Merge remote-tracking branch 'upstream/main' into HEAD
This commit is contained in:
commit
e8fcecd56a
28 changed files with 386 additions and 171 deletions
|
|
@ -137,7 +137,7 @@ async def get_version():
|
|||
async def get_extra_version():
|
||||
"""Impersonate Koboldcpp."""
|
||||
|
||||
return {"result": "KoboldCpp", "version": "1.61"}
|
||||
return {"result": "KoboldCpp", "version": "1.71"}
|
||||
|
||||
|
||||
@kai_router.get("/config/soft_prompts_list")
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from endpoints.OAI.utils.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.utils.completion import (
|
||||
generate_completion,
|
||||
load_inline_model,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||
|
|
@ -42,7 +43,7 @@ def setup():
|
|||
# Completions endpoint
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def completion_request(
|
||||
request: Request, data: CompletionRequest
|
||||
|
|
@ -53,6 +54,18 @@ async def completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if data.model:
|
||||
inline_load_task = asyncio.create_task(load_inline_model(data.model, request))
|
||||
|
||||
await run_with_request_disconnect(
|
||||
request,
|
||||
inline_load_task,
|
||||
disconnect_message=f"Model switch for generation {request.state.id} "
|
||||
+ "cancelled by user.",
|
||||
)
|
||||
else:
|
||||
await check_model_container()
|
||||
|
||||
model_path = model.container.model_dir
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
|
|
@ -85,7 +98,7 @@ async def completion_request(
|
|||
# Chat completions endpoint
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def chat_completion_request(
|
||||
request: Request, data: ChatCompletionRequest
|
||||
|
|
@ -96,6 +109,11 @@ async def chat_completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if data.model:
|
||||
await load_inline_model(data.model, request)
|
||||
else:
|
||||
await check_model_container()
|
||||
|
||||
if model.container.prompt_template is None:
|
||||
error_message = handle_request_error(
|
||||
"Chat completions are disabled because a prompt template is not set.",
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
response_prefix: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
# tools is follows the format OAI schema, functions is more flexible
|
||||
# both are available in the chat template.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
"""Completion utilities for OAI server."""
|
||||
"""
|
||||
Completion utilities for OAI server.
|
||||
|
||||
Also serves as a common module for completions and chat completions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
|
|
@ -10,12 +14,14 @@ from typing import List, Union
|
|||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.auth import get_key_permission
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
request_disconnect_loop,
|
||||
)
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import (
|
||||
CompletionRequest,
|
||||
|
|
@ -103,6 +109,50 @@ async def _stream_collector(
|
|||
await gen_queue.put(e)
|
||||
|
||||
|
||||
async def load_inline_model(model_name: str, request: Request):
|
||||
"""Load a model from the data.model parameter"""
|
||||
|
||||
# Return if the model container already exists and the model is fully loaded
|
||||
if (
|
||||
model.container
|
||||
and model.container.model_dir.name == model_name
|
||||
and model.container.model_loaded
|
||||
):
|
||||
return
|
||||
|
||||
# Inline model loading isn't enabled or the user isn't an admin
|
||||
if not get_key_permission(request) == "admin":
|
||||
error_message = handle_request_error(
|
||||
f"Unable to switch model to {model_name} because "
|
||||
+ "an admin key isn't provided",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(401, error_message)
|
||||
|
||||
if not unwrap(config.model.get("inline_model_loading"), False):
|
||||
logger.warning(
|
||||
f"Unable to switch model to {model_name} because "
|
||||
'"inline_model_loading" is not True in config.yml.'
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = model_path / model_name
|
||||
|
||||
# Model path doesn't exist
|
||||
if not model_path.exists():
|
||||
logger.warning(
|
||||
f"Could not find model path {str(model_path)}. Skipping inline model load."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Load the model
|
||||
await model.load_model(model_path)
|
||||
|
||||
|
||||
async def stream_generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
|||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(is_draft=True)
|
||||
models = await get_current_model_list(model_type="draft")
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -441,7 +441,8 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
template_path = pathlib.Path("templates") / data.name
|
||||
model.container.prompt_template = await PromptTemplate.from_file(template_path)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
|
|
@ -490,7 +491,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
await sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from time import time
|
|||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from common.config_models import logging_config_model
|
||||
from common.model import get_config_default
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
|
|
@ -51,23 +52,13 @@ class DraftModelLoadRequest(BaseModel):
|
|||
draft_model_name: str
|
||||
|
||||
# Config arguments
|
||||
draft_rope_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_scale", model_type="draft"
|
||||
)
|
||||
)
|
||||
draft_rope_scale: Optional[float] = None
|
||||
draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
description='Automatically calculated if set to "auto"',
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_alpha", model_type="draft"
|
||||
),
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_cache_mode", model_type="draft"
|
||||
)
|
||||
)
|
||||
draft_cache_mode: Optional[str] = None
|
||||
|
||||
|
||||
class ModelLoadRequest(BaseModel):
|
||||
|
|
@ -78,62 +69,45 @@ class ModelLoadRequest(BaseModel):
|
|||
|
||||
# Config arguments
|
||||
|
||||
# Max seq len is fetched from config.json of the model by default
|
||||
max_seq_len: Optional[int] = Field(
|
||||
description="Leave this blank to use the model's base sequence length",
|
||||
default_factory=lambda: get_config_default("max_seq_len"),
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
override_base_seq_len: Optional[int] = Field(
|
||||
description=(
|
||||
"Overrides the model's base sequence length. " "Leave blank if unsure"
|
||||
),
|
||||
default_factory=lambda: get_config_default("override_base_seq_len"),
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
cache_size: Optional[int] = Field(
|
||||
description=("Number in tokens, must be greater than or equal to max_seq_len"),
|
||||
default_factory=lambda: get_config_default("cache_size"),
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
tensor_parallel: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("tensor_parallel")
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split_auto")
|
||||
)
|
||||
autosplit_reserve: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("autosplit_reserve")
|
||||
)
|
||||
tensor_parallel: Optional[bool] = None
|
||||
gpu_split_auto: Optional[bool] = None
|
||||
autosplit_reserve: Optional[List[float]] = None
|
||||
gpu_split: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split"),
|
||||
default=None,
|
||||
examples=[[24.0, 20.0]],
|
||||
)
|
||||
rope_scale: Optional[float] = Field(
|
||||
description="Automatically pulled from the model's config if not present",
|
||||
default_factory=lambda: get_config_default("rope_scale"),
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
description='Automatically calculated if set to "auto"',
|
||||
default_factory=lambda: get_config_default("rope_alpha"),
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("cache_mode")
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("chunk_size")
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("prompt_template")
|
||||
)
|
||||
num_experts_per_token: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("num_experts_per_token")
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("fasttensors")
|
||||
)
|
||||
cache_mode: Optional[str] = None
|
||||
chunk_size: Optional[int] = None
|
||||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
fasttensors: Optional[bool] = None
|
||||
|
||||
# Non-config arguments
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
|
|
@ -142,9 +116,11 @@ class ModelLoadRequest(BaseModel):
|
|||
|
||||
class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
# Set default from the config
|
||||
embeddings_device: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"embeddings_device", model_type="embedding"
|
||||
default_factory=lambda: unwrap(
|
||||
config.embeddings.get("embeddings_device"), "cpu"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue