Model: Migrate backend detection to a separate function

Seemed out of place in the common load function. In addition, rename
the transformers utils signature which actually takes a directory
instead of a file.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-08 23:42:39 -04:00
parent f8070e7707
commit cfee16905b
2 changed files with 23 additions and 18 deletions

View file

@ -57,6 +57,24 @@ def load_progress(module, modules):
yield module, modules
async def detect_backend(model_path: pathlib.Path) -> str:
"""Determine the appropriate backend based on model files and configuration."""
try:
hf_config = await HuggingFaceConfig.from_directory(model_path)
quant_method = hf_config.quant_method()
if quant_method == "exl3":
return "exllamav3"
else:
return "exllamav2"
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
"""Sets overrides from a model folder's config yaml."""
@ -124,24 +142,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs)
# Read config.json and detect the quant method
hf_config_path = model_path / "config.json"
if hf_config_path.exists():
try:
hf_config = await HuggingFaceConfig.from_file(model_path)
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
quant_method = hf_config.quant_method()
if quant_method == "exl3":
backend_name = "exllamav3"
else:
backend_name = "exllamav2"
# Create a new container and check if the right dependencies are installed
backend_name = unwrap(kwargs.get("backend"), backend_name).lower()
backend_name = unwrap(
kwargs.get("backend"), await detect_backend(model_path)
).lower()
print(backend_name)
container_class = _BACKEND_REGISTRY.get(backend_name)
if not container_class:

View file

@ -45,7 +45,7 @@ class HuggingFaceConfig(BaseModel):
quantization_config: Optional[Dict] = None
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"