diff --git a/common/downloader.py b/common/downloader.py index 0401159..8307bbc 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -4,7 +4,9 @@ import asyncio import math import pathlib import shutil +from dataclasses import dataclass from huggingface_hub import HfApi, hf_hub_url +from huggingface_hub.hf_api import RepoFile from fnmatch import fnmatch from loguru import logger from rich.progress import Progress @@ -15,9 +17,16 @@ from common.tabby_config import config from common.utils import unwrap +@dataclass +class RepoItem: + path: str + size: int + url: str + + async def _download_file( session: aiohttp.ClientSession, - repo_item: dict, + repo_item: RepoItem, token: Optional[str], download_path: pathlib.Path, chunk_limit: int, @@ -25,8 +34,8 @@ async def _download_file( ): """Downloads a repo from HuggingFace.""" - filename = repo_item.get("filename") - url = repo_item.get("url") + filename = repo_item.path + url = repo_item.url # Default is 2MB chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000) @@ -46,33 +55,20 @@ async def _download_file( message=f"HTTP {response.status}: {error_text}", ) - # Sometimes, Content-Length can be undefined - content_length = response.headers.get("Content-Length") - file_size = int(content_length) if content_length else None - # Create progress task with appropriate total (None for indeterminate) download_task = progress.add_task( - f"[cyan]Downloading {filename}", total=file_size + f"[cyan]Downloading {filename}", total=repo_item.size ) # Chunk limit is 2 MB - downloaded_size = 0 async with aiofiles.open(str(filepath), "wb") as f: async for chunk in response.content.iter_chunked(chunk_limit_bytes): await f.write(chunk) # Store and update progress bar - downloaded_size += len(chunk) - progress.update(download_task, completed=downloaded_size) - - # For indeterminate files, set final total and mark as complete - if file_size is None: - progress.update( - download_task, total=downloaded_size, completed=downloaded_size - ) + progress.update(download_task, advance=len(chunk)) -# Huggingface does not know how async works def _get_repo_info(repo_id, revision, token): """Fetches information about a HuggingFace repository.""" @@ -81,13 +77,18 @@ def _get_repo_info(repo_id, revision, token): token = token or None api_client = HfApi() - repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token) + repo_tree = api_client.list_repo_tree( + repo_id, revision=revision, token=token, recursive=True + ) + return [ - { - "filename": filename, - "url": hf_hub_url(repo_id, filename, revision=revision), - } - for filename in repo_tree + RepoItem( + path=item.path, + size=item.size, + url=hf_hub_url(repo_id, item.path, revision=revision), + ) + for item in repo_tree + if isinstance(item, RepoFile) ] @@ -130,12 +131,13 @@ async def hf_repo_download( # Auto-detect repo type if it isn't provided if not repo_type: lora_filter = filter( - lambda repo_item: repo_item.get("filename", "").endswith( + lambda repo_item: repo_item.path.endswith( ("adapter_config.json", "adapter_model.bin") - ) + ), + file_list, ) - if lora_filter: + if any(lora_filter): repo_type = "lora" if include or exclude: @@ -145,9 +147,7 @@ async def hf_repo_download( file_list = [ file for file in file_list - if _check_exclusions( - file.get("filename"), include_patterns, exclude_patterns - ) + if _check_exclusions(file.path, include_patterns, exclude_patterns) ] if not file_list: