formatting changes
This commit is contained in:
parent
1fd4a9e119
commit
4ace973244
4 changed files with 36 additions and 35 deletions
|
|
@ -13,6 +13,10 @@ from common.model import check_embeddings_container, check_model_container
|
|||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.tabby_config import config
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
from endpoints.OAI.types.common import (
|
||||
ModelItem, ModelListResponse
|
||||
)
|
||||
import hashlib
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -172,28 +176,14 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes
|
|||
|
||||
return response
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
import hashlib
|
||||
|
||||
class ModelItem(BaseModel):
|
||||
model: str
|
||||
name: str
|
||||
digest: str
|
||||
urls: List[int]
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
object: str = Field("list", description="Type of the response object.")
|
||||
models: List[ModelItem]
|
||||
|
||||
async def fetch_models():
|
||||
models_dir = "models"
|
||||
models = []
|
||||
# Iterate over the files in the models directory
|
||||
if os.path.exists(models_dir):
|
||||
for model in os.listdir(models_dir):
|
||||
model_path = os.path.join(models_dir, model)
|
||||
if os.path.isdir(model_path): # Assuming each model is in its own directory
|
||||
if os.path.isdir(model_path):
|
||||
digest = hashlib.md5(model.encode()).hexdigest()
|
||||
models.append({
|
||||
"model":f"{model}:latest",
|
||||
|
|
@ -224,18 +214,13 @@ async def dummy(request: Request):
|
|||
dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def get_all_models(request: Request) -> ModelListResponse:
|
||||
print(f"Processing request for models {request.state.id}")
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
asyncio.create_task(fetch_models()),
|
||||
disconnect_message=f"All models fetched",
|
||||
disconnect_message="All models fetched",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/api/chat",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
|
|
@ -284,7 +269,6 @@ async def chat_completion_request_ollama(
|
|||
if data.stream and not disable_request_streaming:
|
||||
return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
generate_chat_completion(prompt, data, request, model_path)
|
||||
|
|
@ -295,5 +279,4 @@ async def chat_completion_request_ollama(
|
|||
generate_task,
|
||||
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
return response
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
"""Common types for OAI."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, List
|
||||
from typing import Optional, Union
|
||||
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
|
|
@ -58,3 +58,29 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
user: Optional[str] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
|
||||
def validate_params(self):
|
||||
# Temperature
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be greater than or equal to 1. Got {self.n}")
|
||||
|
||||
return super().validate_params()
|
||||
|
||||
def to_gen_params(self):
|
||||
extra_gen_params = {
|
||||
"stream": self.stream,
|
||||
"logprobs": self.logprobs,
|
||||
}
|
||||
|
||||
return super().to_gen_params(**extra_gen_params)
|
||||
|
||||
|
||||
class ModelItem(BaseModel):
|
||||
model: str
|
||||
name: str
|
||||
digest: str
|
||||
urls: List[int]
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
object: str = Field("list", description="Type of the response object.")
|
||||
models: List[ModelItem]
|
||||
|
|
|
|||
|
|
@ -177,7 +177,7 @@ def _create_stream_chunk_ollama(
|
|||
delta=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
ollama_bit = {
|
||||
ollama_chunk = {
|
||||
"model":model_name,
|
||||
"created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z",
|
||||
"message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none',
|
||||
|
|
@ -185,7 +185,7 @@ def _create_stream_chunk_ollama(
|
|||
"done_reason": choice.finish_reason,
|
||||
"done": choice.finish_reason=="stop",
|
||||
}
|
||||
return ollama_bit
|
||||
return ollama_chunk
|
||||
|
||||
def _create_stream_chunk(
|
||||
request_id: str,
|
||||
|
|
|
|||
|
|
@ -176,14 +176,6 @@ async def load_inline_model(model_name: str, request: Request):
|
|||
|
||||
return
|
||||
|
||||
if model.container is not None:
|
||||
if model.container.model_dir.name != model_name:
|
||||
logger.info(f"New model requested: {model_name}. Unloading current model.")
|
||||
await model.unload_model()
|
||||
elif model.container.model_loaded:
|
||||
logger.info(f"Model {model_name} is already loaded.")
|
||||
return
|
||||
|
||||
model_path = pathlib.Path(config.model.model_dir)
|
||||
model_path = model_path / model_name
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue