formatting changes
This commit is contained in:
parent
1fd4a9e119
commit
4ace973244
4 changed files with 36 additions and 35 deletions
|
|
@ -13,6 +13,10 @@ from common.model import check_embeddings_container, check_model_container
|
||||||
from common.networking import handle_request_error, run_with_request_disconnect
|
from common.networking import handle_request_error, run_with_request_disconnect
|
||||||
from common.tabby_config import config
|
from common.tabby_config import config
|
||||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||||
|
from endpoints.OAI.types.common import (
|
||||||
|
ModelItem, ModelListResponse
|
||||||
|
)
|
||||||
|
import hashlib
|
||||||
from endpoints.OAI.types.chat_completion import (
|
from endpoints.OAI.types.chat_completion import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
|
@ -172,28 +176,14 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Optional
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
class ModelItem(BaseModel):
|
|
||||||
model: str
|
|
||||||
name: str
|
|
||||||
digest: str
|
|
||||||
urls: List[int]
|
|
||||||
|
|
||||||
class ModelListResponse(BaseModel):
|
|
||||||
object: str = Field("list", description="Type of the response object.")
|
|
||||||
models: List[ModelItem]
|
|
||||||
|
|
||||||
async def fetch_models():
|
async def fetch_models():
|
||||||
models_dir = "models"
|
models_dir = "models"
|
||||||
models = []
|
models = []
|
||||||
# Iterate over the files in the models directory
|
|
||||||
if os.path.exists(models_dir):
|
if os.path.exists(models_dir):
|
||||||
for model in os.listdir(models_dir):
|
for model in os.listdir(models_dir):
|
||||||
model_path = os.path.join(models_dir, model)
|
model_path = os.path.join(models_dir, model)
|
||||||
if os.path.isdir(model_path): # Assuming each model is in its own directory
|
if os.path.isdir(model_path):
|
||||||
digest = hashlib.md5(model.encode()).hexdigest()
|
digest = hashlib.md5(model.encode()).hexdigest()
|
||||||
models.append({
|
models.append({
|
||||||
"model":f"{model}:latest",
|
"model":f"{model}:latest",
|
||||||
|
|
@ -224,18 +214,13 @@ async def dummy(request: Request):
|
||||||
dependencies=[Depends(check_api_key)]
|
dependencies=[Depends(check_api_key)]
|
||||||
)
|
)
|
||||||
async def get_all_models(request: Request) -> ModelListResponse:
|
async def get_all_models(request: Request) -> ModelListResponse:
|
||||||
print(f"Processing request for models {request.state.id}")
|
|
||||||
|
|
||||||
response = await run_with_request_disconnect(
|
response = await run_with_request_disconnect(
|
||||||
request,
|
request,
|
||||||
asyncio.create_task(fetch_models()),
|
asyncio.create_task(fetch_models()),
|
||||||
disconnect_message=f"All models fetched",
|
disconnect_message="All models fetched",
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
dependencies=[Depends(check_api_key)],
|
dependencies=[Depends(check_api_key)],
|
||||||
|
|
@ -284,7 +269,6 @@ async def chat_completion_request_ollama(
|
||||||
if data.stream and not disable_request_streaming:
|
if data.stream and not disable_request_streaming:
|
||||||
return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
|
return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
generate_task = asyncio.create_task(
|
generate_task = asyncio.create_task(
|
||||||
generate_chat_completion(prompt, data, request, model_path)
|
generate_chat_completion(prompt, data, request, model_path)
|
||||||
|
|
@ -296,4 +280,3 @@ async def chat_completion_request_ollama(
|
||||||
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Common types for OAI."""
|
"""Common types for OAI."""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, List
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||||
|
|
@ -58,3 +58,29 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
||||||
user: Optional[str] = Field(
|
user: Optional[str] = Field(
|
||||||
description="Not parsed. Only used for OAI compliance.", default=None
|
description="Not parsed. Only used for OAI compliance.", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_params(self):
|
||||||
|
# Temperature
|
||||||
|
if self.n < 1:
|
||||||
|
raise ValueError(f"n must be greater than or equal to 1. Got {self.n}")
|
||||||
|
|
||||||
|
return super().validate_params()
|
||||||
|
|
||||||
|
def to_gen_params(self):
|
||||||
|
extra_gen_params = {
|
||||||
|
"stream": self.stream,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return super().to_gen_params(**extra_gen_params)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelItem(BaseModel):
|
||||||
|
model: str
|
||||||
|
name: str
|
||||||
|
digest: str
|
||||||
|
urls: List[int]
|
||||||
|
|
||||||
|
class ModelListResponse(BaseModel):
|
||||||
|
object: str = Field("list", description="Type of the response object.")
|
||||||
|
models: List[ModelItem]
|
||||||
|
|
|
||||||
|
|
@ -177,7 +177,7 @@ def _create_stream_chunk_ollama(
|
||||||
delta=message,
|
delta=message,
|
||||||
logprobs=logprob_response,
|
logprobs=logprob_response,
|
||||||
)
|
)
|
||||||
ollama_bit = {
|
ollama_chunk = {
|
||||||
"model":model_name,
|
"model":model_name,
|
||||||
"created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z",
|
"created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z",
|
||||||
"message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none',
|
"message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none',
|
||||||
|
|
@ -185,7 +185,7 @@ def _create_stream_chunk_ollama(
|
||||||
"done_reason": choice.finish_reason,
|
"done_reason": choice.finish_reason,
|
||||||
"done": choice.finish_reason=="stop",
|
"done": choice.finish_reason=="stop",
|
||||||
}
|
}
|
||||||
return ollama_bit
|
return ollama_chunk
|
||||||
|
|
||||||
def _create_stream_chunk(
|
def _create_stream_chunk(
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
|
|
||||||
|
|
@ -176,14 +176,6 @@ async def load_inline_model(model_name: str, request: Request):
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if model.container is not None:
|
|
||||||
if model.container.model_dir.name != model_name:
|
|
||||||
logger.info(f"New model requested: {model_name}. Unloading current model.")
|
|
||||||
await model.unload_model()
|
|
||||||
elif model.container.model_loaded:
|
|
||||||
logger.info(f"Model {model_name} is already loaded.")
|
|
||||||
return
|
|
||||||
|
|
||||||
model_path = pathlib.Path(config.model.model_dir)
|
model_path = pathlib.Path(config.model.model_dir)
|
||||||
model_path = model_path / model_name
|
model_path = model_path / model_name
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue