diff --git a/auth.py b/auth.py index df0fbc5..593e376 100644 --- a/auth.py +++ b/auth.py @@ -37,7 +37,7 @@ def load_auth_keys(): api_key = auth_keys_dict["api_key"], admin_key = auth_keys_dict["admin_key"] ) - except: + except Exception as _: new_auth_keys = AuthKeys( api_key = secrets.token_hex(16), admin_key = secrets.token_hex(16) diff --git a/main.py b/main.py index 9f4871d..6a092c0 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,18 @@ import uvicorn import yaml import pathlib -import gen_logging from asyncio import CancelledError -from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from model import ModelContainer from progress.bar import IncrementalBar +from typing import Optional +from uuid import uuid4 + +import gen_logging +from auth import check_admin_key, check_api_key, load_auth_keys from generators import generate_with_semaphore +from model import ModelContainer from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse @@ -28,9 +31,7 @@ from OAI.utils import ( create_chat_completion_response, create_chat_completion_stream_chunk ) -from typing import Optional from utils import get_generator_error, get_sse_packet, load_progress, unwrap -from uuid import uuid4 app = FastAPI() diff --git a/model.py b/model.py index f34501d..3316db5 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,6 @@ -import gc, time, pathlib +import gc +import pathlib +import time import torch from exllamav2 import( ExLlamaV2, @@ -12,9 +14,10 @@ from exllamav2.generator import( ExLlamaV2StreamingGenerator, ExLlamaV2Sampler ) + +from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union from utils import coalesce, unwrap -from gen_logging import log_generation_params, log_prompt, log_response # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 @@ -147,7 +150,8 @@ class ModelContainer: progress_callback (function, optional): A function to call for each module loaded. Prototype: def progress(loaded_modules: int, total_modules: int) """ - for _ in self.load_gen(progress_callback): pass + for _ in self.load_gen(progress_callback): + pass def load_loras(self, lora_directory: pathlib.Path, **kwargs): """ @@ -243,10 +247,14 @@ class ModelContainer: # Unload the entire model if not just unloading loras if not loras_only: - if self.model: self.model.unload() + if self.model: + self.model.unload() self.model = None - if self.draft_model: self.draft_model.unload() + + if self.draft_model: + self.draft_model.unload() self.draft_model = None + self.config = None self.cache = None self.tokenizer = None @@ -440,7 +448,8 @@ class ModelContainer: chunk_buffer = "" last_chunk_time = now - if eos or generated_tokens == max_tokens: break + if eos or generated_tokens == max_tokens: + break # Print response log_response(full_response)