diff --git a/common/downloader.py b/common/downloader.py index e8b4330..97f4d55 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -5,9 +5,10 @@ import math import pathlib import shutil from huggingface_hub import HfApi, hf_hub_url +from fnmatch import fnmatch from loguru import logger from rich.progress import Progress -from typing import Optional +from typing import List, Optional from common.config import lora_config, model_config from common.logger import get_progress_bar @@ -85,12 +86,23 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str return download_path +def check_exclusions( + filename: str, include_patterns: List[str], exclude_patterns: List[str] +): + include_result = any(fnmatch(filename, pattern) for pattern in include_patterns) + exclude_result = any(fnmatch(filename, pattern) for pattern in exclude_patterns) + + return include_result and not exclude_result + + async def hf_repo_download( repo_id: str, folder_name: Optional[str], revision: Optional[str], token: Optional[str], chunk_limit: Optional[float], + include: Optional[List[str]], + exclude: Optional[List[str]], repo_type: Optional[str] = "model", ): """Gets a repo's information from HuggingFace and downloads it locally.""" @@ -108,14 +120,30 @@ async def hf_repo_download( if lora_filter: repo_type = "lora" + if include or exclude: + include_patterns = unwrap(include, []) + exclude_patterns = unwrap(exclude, []) + + file_list = [ + file + for file in file_list + if check_exclusions( + file.get("filename"), include_patterns, exclude_patterns + ) + ] + + if not file_list: + raise ValueError(f"File list for repo {repo_id} is empty. Check your filters?") + download_path = _get_download_folder(repo_id, repo_type, folder_name) - download_path.parent.mkdir(parents=True, exist_ok=True) if download_path.exists(): raise FileExistsError( f"The path {download_path} already exists. Remove the folder and try again." ) + download_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving {repo_id} to {str(download_path)}") try: diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 5958c8b..6964c9d 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -294,16 +294,21 @@ async def unload_sampler_override(): async def download_model(request: Request, data: DownloadRequest): """Downloads a model from HuggingFace.""" - download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) + try: + download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) - # For now, the downloader and request data are 1:1 - download_path = await run_with_request_disconnect( - request, - download_task, - "Download request cancelled by user. Files have been cleaned up.", - ) + # For now, the downloader and request data are 1:1 + download_path = await run_with_request_disconnect( + request, + download_task, + "Download request cancelled by user. Files have been cleaned up.", + ) - return DownloadResponse(download_path=str(download_path)) + return DownloadResponse(download_path=str(download_path)) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc # Lora list endpoint diff --git a/endpoints/OAI/types/download.py b/endpoints/OAI/types/download.py index 6ba33d9..63ca26f 100644 --- a/endpoints/OAI/types/download.py +++ b/endpoints/OAI/types/download.py @@ -1,15 +1,17 @@ -from pydantic import BaseModel -from typing import Optional +from pydantic import BaseModel, Field +from typing import List, Optional class DownloadRequest(BaseModel): """Parameters for a HuggingFace repo download.""" repo_id: str - repo_type: Optional[str] = "model" + repo_type: str = "model" folder_name: Optional[str] = None revision: Optional[str] = None token: Optional[str] = None + include: List[str] = Field(default_factory=list) + exclude: List[str] = Field(default_factory=list) chunk_limit: Optional[int] = None