diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py new file mode 100644 index 0000000..233846d --- /dev/null +++ b/OAI/types/chat_completion.py @@ -0,0 +1,44 @@ +from uuid import uuid4 +from time import time +from pydantic import BaseModel, Field +from typing import Union, List, Optional +from OAI.types.common import UsageStats, CommonCompletionRequest + +class ChatCompletionMessage(BaseModel): + role: str + content: str + +class ChatCompletionRespChoice(BaseModel): + # Index is 0 since we aren't using multiple choices + index: int = 0 + finish_reason: str + message: ChatCompletionMessage + +class ChatCompletionStreamChoice(BaseModel): + # Index is 0 since we aren't using multiple choices + index: int = 0 + finish_reason: str + delta: ChatCompletionMessage + +# Inherited from common request +class ChatCompletionRequest(CommonCompletionRequest): + # Messages + # Take in a string as well even though it's not part of the OAI spec + messages: Union[str, List[ChatCompletionMessage]] + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") + choices: List[ChatCompletionRespChoice] + created: int = Field(default_factory=lambda: int(time())) + model: str + object: str = "chat.completion" + + # TODO: Add usage stats + usage: Optional[UsageStats] = None + +class ChatCompletionStreamChunk(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") + choices: List[ChatCompletionStreamChoice] + created: int = Field(default_factory=lambda: int(time())) + model: str + object: str = "chat.completion.chunk" diff --git a/OAI/types/common.py b/OAI/types/common.py index a86341d..c0afb06 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import List, Dict +from typing import List, Dict, Optional, Union class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) @@ -11,3 +11,75 @@ class UsageStats(BaseModel): completion_tokens: int prompt_tokens: int total_tokens: int + +class CommonCompletionRequest(BaseModel): + # Model information + # This parameter is not used, the loaded model is used instead + model: Optional[str] = None + + # Extra OAI request stuff + best_of: Optional[int] = None + echo: Optional[bool] = False + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + n: Optional[int] = 1 + suffix: Optional[str] = None + user: Optional[str] = None + + # Generation info + seed: Optional[int] = -1 + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + + # Default to 150 as 16 makes no sense as a default + max_tokens: Optional[int] = 150 + + # Aliased to repetition_penalty + frequency_penalty: Optional[float] = 0.0 + + # Sampling params + token_healing: Optional[bool] = False + temperature: Optional[float] = 1.0 + top_k: Optional[int] = 0 + top_p: Optional[float] = 1.0 + typical: Optional[float] = 0.0 + min_p: Optional[float] = 0.0 + tfs: Optional[float] = 1.0 + repetition_penalty: Optional[float] = 1.0 + repetition_penalty_range: Optional[int] = 0 + repetition_decay: Optional[int] = 0 + mirostat_mode: Optional[int] = 0 + mirostat_tau: Optional[float] = 1.5 + mirostat_eta: Optional[float] = 0.1 + add_bos_token: Optional[bool] = True + ban_eos_token: Optional[bool] = False + + # Converts to internal generation parameters + def to_gen_params(self): + # Convert stop to an array of strings + if isinstance(self.stop, str): + self.stop = [self.stop] + + # Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined + if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty: + self.repetition_penalty = self.frequency_penalty + + return { + "stop": self.stop, + "max_tokens": self.max_tokens, + "add_bos_token": self.add_bos_token, + "ban_eos_token": self.ban_eos_token, + "token_healing": self.token_healing, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + "typical": self.typical, + "min_p": self.min_p, + "tfs": self.tfs, + "repetition_penalty": self.repetition_penalty, + "repetition_penalty_range": self.repetition_penalty_range, + "repetition_decay": self.repetition_decay, + "mirostat": True if self.mirostat_mode == 2 else False, + "mirostat_tau": self.mirostat_tau, + "mirostat_eta": self.mirostat_eta + } diff --git a/OAI/types/completion.py b/OAI/types/completion.py index 813877c..59cb333 100644 --- a/OAI/types/completion.py +++ b/OAI/types/completion.py @@ -2,102 +2,26 @@ from uuid import uuid4 from time import time from pydantic import BaseModel, Field from typing import List, Optional, Dict, Union -from OAI.types.common import LogProbs, UsageStats +from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest class CompletionRespChoice(BaseModel): + # Index is 0 since we aren't using multiple choices + index: int = 0 finish_reason: str - index: int logprobs: Optional[LogProbs] = None text: str -class CompletionRequest(BaseModel): - # Model information - model: str - +# Inherited from common request +class CompletionRequest(CommonCompletionRequest): # Prompt can also contain token ids, but that's out of scope for this project. prompt: Union[str, List[str]] - # Extra OAI request stuff - best_of: Optional[int] = None - echo: Optional[bool] = False - logit_bias: Optional[Dict[str, float]] = None - logprobs: Optional[int] = None - n: Optional[int] = 1 - suffix: Optional[str] = None - user: Optional[str] = None - - # Generation info - seed: Optional[int] = -1 - stream: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = None - - # Default to 150 as 16 makes no sense as a default - max_tokens: Optional[int] = 150 - - # Not supported sampling params - presence_penalty: Optional[float] = 0.0 - - # Aliased to repetition_penalty - frequency_penalty: Optional[float] = 0.0 - - # Sampling params - token_healing: Optional[bool] = False - temperature: Optional[float] = 1.0 - top_k: Optional[int] = 0 - top_p: Optional[float] = 1.0 - typical: Optional[float] = 0.0 - min_p: Optional[float] = 0.0 - tfs: Optional[float] = 1.0 - repetition_penalty: Optional[float] = 1.0 - repetition_penalty_range: Optional[int] = 0 - repetition_decay: Optional[int] = 0 - mirostat_mode: Optional[int] = 0 - mirostat_tau: Optional[float] = 1.5 - mirostat_eta: Optional[float] = 0.1 - add_bos_token: Optional[bool] = True - ban_eos_token: Optional[bool] = False - - # Converts to internal generation parameters - def to_gen_params(self): - # Convert prompt to a string - if isinstance(self.prompt, list): - self.prompt = "\n".join(self.prompt) - - # Convert stop to an array of strings - if isinstance(self.stop, str): - self.stop = [self.stop] - - # Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined - if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty: - self.repetition_penalty = self.frequency_penalty - - return { - "prompt": self.prompt, - "stop": self.stop, - "max_tokens": self.max_tokens, - "add_bos_token": self.add_bos_token, - "ban_eos_token": self.ban_eos_token, - "token_healing": self.token_healing, - "temperature": self.temperature, - "top_k": self.top_k, - "top_p": self.top_p, - "typical": self.typical, - "min_p": self.min_p, - "tfs": self.tfs, - "repetition_penalty": self.repetition_penalty, - "repetition_penalty_range": self.repetition_penalty_range, - "repetition_decay": self.repetition_decay, - "mirostat": True if self.mirostat_mode == 2 else False, - "mirostat_tau": self.mirostat_tau, - "mirostat_eta": self.mirostat_eta - } - class CompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}") choices: List[CompletionRespChoice] created: int = Field(default_factory=lambda: int(time())) model: str - object: str = "text-completion" + object: str = "text_completion" # TODO: Add usage stats usage: Optional[UsageStats] = None diff --git a/OAI/utils.py b/OAI/utils.py index f6f70a0..87ce8be 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,15 +1,22 @@ import pathlib +from fastchat.model.model_adapter import get_conversation_template, Conversation from OAI.types.completion import CompletionResponse, CompletionRespChoice +from OAI.types.chat_completion import ( + ChatCompletionMessage, + ChatCompletionRespChoice, + ChatCompletionStreamChunk, + ChatCompletionResponse, + ChatCompletionStreamChoice +) from OAI.types.common import UsageStats from OAI.types.model import ModelList, ModelCard -from typing import Optional +from typing import Optional, List -def create_completion_response(text: str, index: int, model_name: Optional[str]): +def create_completion_response(text: str, model_name: Optional[str]): # TODO: Add method to get token amounts in model for UsageStats choice = CompletionRespChoice( - finish_reason="Generated", - index = index, + finish_reason = "Generated", text = text ) @@ -20,11 +27,70 @@ def create_completion_response(text: str, index: int, model_name: Optional[str]) return response +def create_chat_completion_response(text: str, model_name: Optional[str]): + # TODO: Add method to get token amounts in model for UsageStats + + message = ChatCompletionMessage( + role = "assistant", + content = text + ) + + choice = ChatCompletionRespChoice( + finish_reason = "Generated", + message = message + ) + + response = ChatCompletionResponse( + choices = [choice], + model = model_name or "" + ) + + return response + +def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Optional[str]): + # TODO: Add method to get token amounts in model for UsageStats + + message = ChatCompletionMessage( + role = "assistant", + content = text + ) + + choice = ChatCompletionStreamChoice( + finish_reason = "Generated", + delta = message + ) + + chunk = ChatCompletionStreamChunk( + id = const_id, + choices = [choice], + model = model_name or "" + ) + + return chunk + def get_model_list(model_path: pathlib.Path): model_card_list = ModelList() - for path in model_path.parent.iterdir(): + for path in model_path.iterdir(): if path.is_dir(): model_card = ModelCard(id = path.name) model_card_list.data.append(model_card) return model_card_list + +def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]): + conv = get_conversation_template(model_path) + for message in messages: + msg_role = message.role + if msg_role == "system": + conv.system_message = message.content + elif msg_role == "user": + conv.append_message(conv.roles[0], message.content) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message.content) + else: + raise ValueError(f"Unknown role: {msg_role}") + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt diff --git a/main.py b/main.py index cb8cefd..81022d8 100644 --- a/main.py +++ b/main.py @@ -7,11 +7,19 @@ from model import ModelContainer from progress.bar import IncrementalBar from sse_starlette import EventSourceResponse from OAI.types.completion import CompletionRequest +from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse -from OAI.utils import create_completion_response, get_model_list +from OAI.utils import ( + create_completion_response, + get_model_list, + get_chat_completion_prompt, + create_chat_completion_response, + create_chat_completion_stream_chunk +) from typing import Optional from utils import load_progress +from uuid import uuid4 app = FastAPI() @@ -45,8 +53,8 @@ async def load_model(data: ModelLoadRequest): model_config = config["model"] model_path = pathlib.Path(model_config["model_dir"] or "models") model_path = model_path / data.name - - model_container = ModelContainer(model_path, False, **data.model_dump()) + + model_container = ModelContainer(model_path, False, **data.dict()) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: @@ -97,21 +105,58 @@ async def decode_tokens(data: TokenDecodeRequest): @app.post("/v1/completions", dependencies=[Depends(check_api_key)]) async def generate_completion(request: Request, data: CompletionRequest): + model_path = model_container.get_model_path() + + if isinstance(data.prompt, list): + data.prompt = "\n".join(data.prompt) + if data.stream: async def generator(): - new_generation = model_container.generate_gen(**data.to_gen_params()) - for index, part in enumerate(new_generation): + new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) + for part in new_generation: if await request.is_disconnected(): break - response = create_completion_response(part, index, model_container.get_model_path().name) + response = create_completion_response(part, model_path.name) yield response.json() return EventSourceResponse(generator()) else: - response_text = model_container.generate(**data.to_gen_params()) - response = create_completion_response(response_text, 0, model_container.get_model_path().name) + response_text = model_container.generate(data.prompt, **data.to_gen_params()) + response = create_completion_response(response_text, model_path.name) + + return response.json() + +@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def generate_chat_completion(request: Request, data: ChatCompletionRequest): + model_path = model_container.get_model_path() + + if isinstance(data.messages, str): + prompt = data.messages + else: + prompt = get_chat_completion_prompt(model_path.name, data.messages) + + if data.stream: + const_id = f"chatcmpl-{uuid4().hex}" + async def generator(): + new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) + for part in new_generation: + if await request.is_disconnected(): + break + + response = create_chat_completion_stream_chunk( + const_id, + part, + model_path.name + ) + + yield response.json() + + return EventSourceResponse(generator()) + else: + response_text = model_container.generate(prompt, **data.to_gen_params()) + response = create_chat_completion_response(response_text, model_path.name) return response.json()