Initial
This commit is contained in:
commit
b967e2e604
4 changed files with 80 additions and 0 deletions
0
README.md
Normal file
0
README.md
Normal file
38
llm.py
Normal file
38
llm.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# exllama.py
|
||||
import random
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2BaseGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
import time
|
||||
class ModelManager:
|
||||
def __init__(self, model_directory: str = None):
|
||||
if model_directory is None:
|
||||
model_directory = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
|
||||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = model_directory
|
||||
self.config.prepare()
|
||||
self.model = ExLlamaV2(self.config)
|
||||
print("Loading model: " + model_directory)
|
||||
self.cache = ExLlamaV2Cache(self.model, lazy=True)
|
||||
self.model.load_autosplit(self.cache)
|
||||
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
||||
self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer)
|
||||
def generate_text(self, prompt: str, max_new_tokens: int = 150,seed: int = random.randint(0,999999) ):
|
||||
try:
|
||||
self.generator.warmup()
|
||||
time_begin = time.time()
|
||||
output = self.generator.generate_simple(
|
||||
prompt, ExLlamaV2Sampler.Settings(), max_new_tokens, seed=seed
|
||||
)
|
||||
time_end = time.time()
|
||||
time_total = time_end - time_begin
|
||||
return output, f"{time_total:.2f} seconds"
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error generating text: {str(e)}")
|
||||
39
main.py
Normal file
39
main.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# main.py
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from llm import ModelManager
|
||||
from uvicorn import run
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Example: Using a different model directory
|
||||
modelManager = ModelManager("/home/david/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2")
|
||||
|
||||
class TextRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[dict]
|
||||
temperature: float
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
response: str
|
||||
generation_time: str
|
||||
|
||||
@app.post("/generate-text", response_model=TextResponse)
|
||||
def generate_text(request: TextRequest):
|
||||
try:
|
||||
#model_path = request.model # You can use this path to load a specific model if needed
|
||||
messages = request.messages
|
||||
#temperature = request.temperature
|
||||
|
||||
# Assuming you need to extract the user's message from the messages list
|
||||
user_message = next(msg["content"] for msg in messages if msg["role"] == "user")
|
||||
|
||||
# You can then use user_message as the prompt for generation
|
||||
output, generation_time = modelManager.generate_text(user_message)
|
||||
return {"response": output, "generation_time": generation_time}
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(app, host="0.0.0.0", port=8012, reload=True)
|
||||
3
requirements.txt
Normal file
3
requirements.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
fastapi==0.104.1
|
||||
uvicorn==0.24.0
|
||||
pydantic==2.4.2
|
||||
Loading…
Add table
Add a link
Reference in a new issue