OAI: Add "n" for non-streaming generations

This adds the ability to add multiple choices to a generation. This
is only available for non-streaming gens for now, it requires some
more work to port over to streaming.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-26 22:22:51 -04:00 committed by Brian Dashore
parent 8d31a5aed1
commit b944f8d756
3 changed files with 89 additions and 64 deletions

View file

@ -3,7 +3,7 @@
from pydantic import BaseModel, Field
from typing import Optional
from common.sampling import BaseSamplerRequest
from common.sampling import BaseSamplerRequest, get_default_sampler_value
class UsageStats(BaseModel):
@ -27,10 +27,13 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
logprobs: Optional[int] = 0
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0)
)
response_format: Optional[CompletionResponseFormat] = Field(
default_factory=CompletionResponseFormat
)
n: Optional[int] = Field(default_factory=lambda: get_default_sampler_value("n", 1))
# Extra OAI request stuff
best_of: Optional[int] = Field(
@ -39,9 +42,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
suffix: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)

View file

@ -3,7 +3,7 @@
import asyncio
import pathlib
from asyncio import CancelledError
from typing import Optional
from typing import List, Optional
from uuid import uuid4
from fastapi import HTTPException, Request
@ -31,47 +31,52 @@ from endpoints.OAI.types.chat_completion import (
from endpoints.OAI.types.common import UsageStats
def _create_response(generation: dict, model_name: Optional[str]):
def _create_response(generations: List[dict], model_name: Optional[str]):
"""Create a chat completion response from the provided text."""
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
logprob_response = None
choices = []
for index, generation in enumerate(generations):
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
logprob_response = None
collected_token_probs = []
for index, token in enumerate(token_probs.keys()):
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs[index].items()
]
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
collected_token_probs.append(
ChatCompletionLogprob(
token=token,
logprob=token_probs[token],
top_logprobs=top_logprobs,
collected_token_probs = []
for index, token in enumerate(token_probs.keys()):
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs[index].items()
]
collected_token_probs.append(
ChatCompletionLogprob(
token=token,
logprob=token_probs[token],
top_logprobs=top_logprobs,
)
)
)
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
choice = ChatCompletionRespChoice(
finish_reason=generation.get("finish_reason"),
message=message,
logprobs=logprob_response,
)
choice = ChatCompletionRespChoice(
index=index,
finish_reason=generation.get("finish_reason"),
message=message,
logprobs=logprob_response,
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
choices.append(choice)
response = ChatCompletionResponse(
choices=[choice],
choices=choices,
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
@ -236,12 +241,18 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
):
gen_tasks: List[asyncio.Task] = []
try:
generation = await model.container.generate(
prompt,
**data.to_gen_params(),
)
response = _create_response(generation, model_path.name)
for _ in range(0, data.n):
gen_tasks.append(
asyncio.create_task(
model.container.generate(prompt, **data.to_gen_params())
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
return response
except Exception as exc:

View file

@ -4,7 +4,7 @@ import asyncio
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException, Request
from typing import Optional
from typing import List, Optional
from common import model
from common.networking import (
@ -23,34 +23,39 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats
def _create_response(generation: dict, model_name: Optional[str]):
def _create_response(generations: List[dict], model_name: Optional[str]):
"""Create a completion response from the provided text."""
logprob_response = None
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
choices = []
for index, generation in enumerate(generations):
logprob_response = None
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
choice = CompletionRespChoice(
index=index,
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
choice = CompletionRespChoice(
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
choices.append(choice)
response = CompletionResponse(
choices=[choice],
choices=choices,
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
@ -84,7 +89,7 @@ async def stream_generate_completion(
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
response = _create_response(generation, model_path.name)
response = _create_response([generation], model_path.name)
yield response.model_dump_json()
# Break if the generation is finished
@ -105,9 +110,18 @@ async def stream_generate_completion(
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Non-streaming generate for completions"""
gen_tasks: List[asyncio.Task] = []
try:
generation = await model.container.generate(data.prompt, **data.to_gen_params())
response = _create_response(generation, model_path.name)
for _ in range(0, data.n):
gen_tasks.append(
asyncio.create_task(
model.container.generate(data.prompt, **data.to_gen_params())
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
return response
except Exception as exc: