OAI: Fix chat completion streaming
Chat completions require a finish reason to be provided in the OAI spec once the streaming is completed. This is different from a non- streaming chat completion response. Also fix some errors that were raised from the endpoint. References #15 Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c4d8c901e1
commit
aef411bed5
3 changed files with 30 additions and 16 deletions
|
|
@ -5,8 +5,8 @@ from typing import Union, List, Optional
|
|||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
class ChatCompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
|
|
@ -17,8 +17,8 @@ class ChatCompletionRespChoice(BaseModel):
|
|||
class ChatCompletionStreamChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
delta: ChatCompletionMessage
|
||||
finish_reason: Optional[str]
|
||||
delta: Union[ChatCompletionRespChoice, dict] = {}
|
||||
|
||||
# Inherited from common request
|
||||
class ChatCompletionRequest(CommonCompletionRequest):
|
||||
|
|
|
|||
25
OAI/utils.py
25
OAI/utils.py
|
|
@ -57,16 +57,21 @@ def create_chat_completion_response(text: str, prompt_tokens: int, completion_to
|
|||
|
||||
return response
|
||||
|
||||
def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Optional[str]):
|
||||
# TODO: Add method to get token amounts in model for UsageStats
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role = "assistant",
|
||||
content = text
|
||||
)
|
||||
def create_chat_completion_stream_chunk(const_id: str,
|
||||
text: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None):
|
||||
if finish_reason:
|
||||
message = {}
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role = "assistant",
|
||||
content = text
|
||||
)
|
||||
|
||||
# The finish reason can be None
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason = "Generated",
|
||||
finish_reason = finish_reason,
|
||||
delta = message
|
||||
)
|
||||
|
||||
|
|
@ -95,8 +100,8 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
|
|||
return model_card_list
|
||||
|
||||
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
|
||||
# Check if fastchat is available
|
||||
|
||||
# Check if fastchat is available
|
||||
if not _fastchat_available:
|
||||
raise ModuleNotFoundError(
|
||||
"Fastchat must be installed to parse these chat completion messages.\n"
|
||||
|
|
@ -114,7 +119,7 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
|
|||
for message in messages:
|
||||
msg_role = message.role
|
||||
if msg_role == "system":
|
||||
conv.system_message = message.content
|
||||
conv.set_system_message(message.content)
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message.content)
|
||||
elif msg_role == "assistant":
|
||||
|
|
|
|||
13
main.py
13
main.py
|
|
@ -225,8 +225,8 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
async def generator():
|
||||
try:
|
||||
new_generation, prompt_tokens, completion_tokens = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
for (part, _, _) in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
|
|
@ -239,6 +239,15 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
yield response.json(ensure_ascii=False)
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
finally:
|
||||
|
||||
# Always finish no matter what
|
||||
finish_response = create_chat_completion_stream_chunk(
|
||||
const_id,
|
||||
finish_reason = "stop"
|
||||
)
|
||||
|
||||
yield finish_response.json(ensure_ascii=False)
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue