API: Add draft model support

Models can be loaded with a child object called "draft" in the POST
request. Again, models need to be located within the draft model dir
to get loaded.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-19 00:32:25 -05:00
parent 6b9af58cc1
commit f47919b1d3
2 changed files with 37 additions and 16 deletions

View file

@ -12,6 +12,10 @@ class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class DraftModelLoadRequest(BaseModel):
draft_model_name: str
draft_rope_alpha: float = 1.0
class ModelLoadRequest(BaseModel):
name: str
max_seq_len: Optional[int] = 4096
@ -21,8 +25,10 @@ class ModelLoadRequest(BaseModel):
rope_alpha: Optional[float] = 1.0
no_flash_attention: Optional[bool] = False
low_mem: Optional[bool] = False
draft: Optional[DraftModelLoadRequest] = None
class ModelLoadResponse(BaseModel):
model_type: str = "model"
module: int
modules: int
status: str

47
main.py
View file

@ -73,18 +73,28 @@ async def load_model(data: ModelLoadRequest):
if model_container and model_container.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.")
if not data.name:
raise HTTPException(400, "model_name not found.")
model_config = config.get("model") or {}
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / data.name
load_data = data.dict()
if data.draft and "draft" in model_config:
draft_config = model_config.get("draft") or {}
if not data.draft.draft_model_name:
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
def generator():
global model_container
model_config = config.get("model") or {}
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
model_path = pathlib.Path("models")
model_container = ModelContainer(model_path.resolve(), False, **load_data)
model_type = "draft" if model_container.draft_enabled else "model"
model_path = model_path / data.name
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@ -92,10 +102,23 @@ async def load_model(data: ModelLoadRequest):
elif module == modules:
loading_bar.next()
loading_bar.finish()
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="finished"
)
yield response.json(ensure_ascii=False)
if model_container.draft_enabled:
model_type = "model"
else:
loading_bar.next()
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing"
@ -103,14 +126,6 @@ async def load_model(data: ModelLoadRequest):
yield response.json(ensure_ascii=False)
response = ModelLoadResponse(
module=module,
modules=modules,
status="finished"
)
yield response.json(ensure_ascii=False)
return EventSourceResponse(generator())
# Unload model endpoint