diff --git a/common/model.py b/common/model.py index 161c7dc..67af212 100644 --- a/common/model.py +++ b/common/model.py @@ -57,6 +57,24 @@ def load_progress(module, modules): yield module, modules +async def detect_backend(model_path: pathlib.Path) -> str: + """Determine the appropriate backend based on model files and configuration.""" + + try: + hf_config = await HuggingFaceConfig.from_directory(model_path) + quant_method = hf_config.quant_method() + + if quant_method == "exl3": + return "exllamav3" + else: + return "exllamav2" + except Exception as exc: + raise ValueError( + "Failed to read the model's config.json. " + f"Please check your model directory at {model_path}." + ) from exc + + async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): """Sets overrides from a model folder's config yaml.""" @@ -124,24 +142,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): kwargs = {**config.model_defaults, **kwargs} kwargs = await apply_inline_overrides(model_path, **kwargs) - # Read config.json and detect the quant method - hf_config_path = model_path / "config.json" - if hf_config_path.exists(): - try: - hf_config = await HuggingFaceConfig.from_file(model_path) - except Exception as exc: - raise ValueError( - "Failed to read the model's config.json. " - f"Please check your model directory at {model_path}." - ) from exc - quant_method = hf_config.quant_method() - if quant_method == "exl3": - backend_name = "exllamav3" - else: - backend_name = "exllamav2" - # Create a new container and check if the right dependencies are installed - backend_name = unwrap(kwargs.get("backend"), backend_name).lower() + backend_name = unwrap( + kwargs.get("backend"), await detect_backend(model_path) + ).lower() + print(backend_name) container_class = _BACKEND_REGISTRY.get(backend_name) if not container_class: diff --git a/common/transformers_utils.py b/common/transformers_utils.py index cd79f00..d1e5ac1 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -45,7 +45,7 @@ class HuggingFaceConfig(BaseModel): quantization_config: Optional[Dict] = None @classmethod - async def from_file(cls, model_directory: pathlib.Path): + async def from_directory(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json"