Downloader: Add include and exclude parameters
These both take an array of glob strings to state what files or directories to include or exclude when parsing the download list. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c47869c606
commit
21a01741c9
3 changed files with 48 additions and 13 deletions
|
|
@ -5,9 +5,10 @@ import math
|
|||
import pathlib
|
||||
import shutil
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from fnmatch import fnmatch
|
||||
from loguru import logger
|
||||
from rich.progress import Progress
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from common.config import lora_config, model_config
|
||||
from common.logger import get_progress_bar
|
||||
|
|
@ -85,12 +86,23 @@ def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str
|
|||
return download_path
|
||||
|
||||
|
||||
def check_exclusions(
|
||||
filename: str, include_patterns: List[str], exclude_patterns: List[str]
|
||||
):
|
||||
include_result = any(fnmatch(filename, pattern) for pattern in include_patterns)
|
||||
exclude_result = any(fnmatch(filename, pattern) for pattern in exclude_patterns)
|
||||
|
||||
return include_result and not exclude_result
|
||||
|
||||
|
||||
async def hf_repo_download(
|
||||
repo_id: str,
|
||||
folder_name: Optional[str],
|
||||
revision: Optional[str],
|
||||
token: Optional[str],
|
||||
chunk_limit: Optional[float],
|
||||
include: Optional[List[str]],
|
||||
exclude: Optional[List[str]],
|
||||
repo_type: Optional[str] = "model",
|
||||
):
|
||||
"""Gets a repo's information from HuggingFace and downloads it locally."""
|
||||
|
|
@ -108,14 +120,30 @@ async def hf_repo_download(
|
|||
if lora_filter:
|
||||
repo_type = "lora"
|
||||
|
||||
if include or exclude:
|
||||
include_patterns = unwrap(include, [])
|
||||
exclude_patterns = unwrap(exclude, [])
|
||||
|
||||
file_list = [
|
||||
file
|
||||
for file in file_list
|
||||
if check_exclusions(
|
||||
file.get("filename"), include_patterns, exclude_patterns
|
||||
)
|
||||
]
|
||||
|
||||
if not file_list:
|
||||
raise ValueError(f"File list for repo {repo_id} is empty. Check your filters?")
|
||||
|
||||
download_path = _get_download_folder(repo_id, repo_type, folder_name)
|
||||
download_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if download_path.exists():
|
||||
raise FileExistsError(
|
||||
f"The path {download_path} already exists. Remove the folder and try again."
|
||||
)
|
||||
|
||||
download_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Saving {repo_id} to {str(download_path)}")
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -294,16 +294,21 @@ async def unload_sampler_override():
|
|||
async def download_model(request: Request, data: DownloadRequest):
|
||||
"""Downloads a model from HuggingFace."""
|
||||
|
||||
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
|
||||
try:
|
||||
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
|
||||
|
||||
# For now, the downloader and request data are 1:1
|
||||
download_path = await run_with_request_disconnect(
|
||||
request,
|
||||
download_task,
|
||||
"Download request cancelled by user. Files have been cleaned up.",
|
||||
)
|
||||
# For now, the downloader and request data are 1:1
|
||||
download_path = await run_with_request_disconnect(
|
||||
request,
|
||||
download_task,
|
||||
"Download request cancelled by user. Files have been cleaned up.",
|
||||
)
|
||||
|
||||
return DownloadResponse(download_path=str(download_path))
|
||||
return DownloadResponse(download_path=str(download_path))
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class DownloadRequest(BaseModel):
|
||||
"""Parameters for a HuggingFace repo download."""
|
||||
|
||||
repo_id: str
|
||||
repo_type: Optional[str] = "model"
|
||||
repo_type: str = "model"
|
||||
folder_name: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
token: Optional[str] = None
|
||||
include: List[str] = Field(default_factory=list)
|
||||
exclude: List[str] = Field(default_factory=list)
|
||||
chunk_limit: Optional[int] = None
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue