YAML is a more flexible format when it comes to configuration. Commandline arguments are difficult to remember and configure especially for an API with complicated commandline names. Rather than using half-baked textfiles, implement a proper config solution. Also add a progress bar when loading models in the commandline. Signed-off-by: kingbri <bdashore3@proton.me>
67 lines
2.4 KiB
Python
67 lines
2.4 KiB
Python
import uvicorn
|
|
import yaml
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from model import ModelContainer
|
|
from progress.bar import IncrementalBar
|
|
|
|
app = FastAPI()
|
|
|
|
# Initialize a model container. This can be undefined at any period of time
|
|
model_container: ModelContainer = None
|
|
|
|
class TextRequest(BaseModel):
|
|
model: str = None # Make the "model" field optional with a default value of None
|
|
prompt: str
|
|
max_tokens: int = 200
|
|
temperature: float = 1
|
|
top_p: float = 0.9
|
|
seed: int = 10
|
|
stream: bool = False
|
|
token_repetition_penalty: float = 1.0
|
|
stop: list = None
|
|
|
|
class TextResponse(BaseModel):
|
|
response: str
|
|
generation_time: str
|
|
|
|
# TODO: Currently broken
|
|
@app.post("/generate-text", response_model=TextResponse)
|
|
def generate_text(request: TextRequest):
|
|
global modelManager
|
|
try:
|
|
prompt = request.prompt # Get the prompt from the request
|
|
user_message = prompt # Assuming that prompt is equivalent to the user's message
|
|
output, generation_time = modelManager.generate_text(prompt=user_message)
|
|
return {"response": output, "generation_time": generation_time}
|
|
except RuntimeError as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# Wrapper callback for load progress
|
|
def load_progress(module, modules):
|
|
yield module, modules
|
|
|
|
if __name__ == "__main__":
|
|
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
|
with open('config.yml', 'r') as config_file:
|
|
config = yaml.safe_load(config_file)
|
|
|
|
# If an initial model name is specified, create a container and load the model
|
|
if config["model_name"]:
|
|
model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}"
|
|
|
|
model_container = ModelContainer(model_path, False, **config)
|
|
load_status = model_container.load_gen(load_progress)
|
|
for (module, modules) in load_status:
|
|
if module == 0:
|
|
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
|
else:
|
|
loading_bar.next()
|
|
|
|
if module == modules:
|
|
loading_bar.finish()
|
|
|
|
print("Model successfully loaded.")
|
|
|
|
# Reload is for dev purposes ONLY!
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True)
|