Model: Migrate backend detection to a separate function
Seemed out of place in the common load function. In addition, rename the transformers utils signature which actually takes a directory instead of a file. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
f8070e7707
commit
cfee16905b
2 changed files with 23 additions and 18 deletions
|
|
@ -57,6 +57,24 @@ def load_progress(module, modules):
|
||||||
yield module, modules
|
yield module, modules
|
||||||
|
|
||||||
|
|
||||||
|
async def detect_backend(model_path: pathlib.Path) -> str:
|
||||||
|
"""Determine the appropriate backend based on model files and configuration."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
hf_config = await HuggingFaceConfig.from_directory(model_path)
|
||||||
|
quant_method = hf_config.quant_method()
|
||||||
|
|
||||||
|
if quant_method == "exl3":
|
||||||
|
return "exllamav3"
|
||||||
|
else:
|
||||||
|
return "exllamav2"
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to read the model's config.json. "
|
||||||
|
f"Please check your model directory at {model_path}."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||||
"""Sets overrides from a model folder's config yaml."""
|
"""Sets overrides from a model folder's config yaml."""
|
||||||
|
|
||||||
|
|
@ -124,24 +142,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||||
kwargs = {**config.model_defaults, **kwargs}
|
kwargs = {**config.model_defaults, **kwargs}
|
||||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||||
|
|
||||||
# Read config.json and detect the quant method
|
|
||||||
hf_config_path = model_path / "config.json"
|
|
||||||
if hf_config_path.exists():
|
|
||||||
try:
|
|
||||||
hf_config = await HuggingFaceConfig.from_file(model_path)
|
|
||||||
except Exception as exc:
|
|
||||||
raise ValueError(
|
|
||||||
"Failed to read the model's config.json. "
|
|
||||||
f"Please check your model directory at {model_path}."
|
|
||||||
) from exc
|
|
||||||
quant_method = hf_config.quant_method()
|
|
||||||
if quant_method == "exl3":
|
|
||||||
backend_name = "exllamav3"
|
|
||||||
else:
|
|
||||||
backend_name = "exllamav2"
|
|
||||||
|
|
||||||
# Create a new container and check if the right dependencies are installed
|
# Create a new container and check if the right dependencies are installed
|
||||||
backend_name = unwrap(kwargs.get("backend"), backend_name).lower()
|
backend_name = unwrap(
|
||||||
|
kwargs.get("backend"), await detect_backend(model_path)
|
||||||
|
).lower()
|
||||||
|
print(backend_name)
|
||||||
container_class = _BACKEND_REGISTRY.get(backend_name)
|
container_class = _BACKEND_REGISTRY.get(backend_name)
|
||||||
|
|
||||||
if not container_class:
|
if not container_class:
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class HuggingFaceConfig(BaseModel):
|
||||||
quantization_config: Optional[Dict] = None
|
quantization_config: Optional[Dict] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_file(cls, model_directory: pathlib.Path):
|
async def from_directory(cls, model_directory: pathlib.Path):
|
||||||
"""Create an instance from a generation config file."""
|
"""Create an instance from a generation config file."""
|
||||||
|
|
||||||
hf_config_path = model_directory / "config.json"
|
hf_config_path = model_directory / "config.json"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue