OAI: Implement completion API endpoint

Add support for /v1/completions with the option to use streaming
if needed. Also rewrite API endpoints to use async when possible
since that improves request performance.

Model container parameter names also needed rewrites as well and
set fallback cases to their disabled values.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-13 18:24:12 -05:00
parent 4fa4386275
commit eee8b642bd
6 changed files with 190 additions and 57 deletions

13
OAI/models/common.py Normal file
View file

@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import List, Dict
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[float] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int

99
OAI/models/completions.py Normal file
View file

@ -0,0 +1,99 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union
from OAI.models.common import LogProbs, UsageStats
class CompletionRespChoice(BaseModel):
finish_reason: str
index: int
logprobs: Optional[LogProbs] = None
text: str
class CompletionRequest(BaseModel):
# Model information
model: str
# Prompt can also contain token ids, but that's out of scope for this project.
prompt: Union[str, List[str]]
# Extra OAI request stuff
best_of: Optional[int] = None
echo: Optional[bool] = False
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
n: Optional[int] = 1
suffix: Optional[str] = None
user: Optional[str] = None
# Generation info
seed: Optional[int] = -1
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Not supported sampling params
presence_penalty: Optional[int] = 0
# Aliased to repetition_penalty
frequency_penalty: int = 0
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_penalty_range: Optional[int] = 0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
# Converts to internal generation parameters
def to_gen_params(self):
# Convert prompt to a string
if isinstance(self.prompt, list):
self.prompt = "\n".join(self.prompt)
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"prompt": self.prompt,
"stop": self.stop,
"max_tokens": self.max_tokens,
"token_healing": self.token_healing,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_penalty_range": self.repetition_penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": True if self.mirostat_mode == 2 else False,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
}
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
choices: List[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "text-completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None

19
OAI/utils.py Normal file
View file

@ -0,0 +1,19 @@
from OAI.models.completions import CompletionResponse, CompletionRespChoice
from OAI.models.common import UsageStats
from typing import Optional
def create_completion_response(text: str, index: int, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
choice = CompletionRespChoice(
finish_reason="Generated",
index = index,
text = text
)
response = CompletionResponse(
choices = [choice],
model = model_name or ""
)
return response

51
main.py
View file

@ -1,41 +1,37 @@
import uvicorn
import yaml
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi import FastAPI, Request
from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
from OAI.utils import create_completion_response
app = FastAPI()
# Initialize a model container. This can be undefined at any period of time
model_container: ModelContainer = None
class TextRequest(BaseModel):
model: str = None # Make the "model" field optional with a default value of None
prompt: str
max_tokens: int = 200
temperature: float = 1
top_p: float = 0.9
seed: int = 10
stream: bool = False
token_repetition_penalty: float = 1.0
stop: list = None
@app.post("/v1/completions")
async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:
async def generator():
new_generation = model_container.generate_gen(**data.to_gen_params())
for index, part in enumerate(new_generation):
if await request.is_disconnected():
break
class TextResponse(BaseModel):
response: str
generation_time: str
response = create_completion_response(part, index, model_container.get_model_name())
yield response.model_dump_json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(**data.to_gen_params())
response = create_completion_response(response_text, 0, model_container.get_model_name())
return response.model_dump_json()
# TODO: Currently broken
@app.post("/generate-text", response_model=TextResponse)
def generate_text(request: TextRequest):
global modelManager
try:
prompt = request.prompt # Get the prompt from the request
user_message = prompt # Assuming that prompt is equivalent to the user's message
output, generation_time = modelManager.generate_text(prompt=user_message)
return {"response": output, "generation_time": generation_time}
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))
# Wrapper callback for load progress
def load_progress(module, modules):
@ -63,5 +59,4 @@ if __name__ == "__main__":
print("Model successfully loaded.")
# Reload is for dev purposes ONLY!
uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True)
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug")

View file

@ -1,6 +1,5 @@
import gc, time
import torch
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
@ -8,25 +7,26 @@ from exllamav2 import(
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
from os import path
from typing import Optional
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
class ModelContainer:
config: ExLlamaV2Config or None = None
draft_config: ExLlamaV2Config or None = None
model: ExLlamaV2 or None = None
draft_model: ExLlamaV2 or None = None
cache: ExLlamaV2Cache or None = None
draft_cache: ExLlamaV2Cache or None = None
tokenizer: ExLlamaV2Tokenizer or None = None
generator: ExLlamaV2StreamingGenerator or None = None
config: Optional[ExLlamaV2Config] = None
draft_config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None
draft_model: Optional[ExLlamaV2] = None
cache: Optional[ExLlamaV2Cache] = None
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None
cache_fp8: bool = False
draft_enabled: bool = False
@ -102,6 +102,11 @@ class ModelContainer:
self.draft_config.max_input_len = kwargs["chunk_size"]
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
def get_model_name(self):
if self.draft_enabled:
return path.basename(path.normpath(self.draft_config.model_dir))
else:
return path.basename(path.normpath(self.config.model_dir))
def load(self, progress_callback = None):
"""
@ -201,20 +206,20 @@ class ModelContainer:
prompt (str): Input prompt
**kwargs:
'token_healing' (bool): Use token healing (default: False)
'temperature' (float): Sampling temperature (default: 0.8)
'top_k' (int): Sampling top-K (default: 100)
'top_p' (float): Sampling top-P (default: 0.8)
'temperature' (float): Sampling temperature (default: 1.0)
'top_k' (int): Sampling top-K (default: 0)
'top_p' (float): Sampling top-P (default: 1.0)
'min_p' (float): Sampling min-P (default: 0.0)
'tfs' (float): Tail-free sampling (default: 0.0)
'typical' (float): Sampling typical (default: 0.0)
'mirostat' (bool): Use Mirostat (default: False)
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
'token_repetition_range' (int): Repetition penalty range (default: whole context)
'token_repetition_decay' (int): Repetition penalty range (default: same as range)
'stop_conditions' (list): List of stop strings/tokens to end response (default: [EOS])
'max_new_tokens' (int): Max no. tokens in response (default: 150)
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
'repetition_range' (int): Repetition penalty range (default: whole context)
'repetition_decay' (int): Repetition penalty range (default: same as range)
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
'max_tokens' (int): Max no. tokens in response (default: 150)
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the model's context when generating.
Rolls context window by the same amount if context length is exceeded to allow generating past
@ -223,25 +228,27 @@ class ModelContainer:
"""
token_healing = kwargs.get("token_healing", False)
max_new_tokens = kwargs.get("max_new_tokens", 150)
max_tokens = kwargs.get("max_tokens", 150)
stream_interval = kwargs.get("stream_interval", 0)
generate_window = min(kwargs.get("generate_window", 512), max_new_tokens)
generate_window = min(kwargs.get("generate_window", 512), max_tokens)
# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
gen_settings.temperature = kwargs.get("temperature", 0.8)
gen_settings.top_k = kwargs.get("top_k", 100)
gen_settings.top_p = kwargs.get("top_p", 0.8)
gen_settings.temperature = kwargs.get("temperature", 1.0)
gen_settings.top_k = kwargs.get("top_k", 1)
gen_settings.top_p = kwargs.get("top_p", 1.0)
gen_settings.min_p = kwargs.get("min_p", 0.0)
gen_settings.tfs = kwargs.get("tfs", 0.0)
gen_settings.typical = kwargs.get("typical", 0.0)
gen_settings.mirostat = kwargs.get("mirostat", False)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5)
gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1)
gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15)
gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range)
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0)
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
# Override sampler settings for temp = 0
@ -253,7 +260,7 @@ class ModelContainer:
# Stop conditions
self.generator.set_stop_conditions(kwargs.get("stop_conditions", [self.tokenizer.eos_token_id]))
self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id]))
# Tokenized context
@ -302,10 +309,10 @@ class ModelContainer:
now = time.time()
elapsed = now - last_chunk_time
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_new_tokens):
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
yield chunk_buffer
full_response += chunk_buffer
chunk_buffer = ""
last_chunk_time = now
if eos or generated_tokens == max_new_tokens: break
if eos or generated_tokens == max_tokens: break

Binary file not shown.