diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 8947723..9874244 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -13,6 +13,10 @@ from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.tabby_config import config from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse +from endpoints.OAI.types.common import ( + ModelItem, ModelListResponse +) +import hashlib from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, @@ -172,28 +176,14 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes return response -from pydantic import BaseModel, Field -from typing import List, Optional -import hashlib - -class ModelItem(BaseModel): - model: str - name: str - digest: str - urls: List[int] - -class ModelListResponse(BaseModel): - object: str = Field("list", description="Type of the response object.") - models: List[ModelItem] async def fetch_models(): models_dir = "models" models = [] - # Iterate over the files in the models directory if os.path.exists(models_dir): for model in os.listdir(models_dir): model_path = os.path.join(models_dir, model) - if os.path.isdir(model_path): # Assuming each model is in its own directory + if os.path.isdir(model_path): digest = hashlib.md5(model.encode()).hexdigest() models.append({ "model":f"{model}:latest", @@ -224,18 +214,13 @@ async def dummy(request: Request): dependencies=[Depends(check_api_key)] ) async def get_all_models(request: Request) -> ModelListResponse: - print(f"Processing request for models {request.state.id}") - response = await run_with_request_disconnect( request, asyncio.create_task(fetch_models()), - disconnect_message=f"All models fetched", + disconnect_message="All models fetched", ) - return response - - @router.post( "/api/chat", dependencies=[Depends(check_api_key)], @@ -284,7 +269,6 @@ async def chat_completion_request_ollama( if data.stream and not disable_request_streaming: return StreamingResponse(stream_response(request), media_type="application/x-ndjson") - else: generate_task = asyncio.create_task( generate_chat_completion(prompt, data, request, model_path) @@ -295,5 +279,4 @@ async def chat_completion_request_ollama( generate_task, disconnect_message=f"Chat completion {request.state.id} cancelled by user.", ) - return response - + return response \ No newline at end of file diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index 0f765de..e4c853c 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -1,6 +1,6 @@ """Common types for OAI.""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, List from typing import Optional, Union from common.sampling import BaseSamplerRequest, get_default_sampler_value @@ -58,3 +58,29 @@ class CommonCompletionRequest(BaseSamplerRequest): user: Optional[str] = Field( description="Not parsed. Only used for OAI compliance.", default=None ) + + def validate_params(self): + # Temperature + if self.n < 1: + raise ValueError(f"n must be greater than or equal to 1. Got {self.n}") + + return super().validate_params() + + def to_gen_params(self): + extra_gen_params = { + "stream": self.stream, + "logprobs": self.logprobs, + } + + return super().to_gen_params(**extra_gen_params) + + +class ModelItem(BaseModel): + model: str + name: str + digest: str + urls: List[int] + +class ModelListResponse(BaseModel): + object: str = Field("list", description="Type of the response object.") + models: List[ModelItem] diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index efc7b34..1ac2466 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -177,7 +177,7 @@ def _create_stream_chunk_ollama( delta=message, logprobs=logprob_response, ) - ollama_bit = { + ollama_chunk = { "model":model_name, "created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z", "message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none', @@ -185,7 +185,7 @@ def _create_stream_chunk_ollama( "done_reason": choice.finish_reason, "done": choice.finish_reason=="stop", } - return ollama_bit + return ollama_chunk def _create_stream_chunk( request_id: str, diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 98712d9..9cf1418 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -176,14 +176,6 @@ async def load_inline_model(model_name: str, request: Request): return - if model.container is not None: - if model.container.model_dir.name != model_name: - logger.info(f"New model requested: {model_name}. Unloading current model.") - await model.unload_model() - elif model.container.model_loaded: - logger.info(f"Model {model_name} is already loaded.") - return - model_path = pathlib.Path(config.model.model_dir) model_path = model_path / model_name