API: Add HuggingFace downloader
Adds an asynchronous huggingface downloader that uses HF hub to fetch all repo files. The current HF hub package has a snapshot_download function that does not cancel on KeyboardInterrupt. Instead, make a downloader that uses the Rich progress bar styling along with a cancellable interface. Finally, link this to TabbyAPI. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
6114bfd221
commit
55ccd1baad
5 changed files with 196 additions and 0 deletions
150
common/downloader.py
Normal file
150
common/downloader.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import aiofiles
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import math
|
||||
import pathlib
|
||||
import shutil
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from loguru import logger
|
||||
from rich.progress import Progress
|
||||
from typing import Optional
|
||||
|
||||
from common.config import lora_config, model_config
|
||||
from common.logger import get_progress_bar
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
async def _download_file(
|
||||
session: aiohttp.ClientSession,
|
||||
repo_item: dict,
|
||||
token: Optional[str],
|
||||
download_path: pathlib.Path,
|
||||
chunk_limit: int,
|
||||
progress: Progress,
|
||||
):
|
||||
"""Downloads a repo from HuggingFace."""
|
||||
|
||||
filename = repo_item.get("filename")
|
||||
url = repo_item.get("url")
|
||||
|
||||
# Default is 2MB
|
||||
chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000)
|
||||
|
||||
filepath = download_path / filename
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
req_headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||
|
||||
async with session.get(url, headers=req_headers) as response:
|
||||
# TODO: Change to raise errors
|
||||
assert response.status == 200
|
||||
|
||||
file_size = int(response.headers["Content-Length"])
|
||||
|
||||
download_task = progress.add_task(
|
||||
f"[cyan]Downloading {filename}", total=file_size
|
||||
)
|
||||
|
||||
# Chunk limit is 2 MB
|
||||
async with aiofiles.open(str(filepath), "wb") as f:
|
||||
async for chunk in response.content.iter_chunked(chunk_limit_bytes):
|
||||
await f.write(chunk)
|
||||
progress.update(download_task, advance=len(chunk))
|
||||
|
||||
|
||||
# Huggingface does not know how async works
|
||||
def _get_repo_info(repo_id, revision, token):
|
||||
"""Fetches information about a HuggingFace repository."""
|
||||
|
||||
api_client = HfApi()
|
||||
repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token)
|
||||
return list(
|
||||
map(
|
||||
lambda filename: {
|
||||
"filename": filename,
|
||||
"url": hf_hub_url(repo_id, filename, revision=revision),
|
||||
},
|
||||
repo_tree,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]):
|
||||
"""Gets the download folder for the repo."""
|
||||
|
||||
if repo_type == "lora":
|
||||
download_path = pathlib.Path(unwrap(lora_config().get("lora_dir"), "loras"))
|
||||
else:
|
||||
download_path = pathlib.Path(unwrap(model_config().get("model_dir"), "models"))
|
||||
|
||||
download_path = download_path / unwrap(folder_name, repo_id.split("/")[-1])
|
||||
return download_path
|
||||
|
||||
|
||||
async def hf_repo_download(
|
||||
repo_id: str,
|
||||
folder_name: Optional[str],
|
||||
revision: Optional[str],
|
||||
token: Optional[str],
|
||||
chunk_limit: Optional[float],
|
||||
repo_type: Optional[str] = "model",
|
||||
):
|
||||
"""Gets a repo's information from HuggingFace and downloads it locally."""
|
||||
|
||||
file_list = await asyncio.to_thread(_get_repo_info, repo_id, revision, token)
|
||||
|
||||
# Auto-detect repo type if it isn't provided
|
||||
if not repo_type:
|
||||
lora_filter = filter(
|
||||
lambda repo_item: repo_item.get("filename", "").endswith(
|
||||
("adapter_config.json", "adapter_model.bin")
|
||||
)
|
||||
)
|
||||
|
||||
if lora_filter:
|
||||
repo_type = "lora"
|
||||
|
||||
download_path = _get_download_folder(repo_id, repo_type, folder_name)
|
||||
download_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if download_path.exists():
|
||||
raise FileExistsError(
|
||||
f"The path {download_path} already exists. Remove the folder and try again."
|
||||
)
|
||||
|
||||
logger.info(f"Saving {repo_id} to {str(download_path)}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = []
|
||||
logger.info(f"Starting download for {repo_id}")
|
||||
|
||||
progress = get_progress_bar()
|
||||
progress.start()
|
||||
|
||||
for repo_item in file_list:
|
||||
tasks.append(
|
||||
_download_file(
|
||||
session,
|
||||
repo_item,
|
||||
token=token,
|
||||
download_path=download_path.resolve(),
|
||||
chunk_limit=chunk_limit,
|
||||
progress=progress,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
progress.stop()
|
||||
logger.info(f"Finished download for {repo_id}")
|
||||
|
||||
return download_path
|
||||
except asyncio.CancelledError:
|
||||
# Cleanup on cancel
|
||||
if download_path.is_dir():
|
||||
shutil.rmtree(download_path)
|
||||
else:
|
||||
download_path.unlink()
|
||||
|
||||
# Stop the progress bar
|
||||
progress.stop()
|
||||
|
|
@ -23,6 +23,10 @@ RICH_CONSOLE = Console()
|
|||
LOG_LEVEL = os.getenv("TABBY_LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
def get_progress_bar():
|
||||
return Progress(console=RICH_CONSOLE)
|
||||
|
||||
|
||||
def get_loading_progress_bar():
|
||||
"""Gets a pre-made progress bar for loading tasks."""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue