OAI: Add "n" for non-streaming generations
This adds the ability to add multiple choices to a generation. This is only available for non-streaming gens for now, it requires some more work to port over to streaming. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8d31a5aed1
commit
b944f8d756
3 changed files with 89 additions and 64 deletions
|
|
@ -3,7 +3,7 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
|
|
@ -27,10 +27,13 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = 0
|
||||
logprobs: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logprobs", 0)
|
||||
)
|
||||
response_format: Optional[CompletionResponseFormat] = Field(
|
||||
default_factory=CompletionResponseFormat
|
||||
)
|
||||
n: Optional[int] = Field(default_factory=lambda: get_default_sampler_value("n", 1))
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = Field(
|
||||
|
|
@ -39,9 +42,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
echo: Optional[bool] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=False
|
||||
)
|
||||
n: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=1
|
||||
)
|
||||
suffix: Optional[str] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
|
@ -31,47 +31,52 @@ from endpoints.OAI.types.chat_completion import (
|
|||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
logprob_response = None
|
||||
choices = []
|
||||
for index, generation in enumerate(generations):
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
logprob_response = None
|
||||
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
message=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
choice = ChatCompletionRespChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
message=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
choices.append(choice)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices=[choice],
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
|
@ -236,12 +241,18 @@ async def stream_generate_chat_completion(
|
|||
async def generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
|
||||
try:
|
||||
generation = await model.container.generate(
|
||||
prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
response = _create_response(generation, model_path.name)
|
||||
for _ in range(0, data.n):
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(prompt, **data.to_gen_params())
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
|
|
@ -23,34 +23,39 @@ from endpoints.OAI.types.completion import (
|
|||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
"""Create a completion response from the provided text."""
|
||||
|
||||
logprob_response = None
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
offset = unwrap(generation.get("offset"), [])
|
||||
choices = []
|
||||
for index, generation in enumerate(generations):
|
||||
logprob_response = None
|
||||
|
||||
logprob_response = CompletionLogProbs(
|
||||
text_offset=offset if isinstance(offset, list) else [offset],
|
||||
token_logprobs=token_probs.values(),
|
||||
tokens=token_probs.keys(),
|
||||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
offset = unwrap(generation.get("offset"), [])
|
||||
|
||||
logprob_response = CompletionLogProbs(
|
||||
text_offset=offset if isinstance(offset, list) else [offset],
|
||||
token_logprobs=token_probs.values(),
|
||||
tokens=token_probs.keys(),
|
||||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
)
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
text=unwrap(generation.get("text"), ""),
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
text=unwrap(generation.get("text"), ""),
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
choices.append(choice)
|
||||
|
||||
response = CompletionResponse(
|
||||
choices=[choice],
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
|
@ -84,7 +89,7 @@ async def stream_generate_completion(
|
|||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
response = _create_response([generation], model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Break if the generation is finished
|
||||
|
|
@ -105,9 +110,18 @@ async def stream_generate_completion(
|
|||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
"""Non-streaming generate for completions"""
|
||||
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
|
||||
try:
|
||||
generation = await model.container.generate(data.prompt, **data.to_gen_params())
|
||||
response = _create_response(generation, model_path.name)
|
||||
for _ in range(0, data.n):
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(data.prompt, **data.to_gen_params())
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue