Model: Migrate backend detection to a separate function
Seemed out of place in the common load function. In addition, rename the transformers utils signature which actually takes a directory instead of a file. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
f8070e7707
commit
cfee16905b
2 changed files with 23 additions and 18 deletions
|
|
@ -57,6 +57,24 @@ def load_progress(module, modules):
|
|||
yield module, modules
|
||||
|
||||
|
||||
async def detect_backend(model_path: pathlib.Path) -> str:
|
||||
"""Determine the appropriate backend based on model files and configuration."""
|
||||
|
||||
try:
|
||||
hf_config = await HuggingFaceConfig.from_directory(model_path)
|
||||
quant_method = hf_config.quant_method()
|
||||
|
||||
if quant_method == "exl3":
|
||||
return "exllamav3"
|
||||
else:
|
||||
return "exllamav2"
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Failed to read the model's config.json. "
|
||||
f"Please check your model directory at {model_path}."
|
||||
) from exc
|
||||
|
||||
|
||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
|
|
@ -124,24 +142,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
kwargs = {**config.model_defaults, **kwargs}
|
||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||
|
||||
# Read config.json and detect the quant method
|
||||
hf_config_path = model_path / "config.json"
|
||||
if hf_config_path.exists():
|
||||
try:
|
||||
hf_config = await HuggingFaceConfig.from_file(model_path)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Failed to read the model's config.json. "
|
||||
f"Please check your model directory at {model_path}."
|
||||
) from exc
|
||||
quant_method = hf_config.quant_method()
|
||||
if quant_method == "exl3":
|
||||
backend_name = "exllamav3"
|
||||
else:
|
||||
backend_name = "exllamav2"
|
||||
|
||||
# Create a new container and check if the right dependencies are installed
|
||||
backend_name = unwrap(kwargs.get("backend"), backend_name).lower()
|
||||
backend_name = unwrap(
|
||||
kwargs.get("backend"), await detect_backend(model_path)
|
||||
).lower()
|
||||
print(backend_name)
|
||||
container_class = _BACKEND_REGISTRY.get(backend_name)
|
||||
|
||||
if not container_class:
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class HuggingFaceConfig(BaseModel):
|
|||
quantization_config: Optional[Dict] = None
|
||||
|
||||
@classmethod
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
async def from_directory(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue