diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index b9aac68..4f02b47 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from typing import Optional -from common.sampling import BaseSamplerRequest +from common.sampling import BaseSamplerRequest, get_default_sampler_value class UsageStats(BaseModel): @@ -27,10 +27,13 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False - logprobs: Optional[int] = 0 + logprobs: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("logprobs", 0) + ) response_format: Optional[CompletionResponseFormat] = Field( default_factory=CompletionResponseFormat ) + n: Optional[int] = Field(default_factory=lambda: get_default_sampler_value("n", 1)) # Extra OAI request stuff best_of: Optional[int] = Field( @@ -39,9 +42,6 @@ class CommonCompletionRequest(BaseSamplerRequest): echo: Optional[bool] = Field( description="Not parsed. Only used for OAI compliance.", default=False ) - n: Optional[int] = Field( - description="Not parsed. Only used for OAI compliance.", default=1 - ) suffix: Optional[str] = Field( description="Not parsed. Only used for OAI compliance.", default=None ) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9359f74..e88fc6b 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -3,7 +3,7 @@ import asyncio import pathlib from asyncio import CancelledError -from typing import Optional +from typing import List, Optional from uuid import uuid4 from fastapi import HTTPException, Request @@ -31,47 +31,52 @@ from endpoints.OAI.types.chat_completion import ( from endpoints.OAI.types.common import UsageStats -def _create_response(generation: dict, model_name: Optional[str]): +def _create_response(generations: List[dict], model_name: Optional[str]): """Create a chat completion response from the provided text.""" - message = ChatCompletionMessage( - role="assistant", content=unwrap(generation.get("text"), "") - ) + prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) + completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) - logprob_response = None + choices = [] + for index, generation in enumerate(generations): + message = ChatCompletionMessage( + role="assistant", content=unwrap(generation.get("text"), "") + ) - token_probs = unwrap(generation.get("token_probs"), {}) - if token_probs: - logprobs = unwrap(generation.get("logprobs"), []) + logprob_response = None - collected_token_probs = [] - for index, token in enumerate(token_probs.keys()): - top_logprobs = [ - ChatCompletionLogprob(token=token, logprob=logprob) - for token, logprob in logprobs[index].items() - ] + token_probs = unwrap(generation.get("token_probs"), {}) + if token_probs: + logprobs = unwrap(generation.get("logprobs"), []) - collected_token_probs.append( - ChatCompletionLogprob( - token=token, - logprob=token_probs[token], - top_logprobs=top_logprobs, + collected_token_probs = [] + for index, token in enumerate(token_probs.keys()): + top_logprobs = [ + ChatCompletionLogprob(token=token, logprob=logprob) + for token, logprob in logprobs[index].items() + ] + + collected_token_probs.append( + ChatCompletionLogprob( + token=token, + logprob=token_probs[token], + top_logprobs=top_logprobs, + ) ) - ) - logprob_response = ChatCompletionLogprobs(content=collected_token_probs) + logprob_response = ChatCompletionLogprobs(content=collected_token_probs) - choice = ChatCompletionRespChoice( - finish_reason=generation.get("finish_reason"), - message=message, - logprobs=logprob_response, - ) + choice = ChatCompletionRespChoice( + index=index, + finish_reason=generation.get("finish_reason"), + message=message, + logprobs=logprob_response, + ) - prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) - completion_tokens = unwrap(generation.get("generated_tokens"), 0) + choices.append(choice) response = ChatCompletionResponse( - choices=[choice], + choices=choices, model=unwrap(model_name, ""), usage=UsageStats( prompt_tokens=prompt_tokens, @@ -236,12 +241,18 @@ async def stream_generate_chat_completion( async def generate_chat_completion( prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path ): + gen_tasks: List[asyncio.Task] = [] + try: - generation = await model.container.generate( - prompt, - **data.to_gen_params(), - ) - response = _create_response(generation, model_path.name) + for _ in range(0, data.n): + gen_tasks.append( + asyncio.create_task( + model.container.generate(prompt, **data.to_gen_params()) + ) + ) + + generations = await asyncio.gather(*gen_tasks) + response = _create_response(generations, model_path.name) return response except Exception as exc: diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index e242be6..b86ccd5 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -4,7 +4,7 @@ import asyncio import pathlib from asyncio import CancelledError from fastapi import HTTPException, Request -from typing import Optional +from typing import List, Optional from common import model from common.networking import ( @@ -23,34 +23,39 @@ from endpoints.OAI.types.completion import ( from endpoints.OAI.types.common import UsageStats -def _create_response(generation: dict, model_name: Optional[str]): +def _create_response(generations: List[dict], model_name: Optional[str]): """Create a completion response from the provided text.""" - logprob_response = None + prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) + completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) - token_probs = unwrap(generation.get("token_probs"), {}) - if token_probs: - logprobs = unwrap(generation.get("logprobs"), []) - offset = unwrap(generation.get("offset"), []) + choices = [] + for index, generation in enumerate(generations): + logprob_response = None - logprob_response = CompletionLogProbs( - text_offset=offset if isinstance(offset, list) else [offset], - token_logprobs=token_probs.values(), - tokens=token_probs.keys(), - top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], + token_probs = unwrap(generation.get("token_probs"), {}) + if token_probs: + logprobs = unwrap(generation.get("logprobs"), []) + offset = unwrap(generation.get("offset"), []) + + logprob_response = CompletionLogProbs( + text_offset=offset if isinstance(offset, list) else [offset], + token_logprobs=token_probs.values(), + tokens=token_probs.keys(), + top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs], + ) + + choice = CompletionRespChoice( + index=index, + finish_reason=generation.get("finish_reason"), + text=unwrap(generation.get("text"), ""), + logprobs=logprob_response, ) - choice = CompletionRespChoice( - finish_reason=generation.get("finish_reason"), - text=unwrap(generation.get("text"), ""), - logprobs=logprob_response, - ) - - prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) - completion_tokens = unwrap(generation.get("generated_tokens"), 0) + choices.append(choice) response = CompletionResponse( - choices=[choice], + choices=choices, model=unwrap(model_name, ""), usage=UsageStats( prompt_tokens=prompt_tokens, @@ -84,7 +89,7 @@ async def stream_generate_completion( abort_event.set() handle_request_disconnect("Completion generation cancelled by user.") - response = _create_response(generation, model_path.name) + response = _create_response([generation], model_path.name) yield response.model_dump_json() # Break if the generation is finished @@ -105,9 +110,18 @@ async def stream_generate_completion( async def generate_completion(data: CompletionRequest, model_path: pathlib.Path): """Non-streaming generate for completions""" + gen_tasks: List[asyncio.Task] = [] + try: - generation = await model.container.generate(data.prompt, **data.to_gen_params()) - response = _create_response(generation, model_path.name) + for _ in range(0, data.n): + gen_tasks.append( + asyncio.create_task( + model.container.generate(data.prompt, **data.to_gen_params()) + ) + ) + + generations = await asyncio.gather(*gen_tasks) + response = _create_response(generations, model_path.name) return response except Exception as exc: