OAI: Implement completion API endpoint
Add support for /v1/completions with the option to use streaming if needed. Also rewrite API endpoints to use async when possible since that improves request performance. Model container parameter names also needed rewrites as well and set fallback cases to their disabled values. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
4fa4386275
commit
eee8b642bd
6 changed files with 190 additions and 57 deletions
13
OAI/models/common.py
Normal file
13
OAI/models/common.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[float] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
completion_tokens: int
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
99
OAI/models/completions.py
Normal file
99
OAI/models/completions.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from uuid import uuid4
|
||||
from time import time
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Union
|
||||
from OAI.models.common import LogProbs, UsageStats
|
||||
|
||||
class CompletionRespChoice(BaseModel):
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[LogProbs] = None
|
||||
text: str
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
# Model information
|
||||
model: str
|
||||
|
||||
# Prompt can also contain token ids, but that's out of scope for this project.
|
||||
prompt: Union[str, List[str]]
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
suffix: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# Generation info
|
||||
seed: Optional[int] = -1
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
# Default to 150 as 16 makes no sense as a default
|
||||
max_tokens: Optional[int] = 150
|
||||
|
||||
# Not supported sampling params
|
||||
presence_penalty: Optional[int] = 0
|
||||
|
||||
# Aliased to repetition_penalty
|
||||
frequency_penalty: int = 0
|
||||
|
||||
# Sampling params
|
||||
token_healing: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
top_k: Optional[int] = 0
|
||||
top_p: Optional[float] = 1.0
|
||||
typical: Optional[float] = 0.0
|
||||
min_p: Optional[float] = 0.0
|
||||
tfs: Optional[float] = 1.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
repetition_penalty_range: Optional[int] = 0
|
||||
repetition_decay: Optional[int] = 0
|
||||
mirostat_mode: Optional[int] = 0
|
||||
mirostat_tau: Optional[float] = 1.5
|
||||
mirostat_eta: Optional[float] = 0.1
|
||||
|
||||
# Converts to internal generation parameters
|
||||
def to_gen_params(self):
|
||||
# Convert prompt to a string
|
||||
if isinstance(self.prompt, list):
|
||||
self.prompt = "\n".join(self.prompt)
|
||||
|
||||
# Convert stop to an array of strings
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
|
||||
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
|
||||
self.repetition_penalty = self.frequency_penalty
|
||||
|
||||
return {
|
||||
"prompt": self.prompt,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"token_healing": self.token_healing,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"typical": self.typical,
|
||||
"min_p": self.min_p,
|
||||
"tfs": self.tfs,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_penalty_range": self.repetition_penalty_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": True if self.mirostat_mode == 2 else False,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"mirostat_eta": self.mirostat_eta
|
||||
}
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
|
||||
choices: List[CompletionRespChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "text-completion"
|
||||
|
||||
# TODO: Add usage stats
|
||||
usage: Optional[UsageStats] = None
|
||||
19
OAI/utils.py
Normal file
19
OAI/utils.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from OAI.models.completions import CompletionResponse, CompletionRespChoice
|
||||
from OAI.models.common import UsageStats
|
||||
from typing import Optional
|
||||
|
||||
def create_completion_response(text: str, index: int, model_name: Optional[str]):
|
||||
# TODO: Add method to get token amounts in model for UsageStats
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
finish_reason="Generated",
|
||||
index = index,
|
||||
text = text
|
||||
)
|
||||
|
||||
response = CompletionResponse(
|
||||
choices = [choice],
|
||||
model = model_name or ""
|
||||
)
|
||||
|
||||
return response
|
||||
51
main.py
51
main.py
|
|
@ -1,41 +1,37 @@
|
|||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Request
|
||||
from model import ModelContainer
|
||||
from progress.bar import IncrementalBar
|
||||
from sse_starlette import EventSourceResponse
|
||||
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
|
||||
from OAI.utils import create_completion_response
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize a model container. This can be undefined at any period of time
|
||||
model_container: ModelContainer = None
|
||||
|
||||
class TextRequest(BaseModel):
|
||||
model: str = None # Make the "model" field optional with a default value of None
|
||||
prompt: str
|
||||
max_tokens: int = 200
|
||||
temperature: float = 1
|
||||
top_p: float = 0.9
|
||||
seed: int = 10
|
||||
stream: bool = False
|
||||
token_repetition_penalty: float = 1.0
|
||||
stop: list = None
|
||||
@app.post("/v1/completions")
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
if data.stream:
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(**data.to_gen_params())
|
||||
for index, part in enumerate(new_generation):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
response: str
|
||||
generation_time: str
|
||||
response = create_completion_response(part, index, model_container.get_model_name())
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(**data.to_gen_params())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_name())
|
||||
|
||||
return response.model_dump_json()
|
||||
|
||||
# TODO: Currently broken
|
||||
@app.post("/generate-text", response_model=TextResponse)
|
||||
def generate_text(request: TextRequest):
|
||||
global modelManager
|
||||
try:
|
||||
prompt = request.prompt # Get the prompt from the request
|
||||
user_message = prompt # Assuming that prompt is equivalent to the user's message
|
||||
output, generation_time = modelManager.generate_text(prompt=user_message)
|
||||
return {"response": output, "generation_time": generation_time}
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Wrapper callback for load progress
|
||||
def load_progress(module, modules):
|
||||
|
|
@ -63,5 +59,4 @@ if __name__ == "__main__":
|
|||
|
||||
print("Model successfully loaded.")
|
||||
|
||||
# Reload is for dev purposes ONLY!
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8012, log_level="debug", reload=True)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug")
|
||||
|
|
|
|||
65
model.py
65
model.py
|
|
@ -1,6 +1,5 @@
|
|||
import gc, time
|
||||
import torch
|
||||
|
||||
from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
|
|
@ -8,25 +7,26 @@ from exllamav2 import(
|
|||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
|
||||
from exllamav2.generator import(
|
||||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
|
||||
class ModelContainer:
|
||||
|
||||
config: ExLlamaV2Config or None = None
|
||||
draft_config: ExLlamaV2Config or None = None
|
||||
model: ExLlamaV2 or None = None
|
||||
draft_model: ExLlamaV2 or None = None
|
||||
cache: ExLlamaV2Cache or None = None
|
||||
draft_cache: ExLlamaV2Cache or None = None
|
||||
tokenizer: ExLlamaV2Tokenizer or None = None
|
||||
generator: ExLlamaV2StreamingGenerator or None = None
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
model: Optional[ExLlamaV2] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
cache: Optional[ExLlamaV2Cache] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2StreamingGenerator] = None
|
||||
|
||||
cache_fp8: bool = False
|
||||
draft_enabled: bool = False
|
||||
|
|
@ -102,6 +102,11 @@ class ModelContainer:
|
|||
self.draft_config.max_input_len = kwargs["chunk_size"]
|
||||
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
||||
|
||||
def get_model_name(self):
|
||||
if self.draft_enabled:
|
||||
return path.basename(path.normpath(self.draft_config.model_dir))
|
||||
else:
|
||||
return path.basename(path.normpath(self.config.model_dir))
|
||||
|
||||
def load(self, progress_callback = None):
|
||||
"""
|
||||
|
|
@ -201,20 +206,20 @@ class ModelContainer:
|
|||
prompt (str): Input prompt
|
||||
**kwargs:
|
||||
'token_healing' (bool): Use token healing (default: False)
|
||||
'temperature' (float): Sampling temperature (default: 0.8)
|
||||
'top_k' (int): Sampling top-K (default: 100)
|
||||
'top_p' (float): Sampling top-P (default: 0.8)
|
||||
'temperature' (float): Sampling temperature (default: 1.0)
|
||||
'top_k' (int): Sampling top-K (default: 0)
|
||||
'top_p' (float): Sampling top-P (default: 1.0)
|
||||
'min_p' (float): Sampling min-P (default: 0.0)
|
||||
'tfs' (float): Tail-free sampling (default: 0.0)
|
||||
'typical' (float): Sampling typical (default: 0.0)
|
||||
'mirostat' (bool): Use Mirostat (default: False)
|
||||
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
|
||||
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
|
||||
'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
||||
'token_repetition_range' (int): Repetition penalty range (default: whole context)
|
||||
'token_repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||
'stop_conditions' (list): List of stop strings/tokens to end response (default: [EOS])
|
||||
'max_new_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
||||
'repetition_range' (int): Repetition penalty range (default: whole context)
|
||||
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
|
||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
|
||||
'generate_window' (int): Space to reserve at the end of the model's context when generating.
|
||||
Rolls context window by the same amount if context length is exceeded to allow generating past
|
||||
|
|
@ -223,25 +228,27 @@ class ModelContainer:
|
|||
"""
|
||||
|
||||
token_healing = kwargs.get("token_healing", False)
|
||||
max_new_tokens = kwargs.get("max_new_tokens", 150)
|
||||
max_tokens = kwargs.get("max_tokens", 150)
|
||||
stream_interval = kwargs.get("stream_interval", 0)
|
||||
generate_window = min(kwargs.get("generate_window", 512), max_new_tokens)
|
||||
generate_window = min(kwargs.get("generate_window", 512), max_tokens)
|
||||
|
||||
# Sampler settings
|
||||
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
gen_settings.temperature = kwargs.get("temperature", 0.8)
|
||||
gen_settings.top_k = kwargs.get("top_k", 100)
|
||||
gen_settings.top_p = kwargs.get("top_p", 0.8)
|
||||
gen_settings.temperature = kwargs.get("temperature", 1.0)
|
||||
gen_settings.top_k = kwargs.get("top_k", 1)
|
||||
gen_settings.top_p = kwargs.get("top_p", 1.0)
|
||||
gen_settings.min_p = kwargs.get("min_p", 0.0)
|
||||
gen_settings.tfs = kwargs.get("tfs", 0.0)
|
||||
gen_settings.typical = kwargs.get("typical", 0.0)
|
||||
gen_settings.mirostat = kwargs.get("mirostat", False)
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5)
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1)
|
||||
gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15)
|
||||
gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len)
|
||||
gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range)
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty", 1.0)
|
||||
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
|
||||
|
|
@ -253,7 +260,7 @@ class ModelContainer:
|
|||
|
||||
# Stop conditions
|
||||
|
||||
self.generator.set_stop_conditions(kwargs.get("stop_conditions", [self.tokenizer.eos_token_id]))
|
||||
self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id]))
|
||||
|
||||
# Tokenized context
|
||||
|
||||
|
|
@ -302,10 +309,10 @@ class ModelContainer:
|
|||
now = time.time()
|
||||
elapsed = now - last_chunk_time
|
||||
|
||||
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_new_tokens):
|
||||
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
|
||||
yield chunk_buffer
|
||||
full_response += chunk_buffer
|
||||
chunk_buffer = ""
|
||||
last_chunk_time = now
|
||||
|
||||
if eos or generated_tokens == max_new_tokens: break
|
||||
if eos or generated_tokens == max_tokens: break
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue