Model: Create universal HFModel class
The HFModel class serves to coalesce all config files that contain random keys which are required for model usage. Adding this base class allows us to expand as HuggingFace randomly changes their JSON schemas over time, reducing the brunt that backend devs need to feel when their next model isn't supported. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
7900b72848
commit
390daeb92f
5 changed files with 149 additions and 127 deletions
|
|
@ -17,7 +17,7 @@ from common.logger import get_loading_progress_bar
|
|||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.optional_dependencies import dependencies
|
||||
from common.transformers_utils import HuggingFaceConfig
|
||||
from common.transformers_utils import HFModel
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global variables for model container
|
||||
|
|
@ -57,22 +57,15 @@ def load_progress(module, modules):
|
|||
yield module, modules
|
||||
|
||||
|
||||
async def detect_backend(model_path: pathlib.Path) -> str:
|
||||
def detect_backend(hf_model: HFModel) -> str:
|
||||
"""Determine the appropriate backend based on model files and configuration."""
|
||||
|
||||
try:
|
||||
hf_config = await HuggingFaceConfig.from_directory(model_path)
|
||||
quant_method = hf_config.quant_method()
|
||||
quant_method = hf_model.quant_method()
|
||||
|
||||
if quant_method == "exl3":
|
||||
return "exllamav3"
|
||||
else:
|
||||
return "exllamav2"
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Failed to read the model's config.json. "
|
||||
f"Please check your model directory at {model_path}."
|
||||
) from exc
|
||||
if quant_method == "exl3":
|
||||
return "exllamav3"
|
||||
else:
|
||||
return "exllamav2"
|
||||
|
||||
|
||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||
|
|
@ -142,28 +135,29 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
kwargs = {**config.model_defaults, **kwargs}
|
||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||
|
||||
# Fetch the extra HF configuration options
|
||||
hf_model = await HFModel.from_directory(model_path)
|
||||
|
||||
# Create a new container and check if the right dependencies are installed
|
||||
backend_name = unwrap(
|
||||
kwargs.get("backend"), await detect_backend(model_path)
|
||||
).lower()
|
||||
container_class = _BACKEND_REGISTRY.get(backend_name)
|
||||
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
|
||||
container_class = _BACKEND_REGISTRY.get(backend)
|
||||
|
||||
if not container_class:
|
||||
available_backends = list(_BACKEND_REGISTRY.keys())
|
||||
if backend_name in available_backends:
|
||||
if backend in available_backends:
|
||||
raise ValueError(
|
||||
f"Backend '{backend_name}' selected, but required dependencies "
|
||||
f"Backend '{backend}' selected, but required dependencies "
|
||||
"are not installed."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid backend '{backend_name}'. "
|
||||
f"Invalid backend '{backend}'. "
|
||||
f"Available backends: {available_backends}"
|
||||
)
|
||||
|
||||
logger.info(f"Using backend {backend_name}")
|
||||
logger.info(f"Using backend {backend}")
|
||||
new_container: BaseModelContainer = await container_class.create(
|
||||
model_path.resolve(), **kwargs
|
||||
model_path.resolve(), hf_model, **kwargs
|
||||
)
|
||||
|
||||
# Add possible types of models that can be loaded
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue