Tree: Fix classmethod usage

Instead of self, use cls which passes a type of the class.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-09-10 20:52:29 -04:00
parent 2c3bc71afa
commit 5e8ff9a004
2 changed files with 8 additions and 8 deletions

View file

@ -111,7 +111,7 @@ class PromptTemplate:
self.template = self.compile(raw_template)
@classmethod
async def from_file(self, template_path: pathlib.Path):
async def from_file(cls, template_path: pathlib.Path):
"""Get a template from a jinja file."""
# Add the jinja extension if it isn't provided
@ -126,7 +126,7 @@ class PromptTemplate:
template_path, "r", encoding="utf8"
) as raw_template_stream:
contents = await raw_template_stream.read()
return PromptTemplate(
return cls(
name=template_name,
raw_template=contents,
)
@ -138,7 +138,7 @@ class PromptTemplate:
@classmethod
async def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
cls, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
@ -177,7 +177,7 @@ class PromptTemplate:
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(
return cls(
name="from_tokenizer_config",
raw_template=chat_template,
)

View file

@ -16,7 +16,7 @@ class GenerationConfig(BaseModel):
bad_words_ids: Optional[List[List[int]]] = None
@classmethod
async def from_file(self, model_directory: pathlib.Path):
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json"
@ -24,7 +24,7 @@ class GenerationConfig(BaseModel):
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
generation_config_dict = json.load(generation_config_json)
return self.model_validate(generation_config_dict)
return cls.model_validate(generation_config_dict)
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
@ -44,7 +44,7 @@ class HuggingFaceConfig(BaseModel):
badwordsids: Optional[str] = None
@classmethod
async def from_file(self, model_directory: pathlib.Path):
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
@ -53,7 +53,7 @@ class HuggingFaceConfig(BaseModel):
) as hf_config_json:
contents = await hf_config_json.read()
hf_config_dict = json.loads(contents)
return self.model_validate(hf_config_dict)
return cls.model_validate(hf_config_dict)
def get_badwordsids(self):
"""Wrapper method to fetch badwordsids."""