OAI: Add chat completions endpoint

Chat completions is the endpoint that will be used by OAI in the
future. Makes sense to support it even though the completions
endpoint will be used more often.

Also unify common parameters between the chat completion and completion
requests since they're very similar.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-16 01:06:07 -05:00
parent 593471a04d
commit 5e8419ec0c
5 changed files with 247 additions and 96 deletions

View file

@ -0,0 +1,44 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import Union, List, Optional
from OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionMessage(BaseModel):
role: str
content: str
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
message: ChatCompletionMessage
class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
delta: ChatCompletionMessage
# Inherited from common request
class ChatCompletionRequest(CommonCompletionRequest):
# Messages
# Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[ChatCompletionMessage]]
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None
class ChatCompletionStreamChunk(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionStreamChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion.chunk"

View file

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from typing import List, Dict
from typing import List, Dict, Optional, Union
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
@ -11,3 +11,75 @@ class UsageStats(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
class CommonCompletionRequest(BaseModel):
# Model information
# This parameter is not used, the loaded model is used instead
model: Optional[str] = None
# Extra OAI request stuff
best_of: Optional[int] = None
echo: Optional[bool] = False
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
n: Optional[int] = 1
suffix: Optional[str] = None
user: Optional[str] = None
# Generation info
seed: Optional[int] = -1
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Aliased to repetition_penalty
frequency_penalty: Optional[float] = 0.0
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_penalty_range: Optional[int] = 0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
# Converts to internal generation parameters
def to_gen_params(self):
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_penalty_range": self.repetition_penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": True if self.mirostat_mode == 2 else False,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
}

View file

@ -2,102 +2,26 @@ from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union
from OAI.types.common import LogProbs, UsageStats
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
class CompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
index: int
logprobs: Optional[LogProbs] = None
text: str
class CompletionRequest(BaseModel):
# Model information
model: str
# Inherited from common request
class CompletionRequest(CommonCompletionRequest):
# Prompt can also contain token ids, but that's out of scope for this project.
prompt: Union[str, List[str]]
# Extra OAI request stuff
best_of: Optional[int] = None
echo: Optional[bool] = False
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
n: Optional[int] = 1
suffix: Optional[str] = None
user: Optional[str] = None
# Generation info
seed: Optional[int] = -1
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Not supported sampling params
presence_penalty: Optional[float] = 0.0
# Aliased to repetition_penalty
frequency_penalty: Optional[float] = 0.0
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_penalty_range: Optional[int] = 0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
# Converts to internal generation parameters
def to_gen_params(self):
# Convert prompt to a string
if isinstance(self.prompt, list):
self.prompt = "\n".join(self.prompt)
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"prompt": self.prompt,
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_penalty_range": self.repetition_penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": True if self.mirostat_mode == 2 else False,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
}
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
choices: List[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "text-completion"
object: str = "text_completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None

View file

@ -1,15 +1,22 @@
import pathlib
from fastchat.model.model_adapter import get_conversation_template, Conversation
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
ChatCompletionStreamChunk,
ChatCompletionResponse,
ChatCompletionStreamChoice
)
from OAI.types.common import UsageStats
from OAI.types.model import ModelList, ModelCard
from typing import Optional
from typing import Optional, List
def create_completion_response(text: str, index: int, model_name: Optional[str]):
def create_completion_response(text: str, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
choice = CompletionRespChoice(
finish_reason="Generated",
index = index,
finish_reason = "Generated",
text = text
)
@ -20,11 +27,70 @@ def create_completion_response(text: str, index: int, model_name: Optional[str])
return response
def create_chat_completion_response(text: str, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
message = ChatCompletionMessage(
role = "assistant",
content = text
)
choice = ChatCompletionRespChoice(
finish_reason = "Generated",
message = message
)
response = ChatCompletionResponse(
choices = [choice],
model = model_name or ""
)
return response
def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
message = ChatCompletionMessage(
role = "assistant",
content = text
)
choice = ChatCompletionStreamChoice(
finish_reason = "Generated",
delta = message
)
chunk = ChatCompletionStreamChunk(
id = const_id,
choices = [choice],
model = model_name or ""
)
return chunk
def get_model_list(model_path: pathlib.Path):
model_card_list = ModelList()
for path in model_path.parent.iterdir():
for path in model_path.iterdir():
if path.is_dir():
model_card = ModelCard(id = path.name)
model_card_list.data.append(model_card)
return model_card_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
conv = get_conversation_template(model_path)
for message in messages:
msg_role = message.role
if msg_role == "system":
conv.system_message = message.content
elif msg_role == "user":
conv.append_message(conv.roles[0], message.content)
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message.content)
else:
raise ValueError(f"Unknown role: {msg_role}")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt

61
main.py
View file

@ -7,11 +7,19 @@ from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
from OAI.utils import create_completion_response, get_model_list
from OAI.utils import (
create_completion_response,
get_model_list,
get_chat_completion_prompt,
create_chat_completion_response,
create_chat_completion_stream_chunk
)
from typing import Optional
from utils import load_progress
from uuid import uuid4
app = FastAPI()
@ -45,8 +53,8 @@ async def load_model(data: ModelLoadRequest):
model_config = config["model"]
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.model_dump())
model_container = ModelContainer(model_path, False, **data.dict())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@ -97,21 +105,58 @@ async def decode_tokens(data: TokenDecodeRequest):
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest):
model_path = model_container.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
if data.stream:
async def generator():
new_generation = model_container.generate_gen(**data.to_gen_params())
for index, part in enumerate(new_generation):
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
response = create_completion_response(part, index, model_container.get_model_path().name)
response = create_completion_response(part, model_path.name)
yield response.json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(**data.to_gen_params())
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
response_text = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text, model_path.name)
return response.json()
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
model_path = model_container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = get_chat_completion_prompt(model_path.name, data.messages)
if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"
async def generator():
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
response = create_chat_completion_stream_chunk(
const_id,
part,
model_path.name
)
yield response.json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text, model_path.name)
return response.json()