diff --git a/OAI/types/model.py b/OAI/types/model.py index 5301ca8..03f8fd3 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -18,6 +18,8 @@ class ModelCardParameters(BaseModel): prompt_template: Optional[str] = None num_experts_per_token: Optional[int] = None use_cfg: Optional[bool] = None + + # Draft is another model, so include it in the card params draft: Optional["ModelCard"] = None diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6e36ee2..7b21c11 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -293,6 +293,32 @@ class ExllamaV2Container: ) return model_path + def get_model_parameters(self): + model_params = { + "name": self.get_model_path().name, + "rope_scale": self.config.scale_pos_emb, + "rope_alpha": self.config.scale_alpha_value, + "max_seq_len": self.config.max_seq_len, + "cache_mode": self.cache_mode, + "num_experts_per_token": self.config.num_experts_per_token, + "use_cfg": self.use_cfg, + "prompt_template": self.prompt_template.name + if self.prompt_template + else None, + } + + if self.draft_config: + draft_model_params = { + "name": self.get_model_path(is_draft=True).name, + "rope_scale": self.draft_config.scale_pos_emb, + "rope_alpha": self.draft_config.scale_alpha_value, + "max_seq_len": self.draft_config.max_seq_len, + } + + model_params["draft"] = draft_model_params + + return model_params + def load(self, progress_callback=None): """ Load model diff --git a/common/model.py b/common/model.py new file mode 100644 index 0000000..e626237 --- /dev/null +++ b/common/model.py @@ -0,0 +1,75 @@ +""" +Manages the storage and utility of model containers. + +Containers exist as a common interface for backends. +""" + +import pathlib +from loguru import logger +from typing import Optional + +from backends.exllamav2.model import ExllamaV2Container +from common.logger import get_loading_progress_bar +from common.utils import load_progress + + +container: Optional[ExllamaV2Container] = None + + +async def unload_model(): + """Unloads a model""" + global container + + container.unload() + container = None + + +async def load_model_gen(model_path: pathlib.Path, **kwargs): + """Generator to load a model""" + global container + + # Check if the model is already loaded + if container and container.model: + loaded_model_name = container.get_model_path().name + + if loaded_model_name == model_path.name: + raise ValueError( + f'Model "{loaded_model_name}" is already loaded! Aborting.' + ) + + # Unload the existing model + if container and container.model: + logger.info("Unloading existing model.") + await unload_model() + + container = ExllamaV2Container(model_path.resolve(), False, **kwargs) + + model_type = "draft" if container.draft_config else "model" + load_status = container.load_gen(load_progress) + + progress = get_loading_progress_bar() + progress.start() + + try: + for module, modules in load_status: + if module == 0: + loading_task = progress.add_task( + f"[cyan]Loading {model_type} modules", total=modules + ) + else: + progress.advance(loading_task) + if module == modules: + # Switch to model progress if the draft model is loaded + if model_type == "draft": + model_type = "model" + else: + progress.stop() + + yield module, modules, model_type + finally: + progress.stop() + + +async def load_model(model_path: pathlib.Path, **kwargs): + async for _, _, _ in load_model_gen(model_path, **kwargs): + pass diff --git a/main.py b/main.py index 8b0aebd..878fe70 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" +import asyncio import os import pathlib import signal @@ -17,10 +18,10 @@ from fastapi.middleware.cors import CORSMiddleware from functools import partial from loguru import logger -from common.logger import UVICORN_LOG_CONFIG, setup_logger, get_loading_progress_bar +from common.logger import UVICORN_LOG_CONFIG, setup_logger import common.gen_logging as gen_logging -from backends.exllamav2.model import ExllamaV2Container from backends.exllamav2.utils import check_exllama_version +from common import model from common.args import convert_args_to_dict, init_argparser from common.auth import check_admin_key, check_api_key, load_auth_keys from common.config import ( @@ -52,7 +53,6 @@ from common.templating import ( from common.utils import ( get_generator_error, handle_request_error, - load_progress, is_port_in_use, unwrap, ) @@ -90,24 +90,6 @@ app = FastAPI( ), ) -# Globally scoped variables. Undefined until initalized in main -MODEL_CONTAINER: Optional[ExllamaV2Container] = None - - -async def _check_model_container(): - """Checks if a model isn't loading or loaded.""" - - if MODEL_CONTAINER is None or not ( - MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded - ): - error_message = handle_request_error( - "No models are currently loaded.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - # ALlow CORS requests app.add_middleware( CORSMiddleware, @@ -118,6 +100,20 @@ app.add_middleware( ) +async def check_model_container(): + """FastAPI depends that checks if a model isn't loaded or currently loading.""" + + if model.container is None or not ( + model.container.model_is_loading or model.container.model_loaded + ): + error_message = handle_request_error( + "No models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + # Model list endpoint @app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) @@ -139,35 +135,33 @@ async def list_models(): # Currently loaded model endpoint @app.get( "/v1/model", - dependencies=[Depends(check_api_key), Depends(_check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def get_current_model(): """Returns the currently loaded model.""" - model_name = MODEL_CONTAINER.get_model_path().name - prompt_template = MODEL_CONTAINER.prompt_template + model_params = model.container.get_model_parameters() + draft_model_params = model_params.pop("draft", {}) + + if draft_model_params: + model_params["draft"] = ModelCard( + id=unwrap(draft_model_params.get("name"), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), + ) + else: + draft_model_params = None + model_card = ModelCard( - id=model_name, - parameters=ModelCardParameters( - rope_scale=MODEL_CONTAINER.config.scale_pos_emb, - rope_alpha=MODEL_CONTAINER.config.scale_alpha_value, - max_seq_len=MODEL_CONTAINER.config.max_seq_len, - cache_mode=MODEL_CONTAINER.cache_mode, - prompt_template=prompt_template.name if prompt_template else None, - num_experts_per_token=MODEL_CONTAINER.config.num_experts_per_token, - use_cfg=MODEL_CONTAINER.use_cfg, - ), + id=unwrap(model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(model_params), logging=gen_logging.PREFERENCES, ) - if MODEL_CONTAINER.draft_config: + if draft_model_params: draft_card = ModelCard( - id=MODEL_CONTAINER.get_model_path(True).name, - parameters=ModelCardParameters( - rope_scale=MODEL_CONTAINER.draft_config.scale_pos_emb, - rope_alpha=MODEL_CONTAINER.draft_config.scale_alpha_value, - max_seq_len=MODEL_CONTAINER.draft_config.max_seq_len, - ), + id=unwrap(draft_model_params.pop("name", None), "unknown"), + parameters=ModelCardParameters.model_validate(draft_model_params), ) + model_card.parameters.draft = draft_card return model_card @@ -211,35 +205,12 @@ async def load_model(request: Request, data: ModelLoadRequest): if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") - # Check if the model is already loaded - if MODEL_CONTAINER and MODEL_CONTAINER.model: - loaded_model_name = MODEL_CONTAINER.get_model_path().name - - if loaded_model_name == data.name: - raise HTTPException( - 400, f'Model "{loaded_model_name}"is already loaded! Aborting.' - ) - async def generator(): - """Generator for the loading process.""" - global MODEL_CONTAINER - - # Unload the existing model - if MODEL_CONTAINER and MODEL_CONTAINER.model: - logger.info("Unloading existing model.") - await unload_model() - - MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data) - - model_type = "draft" if MODEL_CONTAINER.draft_config else "model" - load_status = MODEL_CONTAINER.load_gen(load_progress) + """Request generation wrapper for the loading process.""" + load_status = model.load_model_gen(model_path, **load_data) try: - progress = get_loading_progress_bar() - progress.start() - - for module, modules in load_status: - # Get out if the request gets disconnected + async for module, modules, model_type in load_status: if await request.is_disconnected(): release_semaphore() logger.error( @@ -248,13 +219,7 @@ async def load_model(request: Request, data: ModelLoadRequest): ) return - if module == 0: - loading_task = progress.add_task( - f"[cyan]Loading {model_type} modules", total=modules - ) - else: - progress.advance(loading_task) - + if module != 0: response = ModelLoadResponse( model_type=model_type, module=module, @@ -273,13 +238,6 @@ async def load_model(request: Request, data: ModelLoadRequest): ) yield response.model_dump_json() - - # Switch to model progress if the draft model is loaded - if model_type == "draft": - model_type = "model" - else: - progress.stop() - except CancelledError: logger.error( "Model load cancelled by user. " @@ -287,8 +245,6 @@ async def load_model(request: Request, data: ModelLoadRequest): ) except Exception as exc: yield get_generator_error(str(exc)) - finally: - progress.stop() # Determine whether to use or skip the queue if data.skip_queue: @@ -306,14 +262,11 @@ async def load_model(request: Request, data: ModelLoadRequest): # Unload model endpoint @app.post( "/v1/model/unload", - dependencies=[Depends(check_admin_key), Depends(_check_model_container)], + dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) async def unload_model(): """Unloads the currently loaded model.""" - global MODEL_CONTAINER - - MODEL_CONTAINER.unload() - MODEL_CONTAINER = None + await model.unload_model() @app.get("/v1/templates", dependencies=[Depends(check_api_key)]) @@ -326,7 +279,7 @@ async def get_templates(): @app.post( "/v1/template/switch", - dependencies=[Depends(check_admin_key), Depends(_check_model_container)], + dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template""" @@ -335,19 +288,19 @@ async def switch_template(data: TemplateSwitchRequest): try: template = get_template_from_file(data.name) - MODEL_CONTAINER.prompt_template = template + model.container.prompt_template = template except FileNotFoundError as e: raise HTTPException(400, "Template does not exist. Check the name?") from e @app.post( "/v1/template/unload", - dependencies=[Depends(check_admin_key), Depends(_check_model_container)], + dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) async def unload_template(): """Unloads the currently selected template""" - MODEL_CONTAINER.prompt_template = None + model.container.prompt_template = None # Sampler override endpoints @@ -405,7 +358,7 @@ async def get_all_loras(): # Currently loaded loras endpoint @app.get( "/v1/lora", - dependencies=[Depends(check_api_key), Depends(_check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def get_active_loras(): """Returns the currently loaded loras.""" @@ -416,7 +369,7 @@ async def get_active_loras(): id=pathlib.Path(lora.lora_path).parent.name, scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, ), - MODEL_CONTAINER.active_loras, + model.container.active_loras, ) ) ) @@ -427,7 +380,7 @@ async def get_active_loras(): # Load lora endpoint @app.post( "/v1/lora/load", - dependencies=[Depends(check_admin_key), Depends(_check_model_container)], + dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) async def load_lora(data: LoraLoadRequest): """Loads a LoRA into the model container.""" @@ -443,10 +396,10 @@ async def load_lora(data: LoraLoadRequest): # Clean-up existing loras if present def load_loras_internal(): - if len(MODEL_CONTAINER.active_loras) > 0: + if len(model.container.active_loras) > 0: unload_loras() - result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump()) + result = model.container.load_loras(lora_dir, **data.model_dump()) return LoraLoadResponse( success=unwrap(result.get("success"), []), failure=unwrap(result.get("failure"), []), @@ -468,21 +421,21 @@ async def load_lora(data: LoraLoadRequest): # Unload lora endpoint @app.post( "/v1/lora/unload", - dependencies=[Depends(check_admin_key), Depends(_check_model_container)], + dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) async def unload_loras(): """Unloads the currently loaded loras.""" - MODEL_CONTAINER.unload(loras_only=True) + model.container.unload(loras_only=True) # Encode tokens endpoint @app.post( "/v1/token/encode", - dependencies=[Depends(check_api_key), Depends(_check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def encode_tokens(data: TokenEncodeRequest): """Encodes a string into tokens.""" - raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params()) + raw_tokens = model.container.encode_tokens(data.text, **data.get_params()) tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) @@ -492,11 +445,11 @@ async def encode_tokens(data: TokenEncodeRequest): # Decode tokens endpoint @app.post( "/v1/token/decode", - dependencies=[Depends(check_api_key), Depends(_check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def decode_tokens(data: TokenDecodeRequest): """Decodes tokens into a string.""" - message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params()) + message = model.container.decode_tokens(data.tokens, **data.get_params()) response = TokenDecodeResponse(text=unwrap(message, "")) return response @@ -505,11 +458,11 @@ async def decode_tokens(data: TokenDecodeRequest): # Completions endpoint @app.post( "/v1/completions", - dependencies=[Depends(check_api_key), Depends(_check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def generate_completion(request: Request, data: CompletionRequest): """Generates a completion from a prompt.""" - model_path = MODEL_CONTAINER.get_model_path() + model_path = model.container.get_model_path() if isinstance(data.prompt, list): data.prompt = "\n".join(data.prompt) @@ -522,7 +475,7 @@ async def generate_completion(request: Request, data: CompletionRequest): async def generator(): try: - new_generation = MODEL_CONTAINER.generate_gen( + new_generation = model.container.generate_gen( data.prompt, **data.to_gen_params() ) for generation in new_generation: @@ -549,7 +502,7 @@ async def generate_completion(request: Request, data: CompletionRequest): generation = await call_with_semaphore( partial( run_in_threadpool, - MODEL_CONTAINER.generate, + model.container.generate, data.prompt, **data.to_gen_params(), ) @@ -570,30 +523,31 @@ async def generate_completion(request: Request, data: CompletionRequest): # Chat completions endpoint @app.post( "/v1/chat/completions", - dependencies=[Depends(check_api_key), Depends(_check_model_container)], + dependencies=[Depends(check_api_key), Depends(check_model_container)], ) async def generate_chat_completion(request: Request, data: ChatCompletionRequest): """Generates a chat completion from a prompt.""" - if MODEL_CONTAINER.prompt_template is None: + if model.container.prompt_template is None: raise HTTPException( 422, "This endpoint is disabled because a prompt template is not set.", ) - model_path = MODEL_CONTAINER.get_model_path() + model_path = model.container.get_model_path() if isinstance(data.messages, str): prompt = data.messages else: try: - special_tokens_dict = MODEL_CONTAINER.get_special_tokens( + special_tokens_dict = model.container.get_special_tokens( unwrap(data.add_bos_token, True), unwrap(data.ban_eos_token, False), ) + prompt = get_prompt_from_template( data.messages, - MODEL_CONTAINER.prompt_template, + model.container.prompt_template, data.add_generation_prompt, special_tokens_dict, ) @@ -601,7 +555,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest raise HTTPException( 400, "Could not find a Conversation from prompt template " - f"'{MODEL_CONTAINER.prompt_template.name}'. " + f"'{model.container.prompt_template.name}'. " "Check your spelling?", ) from exc except TemplateError as exc: @@ -620,7 +574,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest async def generator(): """Generator for the generation process.""" try: - new_generation = MODEL_CONTAINER.generate_gen( + new_generation = model.container.generate_gen( prompt, **data.to_gen_params() ) for generation in new_generation: @@ -653,7 +607,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest generation = await call_with_semaphore( partial( run_in_threadpool, - MODEL_CONTAINER.generate, + model.container.generate, prompt, **data.to_gen_params(), ) @@ -692,11 +646,9 @@ def signal_handler(*_): sys.exit(0) -def entrypoint(args: Optional[dict] = None): +async def entrypoint(args: Optional[dict] = None): """Entry function for program startup""" - global MODEL_CONTAINER - setup_logger() # Set up signal aborting @@ -782,34 +734,13 @@ def entrypoint(args: Optional[dict] = None): model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / model_name - MODEL_CONTAINER = ExllamaV2Container( - model_path.resolve(), False, **model_config - ) - load_status = MODEL_CONTAINER.load_gen(load_progress) - - progress = get_loading_progress_bar() - progress.start() - model_type = "draft" if MODEL_CONTAINER.draft_config else "model" - - for module, modules in load_status: - if module == 0: - loading_task = progress.add_task( - f"[cyan]Loading {model_type} modules", total=modules - ) - else: - progress.advance(loading_task, 1) - - if module == modules: - if model_type == "draft": - model_type = "model" - else: - progress.stop() + await model.load_model(model_path.resolve(), **model_config) # Load loras after loading the model lora_config = get_lora_config() if lora_config.get("loras"): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) - MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config) + model.container.load_loras(lora_dir.resolve(), **lora_config) # TODO: Replace this with abortables, async via producer consumer, or something else api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True) @@ -821,4 +752,4 @@ def entrypoint(args: Optional[dict] = None): if __name__ == "__main__": - entrypoint() + asyncio.run(entrypoint()) diff --git a/start.py b/start.py index 9bfae3e..ec99299 100644 --- a/start.py +++ b/start.py @@ -1,4 +1,5 @@ """Utility to automatically upgrade and start the API""" +import asyncio import argparse import os import pathlib @@ -66,4 +67,4 @@ if __name__ == "__main__": # Import entrypoint after installing all requirements from main import entrypoint - entrypoint(convert_args_to_dict(args, parser)) + asyncio.run(entrypoint(convert_args_to_dict(args, parser)))