Tree: Update to use ModelContainer and args
Use command-line arguments to load an initial model if necessary. API routes are broken, but we should be using the container from now on as a primary interface with the exllama2 library. Also these args should be turned into a YAML configuration file in the future. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9d34479e3e
commit
5d32aa02cd
4 changed files with 39 additions and 61 deletions
53
llm.py
53
llm.py
|
|
@ -1,53 +0,0 @@
|
|||
# exllama.py
|
||||
import random
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2BaseGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
import time
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, model_directory: str = None):
|
||||
if model_directory is None:
|
||||
model_directory = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = model_directory
|
||||
self.config.prepare()
|
||||
self.model = ExLlamaV2(self.config)
|
||||
print("Loading model: " + model_directory)
|
||||
self.cache = ExLlamaV2Cache(self.model, lazy=True)
|
||||
self.model.load_autosplit(self.cache)
|
||||
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
||||
self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer)
|
||||
|
||||
def generate_text(self,
|
||||
prompt: str,
|
||||
max_tokens: int = 150,
|
||||
temperature=0.5,
|
||||
seed: int = random.randint(0, 999999),
|
||||
token_repetition_penalty: float = 1.0,
|
||||
stop: list = None):
|
||||
try:
|
||||
self.generator.warmup()
|
||||
time_begin = time.time()
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.token_repetition_penalty = token_repetition_penalty
|
||||
|
||||
if stop:
|
||||
settings.stop_sequence = stop
|
||||
|
||||
output = self.generator.generate_simple(
|
||||
prompt, settings, max_tokens, seed=seed
|
||||
)
|
||||
time_end = time.time()
|
||||
time_total = time_end - time_begin
|
||||
return output, f"{time_total:.2f} seconds"
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error generating text: {str(e)}")
|
||||
34
main.py
34
main.py
|
|
@ -1,15 +1,16 @@
|
|||
import os
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from llm import ModelManager
|
||||
from uvicorn import run
|
||||
from model import ModelContainer
|
||||
from utils import add_args
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize the modelManager with a default model path
|
||||
default_model_path = "/home/david/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2"
|
||||
modelManager = ModelManager(default_model_path)
|
||||
print(output)
|
||||
# Initialize a model container. This can be undefined at any period of time
|
||||
model_container: ModelContainer = None
|
||||
|
||||
class TextRequest(BaseModel):
|
||||
model: str = None # Make the "model" field optional with a default value of None
|
||||
prompt: str
|
||||
|
|
@ -25,6 +26,7 @@ class TextResponse(BaseModel):
|
|||
response: str
|
||||
generation_time: str
|
||||
|
||||
# TODO: Currently broken
|
||||
@app.post("/generate-text", response_model=TextResponse)
|
||||
def generate_text(request: TextRequest):
|
||||
global modelManager
|
||||
|
|
@ -36,5 +38,23 @@ def generate_text(request: TextRequest):
|
|||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Debug progress check
|
||||
def progress(module, modules):
|
||||
print(f"Loaded {module}/{modules} modules")
|
||||
yield
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(app, host="0.0.0.0", port=8012, reload=True)
|
||||
# Convert this parser to use a YAML config
|
||||
parser = argparse.ArgumentParser(description = "TabbyAPI - An API server for exllamav2")
|
||||
add_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# If an initial model dir is specified, create a container and load the model
|
||||
if args.model_dir:
|
||||
model_container = ModelContainer(args.model_dir, False, **vars(args))
|
||||
print("Loading an initial model...")
|
||||
model_container.load(progress)
|
||||
print("Model successfully loaded.")
|
||||
|
||||
# Reload is for dev purposes ONLY!
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True)
|
||||
|
|
|
|||
5
model.py
5
model.py
|
|
@ -1,4 +1,4 @@
|
|||
import json, uuid, os, gc, time
|
||||
import gc, time
|
||||
import torch
|
||||
|
||||
from exllamav2 import(
|
||||
|
|
@ -34,6 +34,7 @@ class ModelContainer:
|
|||
gpu_split: list or None = None
|
||||
|
||||
def __init__(self, model_directory: str, quiet = False, **kwargs):
|
||||
print(kwargs)
|
||||
"""
|
||||
Create model container
|
||||
|
||||
|
|
@ -57,6 +58,7 @@ class ModelContainer:
|
|||
full model.
|
||||
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
|
||||
'gpu_split' (list): Allocation for weights and (some) tensors, per device
|
||||
'no_flash_attn' (bool): Turns off flash attention (increases vram usage)
|
||||
"""
|
||||
|
||||
self.quiet = quiet
|
||||
|
|
@ -72,6 +74,7 @@ class ModelContainer:
|
|||
if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
|
||||
if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"]
|
||||
if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"]
|
||||
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]
|
||||
|
||||
chunk_size = min(kwargs.get("chunk_size", 2048), self.config.max_seq_len)
|
||||
self.config.max_input_len = chunk_size
|
||||
|
|
|
|||
8
utils.py
Normal file
8
utils.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
def add_args(parser):
|
||||
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory")
|
||||
parser.add_argument("-gs", "--gpu_split", type = str, help = "\"auto\", or VRAM allocation per GPU in GB")
|
||||
parser.add_argument("-l", "--max_seq_len", type = int, help = "Maximum sequence length")
|
||||
parser.add_argument("-rs", "--rope_scale", type = float, default = 1.0, help = "RoPE scaling factor")
|
||||
parser.add_argument("-ra", "--rope_alpha", type = float, default = 1.0, help = "RoPE alpha value (NTK)")
|
||||
parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention")
|
||||
parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed")
|
||||
Loading…
Add table
Add a link
Reference in a new issue