Downloader: Switch to use API sizes

Rather than relying on Content-Length which can be unreliable, ping
the API to get file sizes and work from there.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-06-30 12:49:53 -04:00
parent 03ff4c3128
commit 0152a1665b

View file

@ -4,7 +4,9 @@ import asyncio
import math
import pathlib
import shutil
from dataclasses import dataclass
from huggingface_hub import HfApi, hf_hub_url
from huggingface_hub.hf_api import RepoFile
from fnmatch import fnmatch
from loguru import logger
from rich.progress import Progress
@ -15,9 +17,16 @@ from common.tabby_config import config
from common.utils import unwrap
@dataclass
class RepoItem:
path: str
size: int
url: str
async def _download_file(
session: aiohttp.ClientSession,
repo_item: dict,
repo_item: RepoItem,
token: Optional[str],
download_path: pathlib.Path,
chunk_limit: int,
@ -25,8 +34,8 @@ async def _download_file(
):
"""Downloads a repo from HuggingFace."""
filename = repo_item.get("filename")
url = repo_item.get("url")
filename = repo_item.path
url = repo_item.url
# Default is 2MB
chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000)
@ -46,33 +55,20 @@ async def _download_file(
message=f"HTTP {response.status}: {error_text}",
)
# Sometimes, Content-Length can be undefined
content_length = response.headers.get("Content-Length")
file_size = int(content_length) if content_length else None
# Create progress task with appropriate total (None for indeterminate)
download_task = progress.add_task(
f"[cyan]Downloading {filename}", total=file_size
f"[cyan]Downloading {filename}", total=repo_item.size
)
# Chunk limit is 2 MB
downloaded_size = 0
async with aiofiles.open(str(filepath), "wb") as f:
async for chunk in response.content.iter_chunked(chunk_limit_bytes):
await f.write(chunk)
# Store and update progress bar
downloaded_size += len(chunk)
progress.update(download_task, completed=downloaded_size)
# For indeterminate files, set final total and mark as complete
if file_size is None:
progress.update(
download_task, total=downloaded_size, completed=downloaded_size
)
progress.update(download_task, advance=len(chunk))
# Huggingface does not know how async works
def _get_repo_info(repo_id, revision, token):
"""Fetches information about a HuggingFace repository."""
@ -81,13 +77,18 @@ def _get_repo_info(repo_id, revision, token):
token = token or None
api_client = HfApi()
repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token)
repo_tree = api_client.list_repo_tree(
repo_id, revision=revision, token=token, recursive=True
)
return [
{
"filename": filename,
"url": hf_hub_url(repo_id, filename, revision=revision),
}
for filename in repo_tree
RepoItem(
path=item.path,
size=item.size,
url=hf_hub_url(repo_id, item.path, revision=revision),
)
for item in repo_tree
if isinstance(item, RepoFile)
]
@ -130,12 +131,13 @@ async def hf_repo_download(
# Auto-detect repo type if it isn't provided
if not repo_type:
lora_filter = filter(
lambda repo_item: repo_item.get("filename", "").endswith(
lambda repo_item: repo_item.path.endswith(
("adapter_config.json", "adapter_model.bin")
)
),
file_list,
)
if lora_filter:
if any(lora_filter):
repo_type = "lora"
if include or exclude:
@ -145,9 +147,7 @@ async def hf_repo_download(
file_list = [
file
for file in file_list
if _check_exclusions(
file.get("filename"), include_patterns, exclude_patterns
)
if _check_exclusions(file.path, include_patterns, exclude_patterns)
]
if not file_list: