formatting changes

This commit is contained in:
John 2024-10-13 07:27:27 -04:00 committed by Jakob Lechner
parent 1fd4a9e119
commit 4ace973244
4 changed files with 36 additions and 35 deletions

View file

@ -13,6 +13,10 @@ from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.common import (
ModelItem, ModelListResponse
)
import hashlib
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -172,28 +176,14 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes
return response
from pydantic import BaseModel, Field
from typing import List, Optional
import hashlib
class ModelItem(BaseModel):
model: str
name: str
digest: str
urls: List[int]
class ModelListResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
models: List[ModelItem]
async def fetch_models():
models_dir = "models"
models = []
# Iterate over the files in the models directory
if os.path.exists(models_dir):
for model in os.listdir(models_dir):
model_path = os.path.join(models_dir, model)
if os.path.isdir(model_path): # Assuming each model is in its own directory
if os.path.isdir(model_path):
digest = hashlib.md5(model.encode()).hexdigest()
models.append({
"model":f"{model}:latest",
@ -224,18 +214,13 @@ async def dummy(request: Request):
dependencies=[Depends(check_api_key)]
)
async def get_all_models(request: Request) -> ModelListResponse:
print(f"Processing request for models {request.state.id}")
response = await run_with_request_disconnect(
request,
asyncio.create_task(fetch_models()),
disconnect_message=f"All models fetched",
disconnect_message="All models fetched",
)
return response
@router.post(
"/api/chat",
dependencies=[Depends(check_api_key)],
@ -284,7 +269,6 @@ async def chat_completion_request_ollama(
if data.stream and not disable_request_streaming:
return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
@ -295,5 +279,4 @@ async def chat_completion_request_ollama(
generate_task,
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response
return response

View file

@ -1,6 +1,6 @@
"""Common types for OAI."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, List
from typing import Optional, Union
from common.sampling import BaseSamplerRequest, get_default_sampler_value
@ -58,3 +58,29 @@ class CommonCompletionRequest(BaseSamplerRequest):
user: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
def validate_params(self):
# Temperature
if self.n < 1:
raise ValueError(f"n must be greater than or equal to 1. Got {self.n}")
return super().validate_params()
def to_gen_params(self):
extra_gen_params = {
"stream": self.stream,
"logprobs": self.logprobs,
}
return super().to_gen_params(**extra_gen_params)
class ModelItem(BaseModel):
model: str
name: str
digest: str
urls: List[int]
class ModelListResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
models: List[ModelItem]

View file

@ -177,7 +177,7 @@ def _create_stream_chunk_ollama(
delta=message,
logprobs=logprob_response,
)
ollama_bit = {
ollama_chunk = {
"model":model_name,
"created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z",
"message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none',
@ -185,7 +185,7 @@ def _create_stream_chunk_ollama(
"done_reason": choice.finish_reason,
"done": choice.finish_reason=="stop",
}
return ollama_bit
return ollama_chunk
def _create_stream_chunk(
request_id: str,

View file

@ -176,14 +176,6 @@ async def load_inline_model(model_name: str, request: Request):
return
if model.container is not None:
if model.container.model_dir.name != model_name:
logger.info(f"New model requested: {model_name}. Unloading current model.")
await model.unload_model()
elif model.container.model_loaded:
logger.info(f"Model {model_name} is already loaded.")
return
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model_name