tabbyAPI-ollama/common/downloader.py
kingbri 0152a1665b Downloader: Switch to use API sizes
Rather than relying on Content-Length which can be unreliable, ping
the API to get file sizes and work from there.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
2025-06-30 12:49:53 -04:00

205 lines
6 KiB
Python

import aiofiles
import aiohttp
import asyncio
import math
import pathlib
import shutil
from dataclasses import dataclass
from huggingface_hub import HfApi, hf_hub_url
from huggingface_hub.hf_api import RepoFile
from fnmatch import fnmatch
from loguru import logger
from rich.progress import Progress
from typing import List, Optional
from common.logger import get_progress_bar
from common.tabby_config import config
from common.utils import unwrap
@dataclass
class RepoItem:
path: str
size: int
url: str
async def _download_file(
session: aiohttp.ClientSession,
repo_item: RepoItem,
token: Optional[str],
download_path: pathlib.Path,
chunk_limit: int,
progress: Progress,
):
"""Downloads a repo from HuggingFace."""
filename = repo_item.path
url = repo_item.url
# Default is 2MB
chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000)
filepath = download_path / filename
filepath.parent.mkdir(parents=True, exist_ok=True)
req_headers = {"Authorization": f"Bearer {token}"} if token else {}
async with session.get(url, headers=req_headers) as response:
if not response.ok:
error_text = await response.text()
raise aiohttp.ClientResponseError(
response.request_info,
response.history,
status=response.status,
message=f"HTTP {response.status}: {error_text}",
)
# Create progress task with appropriate total (None for indeterminate)
download_task = progress.add_task(
f"[cyan]Downloading {filename}", total=repo_item.size
)
# Chunk limit is 2 MB
async with aiofiles.open(str(filepath), "wb") as f:
async for chunk in response.content.iter_chunked(chunk_limit_bytes):
await f.write(chunk)
# Store and update progress bar
progress.update(download_task, advance=len(chunk))
def _get_repo_info(repo_id, revision, token):
"""Fetches information about a HuggingFace repository."""
# None-ish casting of revision and token values
revision = revision or None
token = token or None
api_client = HfApi()
repo_tree = api_client.list_repo_tree(
repo_id, revision=revision, token=token, recursive=True
)
return [
RepoItem(
path=item.path,
size=item.size,
url=hf_hub_url(repo_id, item.path, revision=revision),
)
for item in repo_tree
if isinstance(item, RepoFile)
]
def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]):
"""Gets the download folder for the repo."""
if repo_type == "lora":
download_path = pathlib.Path(config.lora.lora_dir)
else:
download_path = pathlib.Path(config.model.model_dir)
download_path = download_path / (folder_name or repo_id.split("/")[-1])
return download_path
def _check_exclusions(
filename: str, include_patterns: List[str], exclude_patterns: List[str]
):
include_result = any(fnmatch(filename, pattern) for pattern in include_patterns)
exclude_result = any(fnmatch(filename, pattern) for pattern in exclude_patterns)
return include_result and not exclude_result
async def hf_repo_download(
repo_id: str,
folder_name: Optional[str],
revision: Optional[str],
token: Optional[str],
include: Optional[List[str]],
exclude: Optional[List[str]],
chunk_limit: Optional[float] = None,
timeout: Optional[int] = None,
repo_type: Optional[str] = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""
file_list = await asyncio.to_thread(_get_repo_info, repo_id, revision, token)
# Auto-detect repo type if it isn't provided
if not repo_type:
lora_filter = filter(
lambda repo_item: repo_item.path.endswith(
("adapter_config.json", "adapter_model.bin")
),
file_list,
)
if any(lora_filter):
repo_type = "lora"
if include or exclude:
include_patterns = unwrap(include, ["*"])
exclude_patterns = unwrap(exclude, [])
file_list = [
file
for file in file_list
if _check_exclusions(file.path, include_patterns, exclude_patterns)
]
if not file_list:
raise ValueError(f"File list for repo {repo_id} is empty. Check your filters?")
download_path = _get_download_folder(repo_id, repo_type, folder_name)
if download_path.exists():
raise FileExistsError(
f"The path {download_path} already exists. Remove the folder and try again."
)
download_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving {repo_id} to {str(download_path)}")
try:
client_timeout = aiohttp.ClientTimeout(total=timeout) # Turn off timeout
async with aiohttp.ClientSession(timeout=client_timeout) as session:
tasks = []
logger.info(f"Starting download for {repo_id}")
progress = get_progress_bar()
progress.start()
for repo_item in file_list:
tasks.append(
_download_file(
session,
repo_item,
token=token,
download_path=download_path.resolve(),
chunk_limit=chunk_limit,
progress=progress,
)
)
await asyncio.gather(*tasks)
progress.stop()
logger.info(f"Finished download for {repo_id}")
return download_path
except (asyncio.CancelledError, Exception) as exc:
# Cleanup on cancel
if download_path.is_dir():
shutil.rmtree(download_path)
else:
download_path.unlink()
# Stop the progress bar
progress.stop()
# Re-raise exception if the task isn't cancelled
if not isinstance(exc, asyncio.CancelledError):
raise exc