Tree: Add transformers_utils
Part of commit 8824ea0205
Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8824ea0205
commit
67f061859d
1 changed files with 32 additions and 0 deletions
32
common/transformers_utils.py
Normal file
32
common/transformers_utils.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
"""
|
||||
An abridged version of HuggingFace's GenerationConfig.
|
||||
Will be expanded as needed.
|
||||
"""
|
||||
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
with open(
|
||||
generation_config_path, "r", encoding="utf8"
|
||||
) as generation_config_json:
|
||||
generation_config_dict = json.load(generation_config_json)
|
||||
return self.model_validate(generation_config_dict)
|
||||
|
||||
def eos_tokens(self):
|
||||
"""Wrapper method to fetch EOS tokens."""
|
||||
|
||||
if isinstance(self.eos_token_id, int):
|
||||
return [self.eos_token_id]
|
||||
else:
|
||||
return self.eos_token_id
|
||||
Loading…
Add table
Add a link
Reference in a new issue