diff --git a/llm.py b/llm.py deleted file mode 100644 index 19167a3..0000000 --- a/llm.py +++ /dev/null @@ -1,53 +0,0 @@ -# exllama.py -import random -from exllamav2 import ( - ExLlamaV2, - ExLlamaV2Config, - ExLlamaV2Cache, - ExLlamaV2Tokenizer, -) -from exllamav2.generator import ( - ExLlamaV2BaseGenerator, - ExLlamaV2Sampler -) -import time - - -class ModelManager: - def __init__(self, model_directory: str = None): - if model_directory is None: - model_directory = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/" - self.config = ExLlamaV2Config() - self.config.model_dir = model_directory - self.config.prepare() - self.model = ExLlamaV2(self.config) - print("Loading model: " + model_directory) - self.cache = ExLlamaV2Cache(self.model, lazy=True) - self.model.load_autosplit(self.cache) - self.tokenizer = ExLlamaV2Tokenizer(self.config) - self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer) - - def generate_text(self, - prompt: str, - max_tokens: int = 150, - temperature=0.5, - seed: int = random.randint(0, 999999), - token_repetition_penalty: float = 1.0, - stop: list = None): - try: - self.generator.warmup() - time_begin = time.time() - settings = ExLlamaV2Sampler.Settings() - settings.token_repetition_penalty = token_repetition_penalty - - if stop: - settings.stop_sequence = stop - - output = self.generator.generate_simple( - prompt, settings, max_tokens, seed=seed - ) - time_end = time.time() - time_total = time_end - time_begin - return output, f"{time_total:.2f} seconds" - except Exception as e: - raise RuntimeError(f"Error generating text: {str(e)}") diff --git a/main.py b/main.py index 7efab37..7e11f3d 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,16 @@ import os +import argparse +import uvicorn from fastapi import FastAPI, HTTPException from pydantic import BaseModel -from llm import ModelManager -from uvicorn import run +from model import ModelContainer +from utils import add_args app = FastAPI() -# Initialize the modelManager with a default model path -default_model_path = "/home/david/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2" -modelManager = ModelManager(default_model_path) -print(output) +# Initialize a model container. This can be undefined at any period of time +model_container: ModelContainer = None + class TextRequest(BaseModel): model: str = None # Make the "model" field optional with a default value of None prompt: str @@ -25,6 +26,7 @@ class TextResponse(BaseModel): response: str generation_time: str +# TODO: Currently broken @app.post("/generate-text", response_model=TextResponse) def generate_text(request: TextRequest): global modelManager @@ -36,5 +38,23 @@ def generate_text(request: TextRequest): except RuntimeError as e: raise HTTPException(status_code=500, detail=str(e)) +# Debug progress check +def progress(module, modules): + print(f"Loaded {module}/{modules} modules") + yield + if __name__ == "__main__": - run(app, host="0.0.0.0", port=8012, reload=True) + # Convert this parser to use a YAML config + parser = argparse.ArgumentParser(description = "TabbyAPI - An API server for exllamav2") + add_args(parser) + args = parser.parse_args() + + # If an initial model dir is specified, create a container and load the model + if args.model_dir: + model_container = ModelContainer(args.model_dir, False, **vars(args)) + print("Loading an initial model...") + model_container.load(progress) + print("Model successfully loaded.") + + # Reload is for dev purposes ONLY! + uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True) diff --git a/model.py b/model.py index 920dd97..0285faa 100644 --- a/model.py +++ b/model.py @@ -1,4 +1,4 @@ -import json, uuid, os, gc, time +import gc, time import torch from exllamav2 import( @@ -34,6 +34,7 @@ class ModelContainer: gpu_split: list or None = None def __init__(self, model_directory: str, quiet = False, **kwargs): + print(kwargs) """ Create model container @@ -57,6 +58,7 @@ class ModelContainer: full model. 'gpu_split_auto' (bool): Automatically split model across available devices (default: True) 'gpu_split' (list): Allocation for weights and (some) tensors, per device + 'no_flash_attn' (bool): Turns off flash attention (increases vram usage) """ self.quiet = quiet @@ -72,6 +74,7 @@ class ModelContainer: if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"] if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"] + if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] chunk_size = min(kwargs.get("chunk_size", 2048), self.config.max_seq_len) self.config.max_input_len = chunk_size diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..9d82216 --- /dev/null +++ b/utils.py @@ -0,0 +1,8 @@ +def add_args(parser): + parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory") + parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB") + parser.add_argument("-l", "--max_seq_len", type = int, help = "Maximum sequence length") + parser.add_argument("-rs", "--rope_scale", type = float, default = 1.0, help = "RoPE scaling factor") + parser.add_argument("-ra", "--rope_alpha", type = float, default = 1.0, help = "RoPE alpha value (NTK)") + parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention") + parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed")