OAI: Add chat completions endpoint
Chat completions is the endpoint that will be used by OAI in the future. Makes sense to support it even though the completions endpoint will be used more often. Also unify common parameters between the chat completion and completion requests since they're very similar. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
593471a04d
commit
5e8419ec0c
5 changed files with 247 additions and 96 deletions
44
OAI/types/chat_completion.py
Normal file
44
OAI/types/chat_completion.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from uuid import uuid4
|
||||
from time import time
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Union, List, Optional
|
||||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatCompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
message: ChatCompletionMessage
|
||||
|
||||
class ChatCompletionStreamChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
delta: ChatCompletionMessage
|
||||
|
||||
# Inherited from common request
|
||||
class ChatCompletionRequest(CommonCompletionRequest):
|
||||
# Messages
|
||||
# Take in a string as well even though it's not part of the OAI spec
|
||||
messages: Union[str, List[ChatCompletionMessage]]
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
choices: List[ChatCompletionRespChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "chat.completion"
|
||||
|
||||
# TODO: Add usage stats
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
class ChatCompletionStreamChunk(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
choices: List[ChatCompletionStreamChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "chat.completion.chunk"
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
|
|
@ -11,3 +11,75 @@ class UsageStats(BaseModel):
|
|||
completion_tokens: int
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
class CommonCompletionRequest(BaseModel):
|
||||
# Model information
|
||||
# This parameter is not used, the loaded model is used instead
|
||||
model: Optional[str] = None
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
suffix: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# Generation info
|
||||
seed: Optional[int] = -1
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
# Default to 150 as 16 makes no sense as a default
|
||||
max_tokens: Optional[int] = 150
|
||||
|
||||
# Aliased to repetition_penalty
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
|
||||
# Sampling params
|
||||
token_healing: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
top_k: Optional[int] = 0
|
||||
top_p: Optional[float] = 1.0
|
||||
typical: Optional[float] = 0.0
|
||||
min_p: Optional[float] = 0.0
|
||||
tfs: Optional[float] = 1.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
repetition_penalty_range: Optional[int] = 0
|
||||
repetition_decay: Optional[int] = 0
|
||||
mirostat_mode: Optional[int] = 0
|
||||
mirostat_tau: Optional[float] = 1.5
|
||||
mirostat_eta: Optional[float] = 0.1
|
||||
add_bos_token: Optional[bool] = True
|
||||
ban_eos_token: Optional[bool] = False
|
||||
|
||||
# Converts to internal generation parameters
|
||||
def to_gen_params(self):
|
||||
# Convert stop to an array of strings
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
|
||||
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
|
||||
self.repetition_penalty = self.frequency_penalty
|
||||
|
||||
return {
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"typical": self.typical,
|
||||
"min_p": self.min_p,
|
||||
"tfs": self.tfs,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_penalty_range": self.repetition_penalty_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": True if self.mirostat_mode == 2 else False,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"mirostat_eta": self.mirostat_eta
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,102 +2,26 @@ from uuid import uuid4
|
|||
from time import time
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Union
|
||||
from OAI.types.common import LogProbs, UsageStats
|
||||
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
|
||||
|
||||
class CompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[LogProbs] = None
|
||||
text: str
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
# Model information
|
||||
model: str
|
||||
|
||||
# Inherited from common request
|
||||
class CompletionRequest(CommonCompletionRequest):
|
||||
# Prompt can also contain token ids, but that's out of scope for this project.
|
||||
prompt: Union[str, List[str]]
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
suffix: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# Generation info
|
||||
seed: Optional[int] = -1
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
# Default to 150 as 16 makes no sense as a default
|
||||
max_tokens: Optional[int] = 150
|
||||
|
||||
# Not supported sampling params
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
|
||||
# Aliased to repetition_penalty
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
|
||||
# Sampling params
|
||||
token_healing: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
top_k: Optional[int] = 0
|
||||
top_p: Optional[float] = 1.0
|
||||
typical: Optional[float] = 0.0
|
||||
min_p: Optional[float] = 0.0
|
||||
tfs: Optional[float] = 1.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
repetition_penalty_range: Optional[int] = 0
|
||||
repetition_decay: Optional[int] = 0
|
||||
mirostat_mode: Optional[int] = 0
|
||||
mirostat_tau: Optional[float] = 1.5
|
||||
mirostat_eta: Optional[float] = 0.1
|
||||
add_bos_token: Optional[bool] = True
|
||||
ban_eos_token: Optional[bool] = False
|
||||
|
||||
# Converts to internal generation parameters
|
||||
def to_gen_params(self):
|
||||
# Convert prompt to a string
|
||||
if isinstance(self.prompt, list):
|
||||
self.prompt = "\n".join(self.prompt)
|
||||
|
||||
# Convert stop to an array of strings
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
|
||||
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
|
||||
self.repetition_penalty = self.frequency_penalty
|
||||
|
||||
return {
|
||||
"prompt": self.prompt,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"typical": self.typical,
|
||||
"min_p": self.min_p,
|
||||
"tfs": self.tfs,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_penalty_range": self.repetition_penalty_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": True if self.mirostat_mode == 2 else False,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"mirostat_eta": self.mirostat_eta
|
||||
}
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
|
||||
choices: List[CompletionRespChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "text-completion"
|
||||
object: str = "text_completion"
|
||||
|
||||
# TODO: Add usage stats
|
||||
usage: Optional[UsageStats] = None
|
||||
|
|
|
|||
76
OAI/utils.py
76
OAI/utils.py
|
|
@ -1,15 +1,22 @@
|
|||
import pathlib
|
||||
from fastchat.model.model_adapter import get_conversation_template, Conversation
|
||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice
|
||||
from OAI.types.chat_completion import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRespChoice,
|
||||
ChatCompletionStreamChunk,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamChoice
|
||||
)
|
||||
from OAI.types.common import UsageStats
|
||||
from OAI.types.model import ModelList, ModelCard
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
def create_completion_response(text: str, index: int, model_name: Optional[str]):
|
||||
def create_completion_response(text: str, model_name: Optional[str]):
|
||||
# TODO: Add method to get token amounts in model for UsageStats
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
finish_reason="Generated",
|
||||
index = index,
|
||||
finish_reason = "Generated",
|
||||
text = text
|
||||
)
|
||||
|
||||
|
|
@ -20,11 +27,70 @@ def create_completion_response(text: str, index: int, model_name: Optional[str])
|
|||
|
||||
return response
|
||||
|
||||
def create_chat_completion_response(text: str, model_name: Optional[str]):
|
||||
# TODO: Add method to get token amounts in model for UsageStats
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role = "assistant",
|
||||
content = text
|
||||
)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason = "Generated",
|
||||
message = message
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices = [choice],
|
||||
model = model_name or ""
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Optional[str]):
|
||||
# TODO: Add method to get token amounts in model for UsageStats
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role = "assistant",
|
||||
content = text
|
||||
)
|
||||
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason = "Generated",
|
||||
delta = message
|
||||
)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id = const_id,
|
||||
choices = [choice],
|
||||
model = model_name or ""
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
||||
def get_model_list(model_path: pathlib.Path):
|
||||
model_card_list = ModelList()
|
||||
for path in model_path.parent.iterdir():
|
||||
for path in model_path.iterdir():
|
||||
if path.is_dir():
|
||||
model_card = ModelCard(id = path.name)
|
||||
model_card_list.data.append(model_card)
|
||||
|
||||
return model_card_list
|
||||
|
||||
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
|
||||
conv = get_conversation_template(model_path)
|
||||
for message in messages:
|
||||
msg_role = message.role
|
||||
if msg_role == "system":
|
||||
conv.system_message = message.content
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message.content)
|
||||
elif msg_role == "assistant":
|
||||
conv.append_message(conv.roles[1], message.content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
return prompt
|
||||
|
|
|
|||
61
main.py
61
main.py
|
|
@ -7,11 +7,19 @@ from model import ModelContainer
|
|||
from progress.bar import IncrementalBar
|
||||
from sse_starlette import EventSourceResponse
|
||||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
|
||||
from OAI.utils import create_completion_response, get_model_list
|
||||
from OAI.utils import (
|
||||
create_completion_response,
|
||||
get_model_list,
|
||||
get_chat_completion_prompt,
|
||||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk
|
||||
)
|
||||
from typing import Optional
|
||||
from utils import load_progress
|
||||
from uuid import uuid4
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -45,8 +53,8 @@ async def load_model(data: ModelLoadRequest):
|
|||
model_config = config["model"]
|
||||
model_path = pathlib.Path(model_config["model_dir"] or "models")
|
||||
model_path = model_path / data.name
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.model_dump())
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.dict())
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
|
|
@ -97,21 +105,58 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||
|
||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
||||
if data.stream:
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(**data.to_gen_params())
|
||||
for index, part in enumerate(new_generation):
|
||||
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(part, index, model_container.get_model_path().name)
|
||||
response = create_completion_response(part, model_path.name)
|
||||
|
||||
yield response.json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(**data.to_gen_params())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
|
||||
response_text = model_container.generate(data.prompt, **data.to_gen_params())
|
||||
response = create_completion_response(response_text, model_path.name)
|
||||
|
||||
return response.json()
|
||||
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
|
||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = get_chat_completion_prompt(model_path.name, data.messages)
|
||||
|
||||
if data.stream:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id,
|
||||
part,
|
||||
model_path.name
|
||||
)
|
||||
|
||||
yield response.json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(prompt, **data.to_gen_params())
|
||||
response = create_chat_completion_response(response_text, model_path.name)
|
||||
|
||||
return response.json()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue