diff --git a/common/model.py b/common/model.py index a26aa17..16138a3 100644 --- a/common/model.py +++ b/common/model.py @@ -68,10 +68,14 @@ def detect_backend(hf_model: HFModel) -> str: return "exllamav2" -async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): - """Sets overrides from a model folder's config yaml.""" +async def apply_load_defaults(model_path: pathlib.Path, **kwargs): + """ + Applies model load overrides. + Sources are from inline config and use_as_default. + Currently agnostic due to different schemas for API and config. + """ - override_config_path = model_dir / "tabby_config.yml" + override_config_path = model_path / "tabby_config.yml" if not override_config_path.exists(): return kwargs @@ -88,20 +92,23 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): yaml = YAML(typ="safe") inline_config = unwrap(yaml.load(contents), {}) - # Check for inline model overrides + # Check for inline model overrides and merge config defaults model_inline_config = unwrap(inline_config.get("model"), {}) if model_inline_config: - overrides = {**model_inline_config} + overrides = {**model_inline_config, **config.model_defaults} else: logger.warning( "Cannot find inline model overrides. " 'Make sure they are nested under a "model:" key' ) - # Merge draft overrides beforehand + # Merge draft overrides beforehand and merge config defaults draft_inline_config = unwrap(inline_config.get("draft_model"), {}) if draft_inline_config: - overrides["draft_model"] = {**draft_inline_config} + overrides["draft_model"] = { + **draft_inline_config, + **config.draft_model_defaults, + } # Merge the override and model kwargs # No need to preserve the original overrides dict @@ -143,8 +150,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Merge with config and inline defaults # TODO: Figure out a way to do this with Pydantic validation # and ModelLoadRequest. Pydantic doesn't have async validators - kwargs = {**config.model_defaults, **kwargs} - kwargs = await apply_inline_overrides(model_path, **kwargs) + kwargs = await apply_load_defaults(model_path, **kwargs) # Fetch the extra HF configuration options hf_model = await HFModel.from_directory(model_path) diff --git a/common/tabby_config.py b/common/tabby_config.py index 865ae28..9c4cc5d 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -20,6 +20,7 @@ class TabbyConfig(TabbyConfigModel): # Persistent defaults # TODO: make this pydantic? model_defaults: dict = {} + draft_model_defaults: dict = {} def load(self, arguments: Optional[dict] = None): """Synchronously loads the global application config""" @@ -50,7 +51,7 @@ class TabbyConfig(TabbyConfigModel): if hasattr(self.model, field): self.model_defaults[field] = getattr(config.model, field) elif hasattr(self.draft_model, field): - self.model_defaults[field] = getattr(config.draft_model, field) + self.draft_model_defaults[field] = getattr(config.draft_model, field) else: logger.error( f"invalid item {field} in config option `model.use_as_default`" diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 6217475..800c7d9 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -193,18 +193,6 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: model_path = pathlib.Path(config.model.model_dir) model_path = model_path / data.model_name - draft_model_path = None - if data.draft_model: - if not data.draft_model.draft_model_name: - error_message = handle_request_error( - "Could not find the draft model name for model load.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - draft_model_path = config.draft_model.draft_model_dir - if not model_path.exists(): error_message = handle_request_error( "Could not find the model path for load. Check model name or config.yml?", @@ -213,9 +201,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: raise HTTPException(400, error_message) - return EventSourceResponse( - stream_model_load(data, model_path, draft_model_path), ping=maxsize - ) + return EventSourceResponse(stream_model_load(data, model_path), ping=maxsize) # Unload model endpoint diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 0805e00..20c9433 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -77,16 +77,16 @@ def get_dummy_models(): async def stream_model_load( data: ModelLoadRequest, model_path: pathlib.Path, - draft_model_path: str, ): """Request generation wrapper for the loading process.""" # Get trimmed load data load_data = data.model_dump(exclude_none=True) - # Set the draft model path if it exists - if draft_model_path: - load_data["draft_model"]["draft_model_dir"] = draft_model_path + # Set the draft model directory + load_data.setdefault("draft_model", {})["draft_model_dir"] = ( + config.draft_model.draft_model_dir + ) load_status = model.load_model_gen( model_path, skip_wait=data.skip_queue, **load_data