Tree: Fix classmethod usage
Instead of self, use cls which passes a type of the class. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2c3bc71afa
commit
5e8ff9a004
2 changed files with 8 additions and 8 deletions
|
|
@ -111,7 +111,7 @@ class PromptTemplate:
|
|||
self.template = self.compile(raw_template)
|
||||
|
||||
@classmethod
|
||||
async def from_file(self, template_path: pathlib.Path):
|
||||
async def from_file(cls, template_path: pathlib.Path):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
# Add the jinja extension if it isn't provided
|
||||
|
|
@ -126,7 +126,7 @@ class PromptTemplate:
|
|||
template_path, "r", encoding="utf8"
|
||||
) as raw_template_stream:
|
||||
contents = await raw_template_stream.read()
|
||||
return PromptTemplate(
|
||||
return cls(
|
||||
name=template_name,
|
||||
raw_template=contents,
|
||||
)
|
||||
|
|
@ -138,7 +138,7 @@ class PromptTemplate:
|
|||
|
||||
@classmethod
|
||||
async def from_model_json(
|
||||
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
cls, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if not json_path.exists():
|
||||
|
|
@ -177,7 +177,7 @@ class PromptTemplate:
|
|||
)
|
||||
else:
|
||||
# Can safely assume the chat template is the old style
|
||||
return PromptTemplate(
|
||||
return cls(
|
||||
name="from_tokenizer_config",
|
||||
raw_template=chat_template,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class GenerationConfig(BaseModel):
|
|||
bad_words_ids: Optional[List[List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
async def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
|
|
@ -24,7 +24,7 @@ class GenerationConfig(BaseModel):
|
|||
generation_config_path, "r", encoding="utf8"
|
||||
) as generation_config_json:
|
||||
generation_config_dict = json.load(generation_config_json)
|
||||
return self.model_validate(generation_config_dict)
|
||||
return cls.model_validate(generation_config_dict)
|
||||
|
||||
def eos_tokens(self):
|
||||
"""Wrapper method to fetch EOS tokens."""
|
||||
|
|
@ -44,7 +44,7 @@ class HuggingFaceConfig(BaseModel):
|
|||
badwordsids: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
async def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
|
|
@ -53,7 +53,7 @@ class HuggingFaceConfig(BaseModel):
|
|||
) as hf_config_json:
|
||||
contents = await hf_config_json.read()
|
||||
hf_config_dict = json.loads(contents)
|
||||
return self.model_validate(hf_config_dict)
|
||||
return cls.model_validate(hf_config_dict)
|
||||
|
||||
def get_badwordsids(self):
|
||||
"""Wrapper method to fetch badwordsids."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue