API: Add HuggingFace downloader
Adds an asynchronous huggingface downloader that uses HF hub to fetch all repo files. The current HF hub package has a snapshot_download function that does not cancel on KeyboardInterrupt. Instead, make a downloader that uses the Rich progress bar styling along with a cancellable interface. Finally, link this to TabbyAPI. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
6114bfd221
commit
55ccd1baad
5 changed files with 196 additions and 0 deletions
150
common/downloader.py
Normal file
150
common/downloader.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import aiofiles
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import math
|
||||
import pathlib
|
||||
import shutil
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from loguru import logger
|
||||
from rich.progress import Progress
|
||||
from typing import Optional
|
||||
|
||||
from common.config import lora_config, model_config
|
||||
from common.logger import get_progress_bar
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
async def _download_file(
|
||||
session: aiohttp.ClientSession,
|
||||
repo_item: dict,
|
||||
token: Optional[str],
|
||||
download_path: pathlib.Path,
|
||||
chunk_limit: int,
|
||||
progress: Progress,
|
||||
):
|
||||
"""Downloads a repo from HuggingFace."""
|
||||
|
||||
filename = repo_item.get("filename")
|
||||
url = repo_item.get("url")
|
||||
|
||||
# Default is 2MB
|
||||
chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000)
|
||||
|
||||
filepath = download_path / filename
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
req_headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||
|
||||
async with session.get(url, headers=req_headers) as response:
|
||||
# TODO: Change to raise errors
|
||||
assert response.status == 200
|
||||
|
||||
file_size = int(response.headers["Content-Length"])
|
||||
|
||||
download_task = progress.add_task(
|
||||
f"[cyan]Downloading {filename}", total=file_size
|
||||
)
|
||||
|
||||
# Chunk limit is 2 MB
|
||||
async with aiofiles.open(str(filepath), "wb") as f:
|
||||
async for chunk in response.content.iter_chunked(chunk_limit_bytes):
|
||||
await f.write(chunk)
|
||||
progress.update(download_task, advance=len(chunk))
|
||||
|
||||
|
||||
# Huggingface does not know how async works
|
||||
def _get_repo_info(repo_id, revision, token):
|
||||
"""Fetches information about a HuggingFace repository."""
|
||||
|
||||
api_client = HfApi()
|
||||
repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token)
|
||||
return list(
|
||||
map(
|
||||
lambda filename: {
|
||||
"filename": filename,
|
||||
"url": hf_hub_url(repo_id, filename, revision=revision),
|
||||
},
|
||||
repo_tree,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]):
|
||||
"""Gets the download folder for the repo."""
|
||||
|
||||
if repo_type == "lora":
|
||||
download_path = pathlib.Path(unwrap(lora_config().get("lora_dir"), "loras"))
|
||||
else:
|
||||
download_path = pathlib.Path(unwrap(model_config().get("model_dir"), "models"))
|
||||
|
||||
download_path = download_path / unwrap(folder_name, repo_id.split("/")[-1])
|
||||
return download_path
|
||||
|
||||
|
||||
async def hf_repo_download(
|
||||
repo_id: str,
|
||||
folder_name: Optional[str],
|
||||
revision: Optional[str],
|
||||
token: Optional[str],
|
||||
chunk_limit: Optional[float],
|
||||
repo_type: Optional[str] = "model",
|
||||
):
|
||||
"""Gets a repo's information from HuggingFace and downloads it locally."""
|
||||
|
||||
file_list = await asyncio.to_thread(_get_repo_info, repo_id, revision, token)
|
||||
|
||||
# Auto-detect repo type if it isn't provided
|
||||
if not repo_type:
|
||||
lora_filter = filter(
|
||||
lambda repo_item: repo_item.get("filename", "").endswith(
|
||||
("adapter_config.json", "adapter_model.bin")
|
||||
)
|
||||
)
|
||||
|
||||
if lora_filter:
|
||||
repo_type = "lora"
|
||||
|
||||
download_path = _get_download_folder(repo_id, repo_type, folder_name)
|
||||
download_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if download_path.exists():
|
||||
raise FileExistsError(
|
||||
f"The path {download_path} already exists. Remove the folder and try again."
|
||||
)
|
||||
|
||||
logger.info(f"Saving {repo_id} to {str(download_path)}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = []
|
||||
logger.info(f"Starting download for {repo_id}")
|
||||
|
||||
progress = get_progress_bar()
|
||||
progress.start()
|
||||
|
||||
for repo_item in file_list:
|
||||
tasks.append(
|
||||
_download_file(
|
||||
session,
|
||||
repo_item,
|
||||
token=token,
|
||||
download_path=download_path.resolve(),
|
||||
chunk_limit=chunk_limit,
|
||||
progress=progress,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
progress.stop()
|
||||
logger.info(f"Finished download for {repo_id}")
|
||||
|
||||
return download_path
|
||||
except asyncio.CancelledError:
|
||||
# Cleanup on cancel
|
||||
if download_path.is_dir():
|
||||
shutil.rmtree(download_path)
|
||||
else:
|
||||
download_path.unlink()
|
||||
|
||||
# Stop the progress bar
|
||||
progress.stop()
|
||||
|
|
@ -23,6 +23,10 @@ RICH_CONSOLE = Console()
|
|||
LOG_LEVEL = os.getenv("TABBY_LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
def get_progress_bar():
|
||||
return Progress(console=RICH_CONSOLE)
|
||||
|
||||
|
||||
def get_loading_progress_bar():
|
||||
"""Gets a pre-made progress bar for loading tasks."""
|
||||
|
||||
|
|
|
|||
|
|
@ -13,12 +13,14 @@ from common.concurrency import (
|
|||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
)
|
||||
from common.downloader import hf_repo_download
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.OAI.types.auth import AuthPermissionResponse
|
||||
from endpoints.OAI.types.completion import CompletionRequest
|
||||
from endpoints.OAI.types.chat_completion import ChatCompletionRequest
|
||||
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.OAI.types.lora import (
|
||||
LoraCard,
|
||||
LoraList,
|
||||
|
|
@ -288,6 +290,22 @@ async def unload_sampler_override():
|
|||
sampling.overrides_from_dict({})
|
||||
|
||||
|
||||
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
|
||||
async def download_model(request: Request, data: DownloadRequest):
|
||||
"""Downloads a model from HuggingFace."""
|
||||
|
||||
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
|
||||
|
||||
# For now, the downloader and request data are 1:1
|
||||
download_path = await run_with_request_disconnect(
|
||||
request,
|
||||
download_task,
|
||||
"Download request cancelled by user. Files have been cleaned up.",
|
||||
)
|
||||
|
||||
return DownloadResponse(download_path=str(download_path))
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
|
|
|
|||
19
endpoints/OAI/types/download.py
Normal file
19
endpoints/OAI/types/download.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DownloadRequest(BaseModel):
|
||||
"""Parameters for a HuggingFace repo download."""
|
||||
|
||||
repo_id: str
|
||||
repo_type: Optional[str] = "model"
|
||||
folder_name: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
token: Optional[str] = None
|
||||
chunk_limit: Optional[int] = None
|
||||
|
||||
|
||||
class DownloadResponse(BaseModel):
|
||||
"""Response for a download request."""
|
||||
|
||||
download_path: str
|
||||
|
|
@ -27,6 +27,11 @@ dependencies = [
|
|||
"packaging",
|
||||
"tokenizers",
|
||||
"lm-format-enforcer >= 0.9.6",
|
||||
"aiofiles",
|
||||
|
||||
# TODO: Maybe move these to a downloader feature?
|
||||
"aiohttp",
|
||||
"huggingface_hub",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue