A common problem in TabbyAPI is that users who want to get up and running with a model always had issues with max_seq_len causing OOMs. This is because model devs set max context values in the millions which requires a lot of VRAM. To idiot-proof first time setup, make the fallback default 4096 so users can run their models. If a user still wants to use the model's max_seq_len, set it to -1. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
176 lines
5.2 KiB
Python
176 lines
5.2 KiB
Python
import aiofiles
|
|
import json
|
|
import pathlib
|
|
from loguru import logger
|
|
from pydantic import BaseModel
|
|
from typing import Dict, List, Optional, Set, Union
|
|
|
|
|
|
class GenerationConfig(BaseModel):
|
|
"""
|
|
An abridged version of HuggingFace's GenerationConfig.
|
|
Will be expanded as needed.
|
|
"""
|
|
|
|
eos_token_id: Optional[Union[int, List[int]]] = None
|
|
|
|
@classmethod
|
|
async def from_directory(cls, model_directory: pathlib.Path):
|
|
"""Create an instance from a generation config file."""
|
|
|
|
generation_config_path = model_directory / "generation_config.json"
|
|
async with aiofiles.open(
|
|
generation_config_path, "r", encoding="utf8"
|
|
) as generation_config_json:
|
|
contents = await generation_config_json.read()
|
|
generation_config_dict = json.loads(contents)
|
|
return cls.model_validate(generation_config_dict)
|
|
|
|
def eos_tokens(self):
|
|
"""Wrapper method to fetch EOS tokens."""
|
|
|
|
if isinstance(self.eos_token_id, list):
|
|
return self.eos_token_id
|
|
elif isinstance(self.eos_token_id, int):
|
|
return [self.eos_token_id]
|
|
else:
|
|
return []
|
|
|
|
|
|
class HuggingFaceConfig(BaseModel):
|
|
"""
|
|
An abridged version of HuggingFace's model config.
|
|
Will be expanded as needed.
|
|
"""
|
|
|
|
max_position_embeddings: int = 4096
|
|
eos_token_id: Optional[Union[int, List[int]]] = None
|
|
quantization_config: Optional[Dict] = None
|
|
|
|
@classmethod
|
|
async def from_directory(cls, model_directory: pathlib.Path):
|
|
"""Create an instance from a generation config file."""
|
|
|
|
hf_config_path = model_directory / "config.json"
|
|
async with aiofiles.open(
|
|
hf_config_path, "r", encoding="utf8"
|
|
) as hf_config_json:
|
|
contents = await hf_config_json.read()
|
|
hf_config_dict = json.loads(contents)
|
|
return cls.model_validate(hf_config_dict)
|
|
|
|
def quant_method(self):
|
|
"""Wrapper method to fetch quant type"""
|
|
|
|
if isinstance(self.quantization_config, Dict):
|
|
return self.quantization_config.get("quant_method")
|
|
else:
|
|
return None
|
|
|
|
def eos_tokens(self):
|
|
"""Wrapper method to fetch EOS tokens."""
|
|
|
|
if isinstance(self.eos_token_id, list):
|
|
return self.eos_token_id
|
|
elif isinstance(self.eos_token_id, int):
|
|
return [self.eos_token_id]
|
|
else:
|
|
return []
|
|
|
|
|
|
class TokenizerConfig(BaseModel):
|
|
"""
|
|
An abridged version of HuggingFace's tokenizer config.
|
|
"""
|
|
|
|
add_bos_token: Optional[bool] = True
|
|
|
|
@classmethod
|
|
async def from_directory(cls, model_directory: pathlib.Path):
|
|
"""Create an instance from a tokenizer config file."""
|
|
|
|
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
|
async with aiofiles.open(
|
|
tokenizer_config_path, "r", encoding="utf8"
|
|
) as tokenizer_config_json:
|
|
contents = await tokenizer_config_json.read()
|
|
tokenizer_config_dict = json.loads(contents)
|
|
return cls.model_validate(tokenizer_config_dict)
|
|
|
|
|
|
class HFModel:
|
|
"""
|
|
Unified container for HuggingFace model configuration files.
|
|
These are abridged for hyper-specific model parameters not covered
|
|
by most backends.
|
|
|
|
Includes:
|
|
- config.json
|
|
- generation_config.json
|
|
- tokenizer_config.json
|
|
"""
|
|
|
|
hf_config: HuggingFaceConfig
|
|
tokenizer_config: Optional[TokenizerConfig] = None
|
|
generation_config: Optional[GenerationConfig] = None
|
|
|
|
@classmethod
|
|
async def from_directory(cls, model_directory: pathlib.Path):
|
|
"""Create an instance from a model directory"""
|
|
|
|
self = cls()
|
|
|
|
# A model must have an HF config
|
|
try:
|
|
self.hf_config = await HuggingFaceConfig.from_directory(model_directory)
|
|
except Exception as exc:
|
|
raise ValueError(
|
|
f"Failed to load config.json from {model_directory}"
|
|
) from exc
|
|
|
|
try:
|
|
self.generation_config = await GenerationConfig.from_directory(
|
|
model_directory
|
|
)
|
|
except Exception:
|
|
logger.warning(
|
|
"Generation config file not found in model directory, skipping."
|
|
)
|
|
|
|
try:
|
|
self.tokenizer_config = await TokenizerConfig.from_directory(
|
|
model_directory
|
|
)
|
|
except Exception:
|
|
logger.warning(
|
|
"Tokenizer config file not found in model directory, skipping."
|
|
)
|
|
|
|
return self
|
|
|
|
def quant_method(self):
|
|
"""Wrapper for quantization method"""
|
|
|
|
return self.hf_config.quant_method()
|
|
|
|
def eos_tokens(self):
|
|
"""Combines and returns EOS tokens from various configs"""
|
|
|
|
eos_ids: Set[int] = set()
|
|
|
|
eos_ids.update(self.hf_config.eos_tokens())
|
|
|
|
if self.generation_config:
|
|
eos_ids.update(self.generation_config.eos_tokens())
|
|
|
|
# Convert back to a list
|
|
return list(eos_ids)
|
|
|
|
def add_bos_token(self):
|
|
"""Wrapper for tokenizer config"""
|
|
|
|
if self.tokenizer_config:
|
|
return self.tokenizer_config.add_bos_token
|
|
|
|
# Expected default
|
|
return True
|