Model: Add Exllamav3 sampler

File was not included in previous commit.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-01 18:23:45 -04:00
parent bdc5189a4b
commit eca403a0e4

View file

@ -0,0 +1,54 @@
from dataclasses import dataclass, field
from typing import List
from exllamav3.generator.sampler import (
CustomSampler,
SS_Temperature,
SS_RepP,
SS_PresFreqP,
SS_Argmax,
SS_MinP,
SS_TopK,
SS_TopP,
SS_Sample,
SS_Base,
)
@dataclass
class ExllamaV3SamplerBuilder:
"""
Custom sampler chain/stack for TabbyAPI
"""
stack: List[SS_Base] = field(default_factory=list)
def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay):
self.stack += [
SS_RepP(rep_p, penalty_range, rep_decay),
SS_PresFreqP(pres_p, freq_p, penalty_range, rep_decay),
]
def temperature(self, temp):
self.stack.append(SS_Temperature(temp))
def top_k(self, top_k):
self.stack.append(SS_TopK(top_k))
def top_p(self, top_p):
self.stack.append(SS_TopP(top_p))
def min_p(self, min_p):
self.stack.append(SS_MinP(min_p))
def greedy(self):
self.stack.append(SS_Argmax())
def build(self, greedy):
"""Builds the final sampler from stack."""
# Use greedy if temp is 0
if greedy:
return CustomSampler([SS_Argmax()])
else:
self.stack.append(SS_Sample())
return CustomSampler(self.stack)