tabbyAPI-ollama/common/transformers_utils.py
kingbri 7baef05b49 Transformers Utils: Fix file read
Use asynchronous JSON reading

Signed-off-by: kingbri <bdashore3@proton.me>
2024-09-10 22:41:39 -04:00

74 lines
2.2 KiB
Python

import aiofiles
import json
import pathlib
from typing import List, Optional, Union
from loguru import logger
from pydantic import BaseModel
class GenerationConfig(BaseModel):
"""
An abridged version of HuggingFace's GenerationConfig.
Will be expanded as needed.
"""
eos_token_id: Optional[Union[int, List[int]]] = None
bad_words_ids: Optional[List[List[int]]] = None
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json"
async with aiofiles.open(
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
contents = await generation_config_json.read()
generation_config_dict = json.loads(contents)
return cls.model_validate(generation_config_dict)
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
if isinstance(self.eos_token_id, int):
return [self.eos_token_id]
else:
return self.eos_token_id
class HuggingFaceConfig(BaseModel):
"""
An abridged version of HuggingFace's model config.
Will be expanded as needed.
"""
badwordsids: Optional[str] = None
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
async with aiofiles.open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
contents = await hf_config_json.read()
hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict)
def get_badwordsids(self):
"""Wrapper method to fetch badwordsids."""
if self.badwordsids:
try:
bad_words_list = json.loads(self.badwordsids)
return bad_words_list
except json.JSONDecodeError:
logger.warning(
"Skipping badwordsids from config.json "
"since it's not a valid array."
)
return []
else:
return []