diff --git a/common/downloader.py b/common/downloader.py index b252a0f..b9e1b72 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -101,6 +101,7 @@ async def hf_repo_download( chunk_limit: Optional[float], include: Optional[List[str]], exclude: Optional[List[str]], + timeout: Optional[int], repo_type: Optional[str] = "model", ): """Gets a repo's information from HuggingFace and downloads it locally.""" @@ -145,8 +146,8 @@ async def hf_repo_download( logger.info(f"Saving {repo_id} to {str(download_path)}") try: - timeout = aiohttp.ClientTimeout(total=None) # Turn off timeout - async with aiohttp.ClientSession(timeout=timeout) as session: + client_timeout = aiohttp.ClientTimeout(total=timeout) # Turn off timeout + async with aiohttp.ClientSession(timeout=client_timeout) as session: tasks = [] logger.info(f"Starting download for {repo_id}") diff --git a/endpoints/core/types/download.py b/endpoints/core/types/download.py index ac681bf..cf49501 100644 --- a/endpoints/core/types/download.py +++ b/endpoints/core/types/download.py @@ -17,6 +17,7 @@ class DownloadRequest(BaseModel): include: List[str] = Field(default_factory=_generate_include_list) exclude: List[str] = Field(default_factory=list) chunk_limit: Optional[int] = None + timeout: Optional[int] = None class DownloadResponse(BaseModel):