Downloader: Switch to use API sizes
Rather than relying on Content-Length which can be unreliable, ping the API to get file sizes and work from there. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
03ff4c3128
commit
0152a1665b
1 changed files with 30 additions and 30 deletions
|
|
@ -4,7 +4,9 @@ import asyncio
|
|||
import math
|
||||
import pathlib
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from huggingface_hub.hf_api import RepoFile
|
||||
from fnmatch import fnmatch
|
||||
from loguru import logger
|
||||
from rich.progress import Progress
|
||||
|
|
@ -15,9 +17,16 @@ from common.tabby_config import config
|
|||
from common.utils import unwrap
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepoItem:
|
||||
path: str
|
||||
size: int
|
||||
url: str
|
||||
|
||||
|
||||
async def _download_file(
|
||||
session: aiohttp.ClientSession,
|
||||
repo_item: dict,
|
||||
repo_item: RepoItem,
|
||||
token: Optional[str],
|
||||
download_path: pathlib.Path,
|
||||
chunk_limit: int,
|
||||
|
|
@ -25,8 +34,8 @@ async def _download_file(
|
|||
):
|
||||
"""Downloads a repo from HuggingFace."""
|
||||
|
||||
filename = repo_item.get("filename")
|
||||
url = repo_item.get("url")
|
||||
filename = repo_item.path
|
||||
url = repo_item.url
|
||||
|
||||
# Default is 2MB
|
||||
chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000)
|
||||
|
|
@ -46,33 +55,20 @@ async def _download_file(
|
|||
message=f"HTTP {response.status}: {error_text}",
|
||||
)
|
||||
|
||||
# Sometimes, Content-Length can be undefined
|
||||
content_length = response.headers.get("Content-Length")
|
||||
file_size = int(content_length) if content_length else None
|
||||
|
||||
# Create progress task with appropriate total (None for indeterminate)
|
||||
download_task = progress.add_task(
|
||||
f"[cyan]Downloading {filename}", total=file_size
|
||||
f"[cyan]Downloading {filename}", total=repo_item.size
|
||||
)
|
||||
|
||||
# Chunk limit is 2 MB
|
||||
downloaded_size = 0
|
||||
async with aiofiles.open(str(filepath), "wb") as f:
|
||||
async for chunk in response.content.iter_chunked(chunk_limit_bytes):
|
||||
await f.write(chunk)
|
||||
|
||||
# Store and update progress bar
|
||||
downloaded_size += len(chunk)
|
||||
progress.update(download_task, completed=downloaded_size)
|
||||
|
||||
# For indeterminate files, set final total and mark as complete
|
||||
if file_size is None:
|
||||
progress.update(
|
||||
download_task, total=downloaded_size, completed=downloaded_size
|
||||
)
|
||||
progress.update(download_task, advance=len(chunk))
|
||||
|
||||
|
||||
# Huggingface does not know how async works
|
||||
def _get_repo_info(repo_id, revision, token):
|
||||
"""Fetches information about a HuggingFace repository."""
|
||||
|
||||
|
|
@ -81,13 +77,18 @@ def _get_repo_info(repo_id, revision, token):
|
|||
token = token or None
|
||||
|
||||
api_client = HfApi()
|
||||
repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token)
|
||||
repo_tree = api_client.list_repo_tree(
|
||||
repo_id, revision=revision, token=token, recursive=True
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"filename": filename,
|
||||
"url": hf_hub_url(repo_id, filename, revision=revision),
|
||||
}
|
||||
for filename in repo_tree
|
||||
RepoItem(
|
||||
path=item.path,
|
||||
size=item.size,
|
||||
url=hf_hub_url(repo_id, item.path, revision=revision),
|
||||
)
|
||||
for item in repo_tree
|
||||
if isinstance(item, RepoFile)
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -130,12 +131,13 @@ async def hf_repo_download(
|
|||
# Auto-detect repo type if it isn't provided
|
||||
if not repo_type:
|
||||
lora_filter = filter(
|
||||
lambda repo_item: repo_item.get("filename", "").endswith(
|
||||
lambda repo_item: repo_item.path.endswith(
|
||||
("adapter_config.json", "adapter_model.bin")
|
||||
)
|
||||
),
|
||||
file_list,
|
||||
)
|
||||
|
||||
if lora_filter:
|
||||
if any(lora_filter):
|
||||
repo_type = "lora"
|
||||
|
||||
if include or exclude:
|
||||
|
|
@ -145,9 +147,7 @@ async def hf_repo_download(
|
|||
file_list = [
|
||||
file
|
||||
for file in file_list
|
||||
if _check_exclusions(
|
||||
file.get("filename"), include_patterns, exclude_patterns
|
||||
)
|
||||
if _check_exclusions(file.path, include_patterns, exclude_patterns)
|
||||
]
|
||||
|
||||
if not file_list:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue