API: Add HuggingFace downloader

Adds an asynchronous huggingface downloader that uses HF hub to fetch
all repo files. The current HF hub package has a snapshot_download
function that does not cancel on KeyboardInterrupt.

Instead, make a downloader that uses the Rich progress bar styling
along with a cancellable interface. Finally, link this to TabbyAPI.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-29 01:12:30 -04:00
parent 6114bfd221
commit 55ccd1baad
5 changed files with 196 additions and 0 deletions

150
common/downloader.py Normal file
View file

@ -0,0 +1,150 @@
import aiofiles
import aiohttp
import asyncio
import math
import pathlib
import shutil
from huggingface_hub import HfApi, hf_hub_url
from loguru import logger
from rich.progress import Progress
from typing import Optional
from common.config import lora_config, model_config
from common.logger import get_progress_bar
from common.utils import unwrap
async def _download_file(
session: aiohttp.ClientSession,
repo_item: dict,
token: Optional[str],
download_path: pathlib.Path,
chunk_limit: int,
progress: Progress,
):
"""Downloads a repo from HuggingFace."""
filename = repo_item.get("filename")
url = repo_item.get("url")
# Default is 2MB
chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000)
filepath = download_path / filename
filepath.parent.mkdir(parents=True, exist_ok=True)
req_headers = {"Authorization": f"Bearer {token}"} if token else {}
async with session.get(url, headers=req_headers) as response:
# TODO: Change to raise errors
assert response.status == 200
file_size = int(response.headers["Content-Length"])
download_task = progress.add_task(
f"[cyan]Downloading {filename}", total=file_size
)
# Chunk limit is 2 MB
async with aiofiles.open(str(filepath), "wb") as f:
async for chunk in response.content.iter_chunked(chunk_limit_bytes):
await f.write(chunk)
progress.update(download_task, advance=len(chunk))
# Huggingface does not know how async works
def _get_repo_info(repo_id, revision, token):
"""Fetches information about a HuggingFace repository."""
api_client = HfApi()
repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token)
return list(
map(
lambda filename: {
"filename": filename,
"url": hf_hub_url(repo_id, filename, revision=revision),
},
repo_tree,
)
)
def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]):
"""Gets the download folder for the repo."""
if repo_type == "lora":
download_path = pathlib.Path(unwrap(lora_config().get("lora_dir"), "loras"))
else:
download_path = pathlib.Path(unwrap(model_config().get("model_dir"), "models"))
download_path = download_path / unwrap(folder_name, repo_id.split("/")[-1])
return download_path
async def hf_repo_download(
repo_id: str,
folder_name: Optional[str],
revision: Optional[str],
token: Optional[str],
chunk_limit: Optional[float],
repo_type: Optional[str] = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""
file_list = await asyncio.to_thread(_get_repo_info, repo_id, revision, token)
# Auto-detect repo type if it isn't provided
if not repo_type:
lora_filter = filter(
lambda repo_item: repo_item.get("filename", "").endswith(
("adapter_config.json", "adapter_model.bin")
)
)
if lora_filter:
repo_type = "lora"
download_path = _get_download_folder(repo_id, repo_type, folder_name)
download_path.parent.mkdir(parents=True, exist_ok=True)
if download_path.exists():
raise FileExistsError(
f"The path {download_path} already exists. Remove the folder and try again."
)
logger.info(f"Saving {repo_id} to {str(download_path)}")
try:
async with aiohttp.ClientSession() as session:
tasks = []
logger.info(f"Starting download for {repo_id}")
progress = get_progress_bar()
progress.start()
for repo_item in file_list:
tasks.append(
_download_file(
session,
repo_item,
token=token,
download_path=download_path.resolve(),
chunk_limit=chunk_limit,
progress=progress,
)
)
await asyncio.gather(*tasks)
progress.stop()
logger.info(f"Finished download for {repo_id}")
return download_path
except asyncio.CancelledError:
# Cleanup on cancel
if download_path.is_dir():
shutil.rmtree(download_path)
else:
download_path.unlink()
# Stop the progress bar
progress.stop()

View file

@ -23,6 +23,10 @@ RICH_CONSOLE = Console()
LOG_LEVEL = os.getenv("TABBY_LOG_LEVEL", "INFO")
def get_progress_bar():
return Progress(console=RICH_CONSOLE)
def get_loading_progress_bar():
"""Gets a pre-made progress bar for loading tasks."""