From 71de3060bbeff492bb3793b5701b3fc114b2722f Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 21:42:38 -0400 Subject: [PATCH] Downloader: Make timeout configurable Add an API parameter to set the timeout in seconds. Keep it to None by default for uninterrupted downloads. Signed-off-by: kingbri --- common/downloader.py | 5 +++-- endpoints/core/types/download.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/common/downloader.py b/common/downloader.py index b252a0f..b9e1b72 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -101,6 +101,7 @@ async def hf_repo_download( chunk_limit: Optional[float], include: Optional[List[str]], exclude: Optional[List[str]], + timeout: Optional[int], repo_type: Optional[str] = "model", ): """Gets a repo's information from HuggingFace and downloads it locally.""" @@ -145,8 +146,8 @@ async def hf_repo_download( logger.info(f"Saving {repo_id} to {str(download_path)}") try: - timeout = aiohttp.ClientTimeout(total=None) # Turn off timeout - async with aiohttp.ClientSession(timeout=timeout) as session: + client_timeout = aiohttp.ClientTimeout(total=timeout) # Turn off timeout + async with aiohttp.ClientSession(timeout=client_timeout) as session: tasks = [] logger.info(f"Starting download for {repo_id}") diff --git a/endpoints/core/types/download.py b/endpoints/core/types/download.py index ac681bf..cf49501 100644 --- a/endpoints/core/types/download.py +++ b/endpoints/core/types/download.py @@ -17,6 +17,7 @@ class DownloadRequest(BaseModel): include: List[str] = Field(default_factory=_generate_include_list) exclude: List[str] = Field(default_factory=list) chunk_limit: Optional[int] = None + timeout: Optional[int] = None class DownloadResponse(BaseModel):