formatting changes

This commit is contained in:
John 2024-10-13 07:27:27 -04:00 committed by Jakob Lechner
parent 1fd4a9e119
commit 4ace973244
4 changed files with 36 additions and 35 deletions

View file

@ -13,6 +13,10 @@ from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config from common.tabby_config import config
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.common import (
ModelItem, ModelListResponse
)
import hashlib
from endpoints.OAI.types.chat_completion import ( from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -172,28 +176,14 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes
return response return response
from pydantic import BaseModel, Field
from typing import List, Optional
import hashlib
class ModelItem(BaseModel):
model: str
name: str
digest: str
urls: List[int]
class ModelListResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
models: List[ModelItem]
async def fetch_models(): async def fetch_models():
models_dir = "models" models_dir = "models"
models = [] models = []
# Iterate over the files in the models directory
if os.path.exists(models_dir): if os.path.exists(models_dir):
for model in os.listdir(models_dir): for model in os.listdir(models_dir):
model_path = os.path.join(models_dir, model) model_path = os.path.join(models_dir, model)
if os.path.isdir(model_path): # Assuming each model is in its own directory if os.path.isdir(model_path):
digest = hashlib.md5(model.encode()).hexdigest() digest = hashlib.md5(model.encode()).hexdigest()
models.append({ models.append({
"model":f"{model}:latest", "model":f"{model}:latest",
@ -224,18 +214,13 @@ async def dummy(request: Request):
dependencies=[Depends(check_api_key)] dependencies=[Depends(check_api_key)]
) )
async def get_all_models(request: Request) -> ModelListResponse: async def get_all_models(request: Request) -> ModelListResponse:
print(f"Processing request for models {request.state.id}")
response = await run_with_request_disconnect( response = await run_with_request_disconnect(
request, request,
asyncio.create_task(fetch_models()), asyncio.create_task(fetch_models()),
disconnect_message=f"All models fetched", disconnect_message="All models fetched",
) )
return response return response
@router.post( @router.post(
"/api/chat", "/api/chat",
dependencies=[Depends(check_api_key)], dependencies=[Depends(check_api_key)],
@ -284,7 +269,6 @@ async def chat_completion_request_ollama(
if data.stream and not disable_request_streaming: if data.stream and not disable_request_streaming:
return StreamingResponse(stream_response(request), media_type="application/x-ndjson") return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
else: else:
generate_task = asyncio.create_task( generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path) generate_chat_completion(prompt, data, request, model_path)
@ -296,4 +280,3 @@ async def chat_completion_request_ollama(
disconnect_message=f"Chat completion {request.state.id} cancelled by user.", disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
) )
return response return response

View file

@ -1,6 +1,6 @@
"""Common types for OAI.""" """Common types for OAI."""
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, List
from typing import Optional, Union from typing import Optional, Union
from common.sampling import BaseSamplerRequest, get_default_sampler_value from common.sampling import BaseSamplerRequest, get_default_sampler_value
@ -58,3 +58,29 @@ class CommonCompletionRequest(BaseSamplerRequest):
user: Optional[str] = Field( user: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None description="Not parsed. Only used for OAI compliance.", default=None
) )
def validate_params(self):
# Temperature
if self.n < 1:
raise ValueError(f"n must be greater than or equal to 1. Got {self.n}")
return super().validate_params()
def to_gen_params(self):
extra_gen_params = {
"stream": self.stream,
"logprobs": self.logprobs,
}
return super().to_gen_params(**extra_gen_params)
class ModelItem(BaseModel):
model: str
name: str
digest: str
urls: List[int]
class ModelListResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
models: List[ModelItem]

View file

@ -177,7 +177,7 @@ def _create_stream_chunk_ollama(
delta=message, delta=message,
logprobs=logprob_response, logprobs=logprob_response,
) )
ollama_bit = { ollama_chunk = {
"model":model_name, "model":model_name,
"created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z", "created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z",
"message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none', "message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none',
@ -185,7 +185,7 @@ def _create_stream_chunk_ollama(
"done_reason": choice.finish_reason, "done_reason": choice.finish_reason,
"done": choice.finish_reason=="stop", "done": choice.finish_reason=="stop",
} }
return ollama_bit return ollama_chunk
def _create_stream_chunk( def _create_stream_chunk(
request_id: str, request_id: str,

View file

@ -176,14 +176,6 @@ async def load_inline_model(model_name: str, request: Request):
return return
if model.container is not None:
if model.container.model_dir.name != model_name:
logger.info(f"New model requested: {model_name}. Unloading current model.")
await model.unload_model()
elif model.container.model_loaded:
logger.info(f"Model {model_name} is already loaded.")
return
model_path = pathlib.Path(config.model.model_dir) model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model_name model_path = model_path / model_name