Model: Add params to current model endpoint
Grabs the current model rope params, max seq len, and the draft model if applicable. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
0f4290f05c
commit
fd9f3eac87
3 changed files with 31 additions and 6 deletions
|
|
@ -2,11 +2,18 @@ from pydantic import BaseModel, Field
|
|||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
max_seq_len: Optional[int] = 4096
|
||||
rope_scale: Optional[float] = 1.0
|
||||
rope_alpha: Optional[float] = 1.0
|
||||
draft: Optional['ModelCard'] = None
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str = "test"
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
parameters: Optional[ModelCardParameters] = None
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
|
|
@ -17,6 +24,7 @@ class DraftModelLoadRequest(BaseModel):
|
|||
draft_rope_alpha: float = 1.0
|
||||
draft_rope_scale: float = 1.0
|
||||
|
||||
# TODO: Unify this with ModelCardParams
|
||||
class ModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
max_seq_len: Optional[int] = 4096
|
||||
|
|
|
|||
25
main.py
25
main.py
|
|
@ -12,7 +12,7 @@ from generators import generate_with_semaphore
|
|||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse, ModelCardParameters
|
||||
from OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
TokenEncodeResponse,
|
||||
|
|
@ -74,7 +74,25 @@ async def list_models():
|
|||
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_current_model():
|
||||
model_name = model_container.get_model_path().name
|
||||
model_card = ModelCard(id = model_name)
|
||||
model_card = ModelCard(
|
||||
id = model_name,
|
||||
parameters = ModelCardParameters(
|
||||
rope_scale = model_container.config.scale_pos_emb,
|
||||
rope_alpha = model_container.config.scale_alpha_value,
|
||||
max_seq_len = model_container.config.max_seq_len,
|
||||
)
|
||||
)
|
||||
|
||||
if model_container.draft_config:
|
||||
draft_card = ModelCard(
|
||||
id = model_container.get_model_path(True).name,
|
||||
parameters = ModelCardParameters(
|
||||
rope_scale = model_container.draft_config.scale_pos_emb,
|
||||
rope_alpha = model_container.draft_config.scale_alpha_value,
|
||||
max_seq_len = model_container.draft_config.max_seq_len
|
||||
)
|
||||
)
|
||||
model_card.parameters.draft = draft_card
|
||||
|
||||
return model_card
|
||||
|
||||
|
|
@ -132,7 +150,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
status="finished"
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.json(ensure_ascii=False))
|
||||
yield get_sse_packet(response.json(ensure_ascii = False))
|
||||
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_container.draft_config:
|
||||
|
|
@ -345,7 +363,6 @@ if __name__ == "__main__":
|
|||
config = {}
|
||||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
if "model_name" in model_config:
|
||||
# TODO: Move this to model_container
|
||||
|
|
|
|||
4
model.py
4
model.py
|
|
@ -129,8 +129,8 @@ class ModelContainer:
|
|||
alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
return alpha
|
||||
|
||||
def get_model_path(self):
|
||||
model_path = pathlib.Path(self.config.model_dir)
|
||||
def get_model_path(self, is_draft: bool = False):
|
||||
model_path = pathlib.Path(self.draft_config.model_dir if is_draft else self.config.model_dir)
|
||||
return model_path
|
||||
|
||||
def load(self, progress_callback = None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue