Tree: Update to use ModelContainer and args

Use command-line arguments to load an initial model if necessary.
API routes are broken, but we should be using the container from
now on as a primary interface with the exllama2 library.

Also these args should be turned into a YAML configuration file in
the future.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-10 23:19:54 -05:00
parent 9d34479e3e
commit 5d32aa02cd
4 changed files with 39 additions and 61 deletions

53
llm.py
View file

@ -1,53 +0,0 @@
# exllama.py
import random
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler
)
import time
class ModelManager:
def __init__(self, model_directory: str = None):
if model_directory is None:
model_directory = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
self.config = ExLlamaV2Config()
self.config.model_dir = model_directory
self.config.prepare()
self.model = ExLlamaV2(self.config)
print("Loading model: " + model_directory)
self.cache = ExLlamaV2Cache(self.model, lazy=True)
self.model.load_autosplit(self.cache)
self.tokenizer = ExLlamaV2Tokenizer(self.config)
self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer)
def generate_text(self,
prompt: str,
max_tokens: int = 150,
temperature=0.5,
seed: int = random.randint(0, 999999),
token_repetition_penalty: float = 1.0,
stop: list = None):
try:
self.generator.warmup()
time_begin = time.time()
settings = ExLlamaV2Sampler.Settings()
settings.token_repetition_penalty = token_repetition_penalty
if stop:
settings.stop_sequence = stop
output = self.generator.generate_simple(
prompt, settings, max_tokens, seed=seed
)
time_end = time.time()
time_total = time_end - time_begin
return output, f"{time_total:.2f} seconds"
except Exception as e:
raise RuntimeError(f"Error generating text: {str(e)}")

34
main.py
View file

@ -1,15 +1,16 @@
import os
import argparse
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llm import ModelManager
from uvicorn import run
from model import ModelContainer
from utils import add_args
app = FastAPI()
# Initialize the modelManager with a default model path
default_model_path = "/home/david/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2"
modelManager = ModelManager(default_model_path)
print(output)
# Initialize a model container. This can be undefined at any period of time
model_container: ModelContainer = None
class TextRequest(BaseModel):
model: str = None # Make the "model" field optional with a default value of None
prompt: str
@ -25,6 +26,7 @@ class TextResponse(BaseModel):
response: str
generation_time: str
# TODO: Currently broken
@app.post("/generate-text", response_model=TextResponse)
def generate_text(request: TextRequest):
global modelManager
@ -36,5 +38,23 @@ def generate_text(request: TextRequest):
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))
# Debug progress check
def progress(module, modules):
print(f"Loaded {module}/{modules} modules")
yield
if __name__ == "__main__":
run(app, host="0.0.0.0", port=8012, reload=True)
# Convert this parser to use a YAML config
parser = argparse.ArgumentParser(description = "TabbyAPI - An API server for exllamav2")
add_args(parser)
args = parser.parse_args()
# If an initial model dir is specified, create a container and load the model
if args.model_dir:
model_container = ModelContainer(args.model_dir, False, **vars(args))
print("Loading an initial model...")
model_container.load(progress)
print("Model successfully loaded.")
# Reload is for dev purposes ONLY!
uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True)

View file

@ -1,4 +1,4 @@
import json, uuid, os, gc, time
import gc, time
import torch
from exllamav2 import(
@ -34,6 +34,7 @@ class ModelContainer:
gpu_split: list or None = None
def __init__(self, model_directory: str, quiet = False, **kwargs):
print(kwargs)
"""
Create model container
@ -57,6 +58,7 @@ class ModelContainer:
full model.
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
'gpu_split' (list): Allocation for weights and (some) tensors, per device
'no_flash_attn' (bool): Turns off flash attention (increases vram usage)
"""
self.quiet = quiet
@ -72,6 +74,7 @@ class ModelContainer:
if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"]
if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"]
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]
chunk_size = min(kwargs.get("chunk_size", 2048), self.config.max_seq_len)
self.config.max_input_len = chunk_size

8
utils.py Normal file
View file

@ -0,0 +1,8 @@
def add_args(parser):
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory")
parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB")
parser.add_argument("-l", "--max_seq_len", type = int, help = "Maximum sequence length")
parser.add_argument("-rs", "--rope_scale", type = float, default = 1.0, help = "RoPE scaling factor")
parser.add_argument("-ra", "--rope_alpha", type = float, default = 1.0, help = "RoPE alpha value (NTK)")
parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention")
parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed")