Model: Add Exllamav3 sampler
File was not included in previous commit. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
bdc5189a4b
commit
eca403a0e4
1 changed files with 54 additions and 0 deletions
54
backends/exllamav3/sampler.py
Normal file
54
backends/exllamav3/sampler.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
from exllamav3.generator.sampler import (
|
||||
CustomSampler,
|
||||
SS_Temperature,
|
||||
SS_RepP,
|
||||
SS_PresFreqP,
|
||||
SS_Argmax,
|
||||
SS_MinP,
|
||||
SS_TopK,
|
||||
SS_TopP,
|
||||
SS_Sample,
|
||||
SS_Base,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExllamaV3SamplerBuilder:
|
||||
"""
|
||||
Custom sampler chain/stack for TabbyAPI
|
||||
"""
|
||||
|
||||
stack: List[SS_Base] = field(default_factory=list)
|
||||
|
||||
def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay):
|
||||
self.stack += [
|
||||
SS_RepP(rep_p, penalty_range, rep_decay),
|
||||
SS_PresFreqP(pres_p, freq_p, penalty_range, rep_decay),
|
||||
]
|
||||
|
||||
def temperature(self, temp):
|
||||
self.stack.append(SS_Temperature(temp))
|
||||
|
||||
def top_k(self, top_k):
|
||||
self.stack.append(SS_TopK(top_k))
|
||||
|
||||
def top_p(self, top_p):
|
||||
self.stack.append(SS_TopP(top_p))
|
||||
|
||||
def min_p(self, min_p):
|
||||
self.stack.append(SS_MinP(min_p))
|
||||
|
||||
def greedy(self):
|
||||
self.stack.append(SS_Argmax())
|
||||
|
||||
def build(self, greedy):
|
||||
"""Builds the final sampler from stack."""
|
||||
|
||||
# Use greedy if temp is 0
|
||||
if greedy:
|
||||
return CustomSampler([SS_Argmax()])
|
||||
else:
|
||||
self.stack.append(SS_Sample())
|
||||
return CustomSampler(self.stack)
|
||||
Loading…
Add table
Add a link
Reference in a new issue