Model: Add support for HuggingFace config and bad_words_ids

This is necessary for Kobold's API. Current models use bad_words_ids
in generation_config.json, but for some reason, they're also present
in the model's config.json.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-26 18:23:22 -04:00
parent 545e26608f
commit 7522b1447b
4 changed files with 68 additions and 9 deletions

View file

@ -47,7 +47,7 @@ from common.templating import (
TemplateLoadError,
find_template_from_model,
)
from common.transformers_utils import GenerationConfig
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
from common.utils import coalesce, unwrap
@ -72,6 +72,7 @@ class ExllamaV2Container:
draft_cache_mode: str = "FP16"
max_batch_size: int = 20
generation_config: Optional[GenerationConfig] = None
hf_config: Optional[HuggingFaceConfig] = None
# GPU split vars
gpu_split: Optional[list] = None
@ -186,6 +187,9 @@ class ExllamaV2Container:
except AttributeError:
pass
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
@ -268,15 +272,8 @@ class ExllamaV2Container:
else:
self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Load generation config overrides
generation_config_path = (
pathlib.Path(self.config.model_dir) / "generation_config.json"
)
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
@ -288,6 +285,11 @@ class ExllamaV2Container:
"Skipping generation config load because of an unexpected error."
)
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors
if self.prompt_template:
logger.info(

View file

@ -1,6 +1,7 @@
import json
import pathlib
from typing import List, Optional, Union
from loguru import logger
from pydantic import BaseModel
@ -11,6 +12,7 @@ class GenerationConfig(BaseModel):
"""
eos_token_id: Optional[Union[int, List[int]]] = None
bad_words_ids: Optional[List[List[int]]] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
@ -30,3 +32,40 @@ class GenerationConfig(BaseModel):
return [self.eos_token_id]
else:
return self.eos_token_id
class HuggingFaceConfig(BaseModel):
"""
An abridged version of HuggingFace's model config.
Will be expanded as needed.
"""
badwordsids: Optional[str] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
with open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
hf_config_dict = json.load(hf_config_json)
return self.model_validate(hf_config_dict)
def get_badwordsids(self):
"""Wrapper method to fetch badwordsids."""
if self.badwordsids:
try:
bad_words_list = json.loads(self.badwordsids)
return bad_words_list
except json.JSONDecodeError:
logger.warning(
"Skipping badwordsids from config.json "
"since it's not a valid array."
)
return []
else:
return []

View file

@ -18,3 +18,9 @@ def prune_dict(input_dict):
"""Trim out instances of None from a dictionary."""
return {k: v for k, v in input_dict.items() if v is not None}
def flat_map(input_list):
"""Flattens a list of lists into a single list."""
return [item for sublist in input_list for item in sublist]

View file

@ -1,7 +1,9 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from common import model
from common.sampling import BaseSamplerRequest
from common.utils import flat_map, unwrap
class GenerateRequest(BaseSamplerRequest):
@ -14,6 +16,16 @@ class GenerateRequest(BaseSamplerRequest):
if self.penalty_range == 0:
self.penalty_range = -1
# Move badwordsids into banned tokens for generation
if self.use_default_badwordsids:
bad_words_ids = unwrap(
model.container.generation_config.bad_words_ids,
model.container.hf_config.get_badwordsids()
)
if bad_words_ids:
self.banned_tokens += flat_map(bad_words_ids)
return super().to_gen_params(**kwargs)