API: Fix response creation
Change chat completion and text completion responses to be more flexible. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
0af6a38af3
commit
c02fe4d1db
4 changed files with 50 additions and 37 deletions
|
|
@ -6,6 +6,16 @@ from uuid import uuid4
|
|||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
||||
|
||||
class ChatCompletionLogprobs(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: List["ChatCompletionLogprobs"]
|
||||
|
||||
|
||||
class WrappedChatCompletionLogprobs(BaseModel):
|
||||
content: List[ChatCompletionLogprobs]
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
|
@ -16,6 +26,7 @@ class ChatCompletionRespChoice(BaseModel):
|
|||
index: int = 0
|
||||
finish_reason: str
|
||||
message: ChatCompletionMessage
|
||||
logprobs: Optional[WrappedChatCompletionLogprobs] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamChoice(BaseModel):
|
||||
|
|
@ -23,6 +34,7 @@ class ChatCompletionStreamChoice(BaseModel):
|
|||
index: int = 0
|
||||
finish_reason: Optional[str]
|
||||
delta: Union[ChatCompletionMessage, dict] = {}
|
||||
logprobs: Optional[WrappedChatCompletionLogprobs] = None
|
||||
|
||||
|
||||
# Inherited from common request
|
||||
|
|
|
|||
|
|
@ -17,32 +17,35 @@ from OAI.types.completion import (
|
|||
from OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def create_completion_response(**kwargs):
|
||||
def create_completion_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a completion response from the provided text."""
|
||||
|
||||
token_probs = unwrap(kwargs.get("token_probs"), {})
|
||||
logprobs = unwrap(kwargs.get("logprobs"), [])
|
||||
offset = unwrap(kwargs.get("offset"), [])
|
||||
logprob_response = None
|
||||
|
||||
logprob_response = CompletionLogProbs(
|
||||
text_offset=offset if isinstance(offset, list) else [offset],
|
||||
token_logprobs=token_probs.values(),
|
||||
tokens=token_probs.keys(),
|
||||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
)
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
offset = unwrap(generation.get("offset"), [])
|
||||
|
||||
logprob_response = CompletionLogProbs(
|
||||
text_offset=offset if isinstance(offset, list) else [offset],
|
||||
token_logprobs=token_probs.values(),
|
||||
tokens=token_probs.keys(),
|
||||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
)
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
finish_reason="Generated",
|
||||
text=unwrap(kwargs.get("text"), ""),
|
||||
text=unwrap(generation.get("text"), ""),
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(kwargs.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(kwargs.get("completion_tokens"), 0)
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
choices=[choice],
|
||||
model=unwrap(kwargs.get("model_name"), ""),
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
|
@ -53,17 +56,18 @@ def create_completion_response(**kwargs):
|
|||
return response
|
||||
|
||||
|
||||
def create_chat_completion_response(
|
||||
text: str,
|
||||
prompt_tokens: Optional[int],
|
||||
completion_tokens: Optional[int],
|
||||
model_name: Optional[str],
|
||||
):
|
||||
def create_chat_completion_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
message = ChatCompletionMessage(role="assistant", content=unwrap(text, ""))
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices=[choice],
|
||||
model=unwrap(model_name, ""),
|
||||
|
|
@ -79,15 +83,18 @@ def create_chat_completion_response(
|
|||
|
||||
def create_chat_completion_stream_chunk(
|
||||
const_id: str,
|
||||
text: Optional[str] = None,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
|
||||
if finish_reason:
|
||||
message = {}
|
||||
else:
|
||||
message = ChatCompletionMessage(role="assistant", content=text)
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
# The finish reason can be None
|
||||
choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message)
|
||||
|
|
|
|||
|
|
@ -505,7 +505,7 @@ class ExllamaV2Container:
|
|||
generations = list(self.generate_gen(prompt, **kwargs))
|
||||
|
||||
joined_generation = {
|
||||
"chunk": "",
|
||||
"text": "",
|
||||
"prompt_tokens": 0,
|
||||
"generation_tokens": 0,
|
||||
"offset": [],
|
||||
|
|
@ -515,7 +515,7 @@ class ExllamaV2Container:
|
|||
|
||||
if generations:
|
||||
for generation in generations:
|
||||
joined_generation["chunk"] += unwrap(generation.get("chunk"), "")
|
||||
joined_generation["text"] += unwrap(generation.get("text"), "")
|
||||
joined_generation["offset"].append(unwrap(generation.get("offset"), []))
|
||||
joined_generation["token_probs"].update(
|
||||
unwrap(generation.get("token_probs"), {})
|
||||
|
|
@ -835,7 +835,7 @@ class ExllamaV2Container:
|
|||
elapsed > stream_interval or eos or generated_tokens == max_tokens
|
||||
):
|
||||
generation = {
|
||||
"chunk": chunk_buffer,
|
||||
"text": chunk_buffer,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
|
|
|
|||
16
main.py
16
main.py
|
|
@ -462,10 +462,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(
|
||||
**generation,
|
||||
model_name=model_path.name,
|
||||
)
|
||||
response = create_completion_response(generation, model_path.name)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
|
||||
|
|
@ -483,7 +480,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
generation = await call_with_semaphore(
|
||||
partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params())
|
||||
)
|
||||
response = create_completion_response(**generation)
|
||||
response = create_completion_response(generation, model_path.name)
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -548,7 +545,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
break
|
||||
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id, generation.get("chunk"), model_path.name
|
||||
const_id, generation, model_path.name
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
|
|
@ -568,13 +565,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
generate_with_semaphore(generator), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
||||
generation = await call_with_semaphore(
|
||||
partial(MODEL_CONTAINER.generate, prompt, **data.to_gen_params())
|
||||
)
|
||||
|
||||
response = create_chat_completion_response(
|
||||
response_text, prompt_tokens, completion_tokens, model_path.name
|
||||
)
|
||||
response = create_chat_completion_response(generation, model_path.name)
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue