diff --git a/OAI/types/model.py b/OAI/types/model.py index cc08742..f82daf2 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -12,6 +12,10 @@ class ModelList(BaseModel): object: str = "list" data: List[ModelCard] = Field(default_factory=list) +class DraftModelLoadRequest(BaseModel): + draft_model_name: str + draft_rope_alpha: float = 1.0 + class ModelLoadRequest(BaseModel): name: str max_seq_len: Optional[int] = 4096 @@ -21,8 +25,10 @@ class ModelLoadRequest(BaseModel): rope_alpha: Optional[float] = 1.0 no_flash_attention: Optional[bool] = False low_mem: Optional[bool] = False + draft: Optional[DraftModelLoadRequest] = None class ModelLoadResponse(BaseModel): + model_type: str = "model" module: int modules: int status: str diff --git a/main.py b/main.py index 747c120..5634561 100644 --- a/main.py +++ b/main.py @@ -73,18 +73,28 @@ async def load_model(data: ModelLoadRequest): if model_container and model_container.model: raise HTTPException(400, "A model is already loaded! Please unload it first.") + if not data.name: + raise HTTPException(400, "model_name not found.") + + model_config = config.get("model") or {} + model_path = pathlib.Path(model_config.get("model_dir") or "models") + model_path = model_path / data.name + + load_data = data.dict() + if data.draft and "draft" in model_config: + draft_config = model_config.get("draft") or {} + + if not data.draft.draft_model_name: + raise HTTPException(400, "draft_model_name was not found inside the draft object.") + + load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models" + def generator(): global model_container - model_config = config.get("model") or {} - if "model_dir" in model_config: - model_path = pathlib.Path(model_config["model_dir"]) - else: - model_path = pathlib.Path("models") + model_container = ModelContainer(model_path.resolve(), False, **load_data) + model_type = "draft" if model_container.draft_enabled else "model" - model_path = model_path / data.name - - model_container = ModelContainer(model_path.resolve(), False, **data.dict()) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: @@ -92,10 +102,23 @@ async def load_model(data: ModelLoadRequest): elif module == modules: loading_bar.next() loading_bar.finish() + + response = ModelLoadResponse( + model_type=model_type, + module=module, + modules=modules, + status="finished" + ) + + yield response.json(ensure_ascii=False) + + if model_container.draft_enabled: + model_type = "model" else: loading_bar.next() response = ModelLoadResponse( + model_type=model_type, module=module, modules=modules, status="processing" @@ -103,14 +126,6 @@ async def load_model(data: ModelLoadRequest): yield response.json(ensure_ascii=False) - response = ModelLoadResponse( - module=module, - modules=modules, - status="finished" - ) - - yield response.json(ensure_ascii=False) - return EventSourceResponse(generator()) # Unload model endpoint