OAI: Fix chat completion streaming

Chat completions require a finish reason to be provided in the OAI
spec once the streaming is completed. This is different from a non-
streaming chat completion response.

Also fix some errors that were raised from the endpoint.

References #15

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-01 00:14:24 -05:00
parent c4d8c901e1
commit aef411bed5
3 changed files with 30 additions and 16 deletions

View file

@ -5,8 +5,8 @@ from typing import Union, List, Optional
from OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionMessage(BaseModel):
role: str
content: str
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
@ -17,8 +17,8 @@ class ChatCompletionRespChoice(BaseModel):
class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
delta: ChatCompletionMessage
finish_reason: Optional[str]
delta: Union[ChatCompletionRespChoice, dict] = {}
# Inherited from common request
class ChatCompletionRequest(CommonCompletionRequest):

View file

@ -57,16 +57,21 @@ def create_chat_completion_response(text: str, prompt_tokens: int, completion_to
return response
def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
message = ChatCompletionMessage(
role = "assistant",
content = text
)
def create_chat_completion_stream_chunk(const_id: str,
text: Optional[str] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None):
if finish_reason:
message = {}
else:
message = ChatCompletionMessage(
role = "assistant",
content = text
)
# The finish reason can be None
choice = ChatCompletionStreamChoice(
finish_reason = "Generated",
finish_reason = finish_reason,
delta = message
)
@ -95,8 +100,8 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
return model_card_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
# Check if fastchat is available
# Check if fastchat is available
if not _fastchat_available:
raise ModuleNotFoundError(
"Fastchat must be installed to parse these chat completion messages.\n"
@ -114,7 +119,7 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
for message in messages:
msg_role = message.role
if msg_role == "system":
conv.system_message = message.content
conv.set_system_message(message.content)
elif msg_role == "user":
conv.append_message(conv.roles[0], message.content)
elif msg_role == "assistant":

13
main.py
View file

@ -225,8 +225,8 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
const_id = f"chatcmpl-{uuid4().hex}"
async def generator():
try:
new_generation, prompt_tokens, completion_tokens = model_container.generate_gen(prompt, **data.to_gen_params())
for part in new_generation:
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
for (part, _, _) in new_generation:
if await request.is_disconnected():
break
@ -239,6 +239,15 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
yield response.json(ensure_ascii=False)
except Exception as e:
yield get_generator_error(e)
finally:
# Always finish no matter what
finish_response = create_chat_completion_stream_chunk(
const_id,
finish_reason = "stop"
)
yield finish_response.json(ensure_ascii=False)
return EventSourceResponse(generator())
else: