API: Move to ModelManager
This is a shared module which manages the model container and provides extra utility functions around it to help slim down the API. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8b46282aef
commit
b373b25235
5 changed files with 178 additions and 143 deletions
|
|
@ -18,6 +18,8 @@ class ModelCardParameters(BaseModel):
|
|||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
use_cfg: Optional[bool] = None
|
||||
|
||||
# Draft is another model, so include it in the card params
|
||||
draft: Optional["ModelCard"] = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -293,6 +293,32 @@ class ExllamaV2Container:
|
|||
)
|
||||
return model_path
|
||||
|
||||
def get_model_parameters(self):
|
||||
model_params = {
|
||||
"name": self.get_model_path().name,
|
||||
"rope_scale": self.config.scale_pos_emb,
|
||||
"rope_alpha": self.config.scale_alpha_value,
|
||||
"max_seq_len": self.config.max_seq_len,
|
||||
"cache_mode": self.cache_mode,
|
||||
"num_experts_per_token": self.config.num_experts_per_token,
|
||||
"use_cfg": self.use_cfg,
|
||||
"prompt_template": self.prompt_template.name
|
||||
if self.prompt_template
|
||||
else None,
|
||||
}
|
||||
|
||||
if self.draft_config:
|
||||
draft_model_params = {
|
||||
"name": self.get_model_path(is_draft=True).name,
|
||||
"rope_scale": self.draft_config.scale_pos_emb,
|
||||
"rope_alpha": self.draft_config.scale_alpha_value,
|
||||
"max_seq_len": self.draft_config.max_seq_len,
|
||||
}
|
||||
|
||||
model_params["draft"] = draft_model_params
|
||||
|
||||
return model_params
|
||||
|
||||
def load(self, progress_callback=None):
|
||||
"""
|
||||
Load model
|
||||
|
|
|
|||
75
common/model.py
Normal file
75
common/model.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
Manages the storage and utility of model containers.
|
||||
|
||||
Containers exist as a common interface for backends.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
from common.logger import get_loading_progress_bar
|
||||
from common.utils import load_progress
|
||||
|
||||
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
|
||||
|
||||
async def unload_model():
|
||||
"""Unloads a model"""
|
||||
global container
|
||||
|
||||
container.unload()
|
||||
container = None
|
||||
|
||||
|
||||
async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
"""Generator to load a model"""
|
||||
global container
|
||||
|
||||
# Check if the model is already loaded
|
||||
if container and container.model:
|
||||
loaded_model_name = container.get_model_path().name
|
||||
|
||||
if loaded_model_name == model_path.name:
|
||||
raise ValueError(
|
||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
# Unload the existing model
|
||||
if container and container.model:
|
||||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
load_status = container.load_gen(load_progress)
|
||||
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
|
||||
try:
|
||||
for module, modules in load_status:
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task)
|
||||
if module == modules:
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_type == "draft":
|
||||
model_type = "model"
|
||||
else:
|
||||
progress.stop()
|
||||
|
||||
yield module, modules, model_type
|
||||
finally:
|
||||
progress.stop()
|
||||
|
||||
|
||||
async def load_model(model_path: pathlib.Path, **kwargs):
|
||||
async for _, _, _ in load_model_gen(model_path, **kwargs):
|
||||
pass
|
||||
215
main.py
215
main.py
|
|
@ -1,4 +1,5 @@
|
|||
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import signal
|
||||
|
|
@ -17,10 +18,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from functools import partial
|
||||
from loguru import logger
|
||||
|
||||
from common.logger import UVICORN_LOG_CONFIG, setup_logger, get_loading_progress_bar
|
||||
from common.logger import UVICORN_LOG_CONFIG, setup_logger
|
||||
import common.gen_logging as gen_logging
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
from backends.exllamav2.utils import check_exllama_version
|
||||
from common import model
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
from common.auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from common.config import (
|
||||
|
|
@ -52,7 +53,6 @@ from common.templating import (
|
|||
from common.utils import (
|
||||
get_generator_error,
|
||||
handle_request_error,
|
||||
load_progress,
|
||||
is_port_in_use,
|
||||
unwrap,
|
||||
)
|
||||
|
|
@ -90,24 +90,6 @@ app = FastAPI(
|
|||
),
|
||||
)
|
||||
|
||||
# Globally scoped variables. Undefined until initalized in main
|
||||
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
||||
|
||||
|
||||
async def _check_model_container():
|
||||
"""Checks if a model isn't loading or loaded."""
|
||||
|
||||
if MODEL_CONTAINER is None or not (
|
||||
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
|
||||
):
|
||||
error_message = handle_request_error(
|
||||
"No models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
# ALlow CORS requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
|
@ -118,6 +100,20 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||
|
||||
if model.container is None or not (
|
||||
model.container.model_is_loading or model.container.model_loaded
|
||||
):
|
||||
error_message = handle_request_error(
|
||||
"No models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
# Model list endpoint
|
||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
|
|
@ -139,35 +135,33 @@ async def list_models():
|
|||
# Currently loaded model endpoint
|
||||
@app.get(
|
||||
"/v1/model",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_current_model():
|
||||
"""Returns the currently loaded model."""
|
||||
model_name = MODEL_CONTAINER.get_model_path().name
|
||||
prompt_template = MODEL_CONTAINER.prompt_template
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
||||
if draft_model_params:
|
||||
model_params["draft"] = ModelCard(
|
||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
else:
|
||||
draft_model_params = None
|
||||
|
||||
model_card = ModelCard(
|
||||
id=model_name,
|
||||
parameters=ModelCardParameters(
|
||||
rope_scale=MODEL_CONTAINER.config.scale_pos_emb,
|
||||
rope_alpha=MODEL_CONTAINER.config.scale_alpha_value,
|
||||
max_seq_len=MODEL_CONTAINER.config.max_seq_len,
|
||||
cache_mode=MODEL_CONTAINER.cache_mode,
|
||||
prompt_template=prompt_template.name if prompt_template else None,
|
||||
num_experts_per_token=MODEL_CONTAINER.config.num_experts_per_token,
|
||||
use_cfg=MODEL_CONTAINER.use_cfg,
|
||||
),
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=gen_logging.PREFERENCES,
|
||||
)
|
||||
|
||||
if MODEL_CONTAINER.draft_config:
|
||||
if draft_model_params:
|
||||
draft_card = ModelCard(
|
||||
id=MODEL_CONTAINER.get_model_path(True).name,
|
||||
parameters=ModelCardParameters(
|
||||
rope_scale=MODEL_CONTAINER.draft_config.scale_pos_emb,
|
||||
rope_alpha=MODEL_CONTAINER.draft_config.scale_alpha_value,
|
||||
max_seq_len=MODEL_CONTAINER.draft_config.max_seq_len,
|
||||
),
|
||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
|
||||
model_card.parameters.draft = draft_card
|
||||
|
||||
return model_card
|
||||
|
|
@ -211,35 +205,12 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
|
||||
# Check if the model is already loaded
|
||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||
loaded_model_name = MODEL_CONTAINER.get_model_path().name
|
||||
|
||||
if loaded_model_name == data.name:
|
||||
raise HTTPException(
|
||||
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
async def generator():
|
||||
"""Generator for the loading process."""
|
||||
global MODEL_CONTAINER
|
||||
|
||||
# Unload the existing model
|
||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||
|
||||
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
|
||||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||
"""Request generation wrapper for the loading process."""
|
||||
|
||||
load_status = model.load_model_gen(model_path, **load_data)
|
||||
try:
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
|
||||
for module, modules in load_status:
|
||||
# Get out if the request gets disconnected
|
||||
async for module, modules, model_type in load_status:
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error(
|
||||
|
|
@ -248,13 +219,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
)
|
||||
return
|
||||
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task)
|
||||
|
||||
if module != 0:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
|
|
@ -273,13 +238,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_type == "draft":
|
||||
model_type = "model"
|
||||
else:
|
||||
progress.stop()
|
||||
|
||||
except CancelledError:
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
|
|
@ -287,8 +245,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
)
|
||||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
finally:
|
||||
progress.stop()
|
||||
|
||||
# Determine whether to use or skip the queue
|
||||
if data.skip_queue:
|
||||
|
|
@ -306,14 +262,11 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
# Unload model endpoint
|
||||
@app.post(
|
||||
"/v1/model/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_model():
|
||||
"""Unloads the currently loaded model."""
|
||||
global MODEL_CONTAINER
|
||||
|
||||
MODEL_CONTAINER.unload()
|
||||
MODEL_CONTAINER = None
|
||||
await model.unload_model()
|
||||
|
||||
|
||||
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
|
|
@ -326,7 +279,7 @@ async def get_templates():
|
|||
|
||||
@app.post(
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template"""
|
||||
|
|
@ -335,19 +288,19 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
|
||||
try:
|
||||
template = get_template_from_file(data.name)
|
||||
MODEL_CONTAINER.prompt_template = template
|
||||
model.container.prompt_template = template
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(400, "Template does not exist. Check the name?") from e
|
||||
|
||||
|
||||
@app.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_template():
|
||||
"""Unloads the currently selected template"""
|
||||
|
||||
MODEL_CONTAINER.prompt_template = None
|
||||
model.container.prompt_template = None
|
||||
|
||||
|
||||
# Sampler override endpoints
|
||||
|
|
@ -405,7 +358,7 @@ async def get_all_loras():
|
|||
# Currently loaded loras endpoint
|
||||
@app.get(
|
||||
"/v1/lora",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_active_loras():
|
||||
"""Returns the currently loaded loras."""
|
||||
|
|
@ -416,7 +369,7 @@ async def get_active_loras():
|
|||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
),
|
||||
MODEL_CONTAINER.active_loras,
|
||||
model.container.active_loras,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -427,7 +380,7 @@ async def get_active_loras():
|
|||
# Load lora endpoint
|
||||
@app.post(
|
||||
"/v1/lora/load",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def load_lora(data: LoraLoadRequest):
|
||||
"""Loads a LoRA into the model container."""
|
||||
|
|
@ -443,10 +396,10 @@ async def load_lora(data: LoraLoadRequest):
|
|||
|
||||
# Clean-up existing loras if present
|
||||
def load_loras_internal():
|
||||
if len(MODEL_CONTAINER.active_loras) > 0:
|
||||
if len(model.container.active_loras) > 0:
|
||||
unload_loras()
|
||||
|
||||
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
|
||||
result = model.container.load_loras(lora_dir, **data.model_dump())
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(result.get("success"), []),
|
||||
failure=unwrap(result.get("failure"), []),
|
||||
|
|
@ -468,21 +421,21 @@ async def load_lora(data: LoraLoadRequest):
|
|||
# Unload lora endpoint
|
||||
@app.post(
|
||||
"/v1/lora/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
MODEL_CONTAINER.unload(loras_only=True)
|
||||
model.container.unload(loras_only=True)
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@app.post(
|
||||
"/v1/token/encode",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest):
|
||||
"""Encodes a string into tokens."""
|
||||
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
|
||||
raw_tokens = model.container.encode_tokens(data.text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
|
|
@ -492,11 +445,11 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
# Decode tokens endpoint
|
||||
@app.post(
|
||||
"/v1/token/decode",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
"""Decodes tokens into a string."""
|
||||
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
|
||||
message = model.container.decode_tokens(data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||
|
||||
return response
|
||||
|
|
@ -505,11 +458,11 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||
# Completions endpoint
|
||||
@app.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
"""Generates a completion from a prompt."""
|
||||
model_path = MODEL_CONTAINER.get_model_path()
|
||||
model_path = model.container.get_model_path()
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
|
@ -522,7 +475,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
|
||||
async def generator():
|
||||
try:
|
||||
new_generation = MODEL_CONTAINER.generate_gen(
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
|
|
@ -549,7 +502,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
generation = await call_with_semaphore(
|
||||
partial(
|
||||
run_in_threadpool,
|
||||
MODEL_CONTAINER.generate,
|
||||
model.container.generate,
|
||||
data.prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
|
|
@ -570,30 +523,31 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
# Chat completions endpoint
|
||||
@app.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||
"""Generates a chat completion from a prompt."""
|
||||
|
||||
if MODEL_CONTAINER.prompt_template is None:
|
||||
if model.container.prompt_template is None:
|
||||
raise HTTPException(
|
||||
422,
|
||||
"This endpoint is disabled because a prompt template is not set.",
|
||||
)
|
||||
|
||||
model_path = MODEL_CONTAINER.get_model_path()
|
||||
model_path = model.container.get_model_path()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
try:
|
||||
special_tokens_dict = MODEL_CONTAINER.get_special_tokens(
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True),
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
prompt = get_prompt_from_template(
|
||||
data.messages,
|
||||
MODEL_CONTAINER.prompt_template,
|
||||
model.container.prompt_template,
|
||||
data.add_generation_prompt,
|
||||
special_tokens_dict,
|
||||
)
|
||||
|
|
@ -601,7 +555,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
raise HTTPException(
|
||||
400,
|
||||
"Could not find a Conversation from prompt template "
|
||||
f"'{MODEL_CONTAINER.prompt_template.name}'. "
|
||||
f"'{model.container.prompt_template.name}'. "
|
||||
"Check your spelling?",
|
||||
) from exc
|
||||
except TemplateError as exc:
|
||||
|
|
@ -620,7 +574,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
async def generator():
|
||||
"""Generator for the generation process."""
|
||||
try:
|
||||
new_generation = MODEL_CONTAINER.generate_gen(
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
|
|
@ -653,7 +607,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
generation = await call_with_semaphore(
|
||||
partial(
|
||||
run_in_threadpool,
|
||||
MODEL_CONTAINER.generate,
|
||||
model.container.generate,
|
||||
prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
|
|
@ -692,11 +646,9 @@ def signal_handler(*_):
|
|||
sys.exit(0)
|
||||
|
||||
|
||||
def entrypoint(args: Optional[dict] = None):
|
||||
async def entrypoint(args: Optional[dict] = None):
|
||||
"""Entry function for program startup"""
|
||||
|
||||
global MODEL_CONTAINER
|
||||
|
||||
setup_logger()
|
||||
|
||||
# Set up signal aborting
|
||||
|
|
@ -782,34 +734,13 @@ def entrypoint(args: Optional[dict] = None):
|
|||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / model_name
|
||||
|
||||
MODEL_CONTAINER = ExllamaV2Container(
|
||||
model_path.resolve(), False, **model_config
|
||||
)
|
||||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
|
||||
|
||||
for module, modules in load_status:
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task, 1)
|
||||
|
||||
if module == modules:
|
||||
if model_type == "draft":
|
||||
model_type = "model"
|
||||
else:
|
||||
progress.stop()
|
||||
await model.load_model(model_path.resolve(), **model_config)
|
||||
|
||||
# Load loras after loading the model
|
||||
lora_config = get_lora_config()
|
||||
if lora_config.get("loras"):
|
||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)
|
||||
model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
# TODO: Replace this with abortables, async via producer consumer, or something else
|
||||
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
|
||||
|
|
@ -821,4 +752,4 @@ def entrypoint(args: Optional[dict] = None):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
entrypoint()
|
||||
asyncio.run(entrypoint())
|
||||
|
|
|
|||
3
start.py
3
start.py
|
|
@ -1,4 +1,5 @@
|
|||
"""Utility to automatically upgrade and start the API"""
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
|
|
@ -66,4 +67,4 @@ if __name__ == "__main__":
|
|||
# Import entrypoint after installing all requirements
|
||||
from main import entrypoint
|
||||
|
||||
entrypoint(convert_args_to_dict(args, parser))
|
||||
asyncio.run(entrypoint(convert_args_to_dict(args, parser)))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue