API: Fix finish_reason returns

OAI expects finish_reason to be "stop" or "length" (there are others,
but they're not in the current scope of this project).

Make all completions and chat completions responses return this
from the model generation itself rather than putting a placeholder.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-18 15:59:28 -04:00
parent 25f5d4a690
commit 5c7fc69ded
5 changed files with 35 additions and 17 deletions

View file

@ -605,6 +605,9 @@ class ExllamaV2Container:
joined_generation["generation_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
joined_generation["finish_reason"] = unwrap(
generations[-1].get("finish_reason"), "stop"
)
return joined_generation
@ -1004,6 +1007,10 @@ class ExllamaV2Container:
last_chunk_time = now
if eos or generated_tokens == max_tokens:
finish_reason = "length" if generated_tokens == max_tokens else "stop"
generation = {"finish_reason": finish_reason}
yield generation
break
# Print response

View file

@ -24,7 +24,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
finish_reason: Optional[str] = None
message: ChatCompletionMessage
logprobs: Optional[ChatCompletionLogprobs] = None
@ -32,7 +32,7 @@ class ChatCompletionRespChoice(BaseModel):
class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: Optional[str]
finish_reason: Optional[str] = None
delta: Union[ChatCompletionMessage, dict] = {}
logprobs: Optional[ChatCompletionLogprobs] = None

View file

@ -22,7 +22,7 @@ class CompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
finish_reason: Optional[str] = None
logprobs: Optional[CompletionLogProbs] = None
text: str

View file

@ -60,7 +60,9 @@ def _create_response(generation: dict, model_name: Optional[str]):
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
choice = ChatCompletionRespChoice(
finish_reason="Generated", message=message, logprobs=logprob_response
finish_reason=generation.get("finish_reason"),
message=message,
logprobs=logprob_response,
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
@ -83,14 +85,15 @@ def _create_stream_chunk(
const_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None,
):
"""Create a chat completion stream chunk from the provided text."""
logprob_response = None
if finish_reason:
message = {}
if "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
finish_reason=generation.get("finish_reason")
)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
@ -113,10 +116,10 @@ def _create_stream_chunk(
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
# The finish reason can be None
choice = ChatCompletionStreamChoice(
finish_reason=finish_reason, delta=message, logprobs=logprob_response
)
choice = ChatCompletionStreamChoice(
delta=message,
logprobs=logprob_response,
)
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
@ -165,10 +168,14 @@ async def stream_generate_chat_completion(
yield response.model_dump_json()
# Yield a finish response on successful generation
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
# Break if the generation is finished
if "finish_reason" in generation:
break
yield finish_response.model_dump_json()
# Yield a finish response on successful generation
# finish_response = _create_stream_chunk(const_id, finish_reason="stop")
# yield finish_response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected

View file

@ -39,7 +39,7 @@ def _create_response(generation: dict, model_name: Optional[str]):
)
choice = CompletionRespChoice(
finish_reason="Generated",
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
@ -69,11 +69,15 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
)
async for generation in new_generation:
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Break if the generation is finished
if "finish_reason" in generation:
yield "[DONE]"
break
# Yield a finish response on successful generation
yield "[DONE]"
# yield "[DONE]"
except CancelledError:
# Get out if the request gets disconnected