Downloader: Add include and exclude parameters

These both take an array of glob strings to state what files or
directories to include or exclude when parsing the download list.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-30 00:58:54 -04:00
parent c47869c606
commit 21a01741c9
3 changed files with 48 additions and 13 deletions

View file

@ -5,9 +5,10 @@ import math
import pathlib
import shutil
from huggingface_hub import HfApi, hf_hub_url
from fnmatch import fnmatch
from loguru import logger
from rich.progress import Progress
from typing import Optional
from typing import List, Optional
from common.config import lora_config, model_config
from common.logger import get_progress_bar
@ -85,12 +86,23 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
return download_path
def check_exclusions(
filename: str, include_patterns: List[str], exclude_patterns: List[str]
):
include_result = any(fnmatch(filename, pattern) for pattern in include_patterns)
exclude_result = any(fnmatch(filename, pattern) for pattern in exclude_patterns)
return include_result and not exclude_result
async def hf_repo_download(
repo_id: str,
folder_name: Optional[str],
revision: Optional[str],
token: Optional[str],
chunk_limit: Optional[float],
include: Optional[List[str]],
exclude: Optional[List[str]],
repo_type: Optional[str] = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""
@ -108,14 +120,30 @@ async def hf_repo_download(
if lora_filter:
repo_type = "lora"
if include or exclude:
include_patterns = unwrap(include, [])
exclude_patterns = unwrap(exclude, [])
file_list = [
file
for file in file_list
if check_exclusions(
file.get("filename"), include_patterns, exclude_patterns
)
]
if not file_list:
raise ValueError(f"File list for repo {repo_id} is empty. Check your filters?")
download_path = _get_download_folder(repo_id, repo_type, folder_name)
download_path.parent.mkdir(parents=True, exist_ok=True)
if download_path.exists():
raise FileExistsError(
f"The path {download_path} already exists. Remove the folder and try again."
)
download_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving {repo_id} to {str(download_path)}")
try:

View file

@ -294,16 +294,21 @@ async def unload_sampler_override():
async def download_model(request: Request, data: DownloadRequest):
"""Downloads a model from HuggingFace."""
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
try:
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
return DownloadResponse(download_path=str(download_path))
return DownloadResponse(download_path=str(download_path))
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Lora list endpoint

View file

@ -1,15 +1,17 @@
from pydantic import BaseModel
from typing import Optional
from pydantic import BaseModel, Field
from typing import List, Optional
class DownloadRequest(BaseModel):
"""Parameters for a HuggingFace repo download."""
repo_id: str
repo_type: Optional[str] = "model"
repo_type: str = "model"
folder_name: Optional[str] = None
revision: Optional[str] = None
token: Optional[str] = None
include: List[str] = Field(default_factory=list)
exclude: List[str] = Field(default_factory=list)
chunk_limit: Optional[int] = None