Add new samplers

This commit is contained in:
turboderp 2023-11-12 08:12:08 +01:00
parent a10c14d357
commit 4fa4386275

View file

@ -204,7 +204,12 @@ class ModelContainer:
'temperature' (float): Sampling temperature (default: 0.8)
'top_k' (int): Sampling top-K (default: 100)
'top_p' (float): Sampling top-P (default: 0.8)
'min_p' (float): Sampling min-P (default: 0.0)
'tfs' (float): Tail-free sampling (default: 0.0)
'typical' (float): Sampling typical (default: 0.0)
'mirostat' (bool): Use Mirostat (default: False)
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
'token_repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
'token_repetition_range' (int): Repetition penalty range (default: whole context)
'token_repetition_decay' (int): Repetition penalty range (default: same as range)
@ -228,7 +233,12 @@ class ModelContainer:
gen_settings.temperature = kwargs.get("temperature", 0.8)
gen_settings.top_k = kwargs.get("top_k", 100)
gen_settings.top_p = kwargs.get("top_p", 0.8)
gen_settings.min_p = kwargs.get("min_p", 0.0)
gen_settings.tfs = kwargs.get("tfs", 0.0)
gen_settings.typical = kwargs.get("typical", 0.0)
gen_settings.mirostat = kwargs.get("mirostat", False)
gen_settings.mirostat_tau = kwargs.get("mirostat_tau", 1.5)
gen_settings.mirostat_eta = kwargs.get("mirostat_eta", 0.1)
gen_settings.token_repetition_penalty = kwargs.get("token_repetition_penalty", 1.15)
gen_settings.token_repetition_range = kwargs.get("token_repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("token_repetition_decay", gen_settings.token_repetition_range)