Model + API: Migrate to use BaseSamplerParams

kwargs is pretty ugly when figuring out which arguments to use. The
base requests falls back to defaults anyways, so pass in the params
object as is.

However, since Python's typing isn't like TypeScript where types
can be transformed, the type hinting has a possiblity of None showing
up despite there always being a value for some params.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-16 00:50:05 -04:00
parent dcb36e9ab2
commit 3084ef9fa1
5 changed files with 113 additions and 121 deletions

View file

@ -1,17 +1,13 @@
"""The model container class for ExLlamaV2 models."""
from functools import partial
import aiofiles
import asyncio
import gc
import math
import pathlib
import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch
import uuid
from copy import deepcopy
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
@ -32,7 +28,7 @@ from exllamav2.generator import (
)
from itertools import zip_longest
from loguru import logger
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional
from ruamel.yaml import YAML
@ -47,6 +43,7 @@ from backends.exllamav2.utils import (
hardware_supports_flash_attn,
supports_paged_attn,
)
from backends.exllamav2.vision import clear_image_embedding_cache
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
@ -54,6 +51,8 @@ from common.gen_logging import (
log_prompt,
log_response,
)
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import (
PromptTemplate,
TemplateLoadError,
@ -976,15 +975,20 @@ class ExllamaV2Container:
async def generate(
self,
prompt: str,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
request_id,
prompt,
params,
abort_event,
mm_embeddings,
):
generations.append(generation)
@ -1031,21 +1035,22 @@ class ExllamaV2Container:
return joined_generation
def check_unsupported_settings(self, **kwargs):
def check_unsupported_settings(self, params: BaseSamplerRequest):
"""
Check and warn the user if a sampler is unsupported.
Meant for dev wheels!
"""
return kwargs
return params
async def generate_gen(
self,
prompt: str,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""
Create generator function for prompt completion.
@ -1059,46 +1064,43 @@ class ExllamaV2Container:
prompts = [prompt]
token_healing = kwargs.get("token_healing")
generate_window = max(
kwargs.get("generate_window"), self.config.max_seq_len // 8
)
# TODO: Not used for some reason?
generate_window = max(params.generate_window, self.config.max_seq_len // 8)
# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
# Check unsupported settings for dev wheels
kwargs = self.check_unsupported_settings(**kwargs)
params = self.check_unsupported_settings(params)
# Apply settings
gen_settings.temperature = kwargs.get("temperature")
gen_settings.temperature_last = kwargs.get("temperature_last")
gen_settings.smoothing_factor = kwargs.get("smoothing_factor")
gen_settings.top_k = kwargs.get("top_k")
gen_settings.top_p = kwargs.get("top_p")
gen_settings.top_a = kwargs.get("top_a")
gen_settings.min_p = kwargs.get("min_p")
gen_settings.tfs = kwargs.get("tfs")
gen_settings.typical = kwargs.get("typical")
gen_settings.mirostat = kwargs.get("mirostat")
gen_settings.skew = kwargs.get("skew")
gen_settings.temperature = params.temperature
gen_settings.temperature_last = params.temperature_last
gen_settings.smoothing_factor = params.smoothing_factor
gen_settings.top_k = params.top_k
gen_settings.top_p = params.top_p
gen_settings.top_a = params.top_a
gen_settings.min_p = params.min_p
gen_settings.tfs = params.tfs
gen_settings.typical = params.typical
gen_settings.mirostat = params.mirostat
gen_settings.skew = params.skew
# XTC
xtc_probability = kwargs.get("xtc_probability")
if xtc_probability > 0.0:
gen_settings.xtc_probability = xtc_probability
if params.xtc_probability > 0.0:
gen_settings.xtc_probability = params.xtc_probability
# 0.1 is the default for this value
gen_settings.xtc_threshold = kwargs.get("xtc_threshold")
gen_settings.xtc_threshold = params.xtc_threshold
# DynaTemp settings
max_temp = kwargs.get("max_temp")
min_temp = kwargs.get("min_temp")
max_temp = params.max_temp
min_temp = params.min_temp
if max_temp > min_temp:
if params.max_temp > params.min_temp:
gen_settings.max_temp = max_temp
gen_settings.min_temp = min_temp
gen_settings.temp_exponent = kwargs.get("temp_exponent")
gen_settings.temp_exponent = params.temp_exponent
else:
# Force to default values
gen_settings.max_temp = 1.0
@ -1115,11 +1117,11 @@ class ExllamaV2Container:
)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = kwargs.get("mirostat_tau")
gen_settings.mirostat_eta = kwargs.get("mirostat_eta")
gen_settings.mirostat_tau = params.mirostat_tau
gen_settings.mirostat_eta = params.mirostat_eta
# Set CFG scale and negative prompt
cfg_scale = kwargs.get("cfg_scale")
cfg_scale = params.cfg_scale
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.paged:
@ -1127,7 +1129,7 @@ class ExllamaV2Container:
# If the negative prompt is empty, use the BOS token
negative_prompt = unwrap(
kwargs.get("negative_prompt"), self.tokenizer.bos_token
params.negative_prompt, self.tokenizer.bos_token
)
prompts.append(negative_prompt)
@ -1138,15 +1140,16 @@ class ExllamaV2Container:
)
# Penalties
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty")
gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty")
gen_settings.token_presence_penalty = kwargs.get("presence_penalty")
gen_settings.token_repetition_penalty = params.repetition_penalty
gen_settings.token_frequency_penalty = params.frequency_penalty
gen_settings.token_presence_penalty = params.presence_penalty
# Applies for all penalties despite being called token_repetition_range
gen_settings.token_repetition_range = unwrap(
kwargs.get("penalty_range"), self.config.max_seq_len
params.penalty_range, self.config.max_seq_len
)
# TODO: Not used for some reason?
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
@ -1164,54 +1167,51 @@ class ExllamaV2Container:
else:
fallback_decay = gen_settings.token_repetition_range
gen_settings.token_repetition_decay = coalesce(
kwargs.get("repetition_decay"), fallback_decay, 0
params.repetition_decay, fallback_decay, 0
)
# DRY options
dry_multiplier = kwargs.get("dry_multiplier")
dry_multiplier = params.dry_multiplier
# < 0 = disabled
if dry_multiplier > 0:
gen_settings.dry_multiplier = dry_multiplier
gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length")
gen_settings.dry_base = kwargs.get("dry_base")
gen_settings.dry_allowed_length = params.dry_allowed_length
gen_settings.dry_base = params.dry_base
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
# Use max_seq_len as the fallback to stay consistent
gen_settings.dry_range = unwrap(
kwargs.get("dry_range"), self.config.max_seq_len
)
gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len)
# Tokenize sequence breakers
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
if dry_sequence_breakers_json:
if params.dry_sequence_breakers:
gen_settings.dry_sequence_breakers = {
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
}
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists
json_schema = kwargs.get("json_schema")
if json_schema:
if params.json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
params.json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = kwargs.get("regex_pattern")
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
if params.regex_pattern:
grammar_handler.add_regex_filter(
params.regex_pattern, self.model, self.tokenizer
)
# Add EBNF filter if it exists
grammar_string = kwargs.get("grammar_string")
if grammar_string:
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
if params.grammar_string:
grammar_handler.add_kbnf_filter(
params.grammar_string, self.model, self.tokenizer
)
# Set banned strings
banned_strings = kwargs.get("banned_strings")
banned_strings = params.banned_strings
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
"Disabling banned_strings because "
@ -1220,16 +1220,12 @@ class ExllamaV2Container:
banned_strings = []
stop_conditions = kwargs.get("stop")
add_bos_token = kwargs.get("add_bos_token"), True
ban_eos_token = kwargs.get("ban_eos_token"), False
logit_bias = kwargs.get("logit_bias")
# Logprobs
request_logprobs = kwargs.get("logprobs")
stop_conditions = params.stop
add_bos_token = params.add_bos_token
ban_eos_token = params.ban_eos_token
# Speculative Ngram
self.generator.speculative_ngram = kwargs.get("speculative_ngram")
self.generator.speculative_ngram = params.speculative_ngram
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
@ -1244,17 +1240,15 @@ class ExllamaV2Container:
)
# Set banned tokens
banned_tokens = kwargs.get("banned_tokens")
if banned_tokens:
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
if params.banned_tokens:
gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens)
# Set allowed tokens
allowed_tokens = kwargs.get("allowed_tokens")
if allowed_tokens:
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
if params.allowed_tokens:
gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens)
# Set logit bias
if logit_bias:
if params.logit_bias:
# Create a vocab tensor if it doesn't exist for token biasing
if gen_settings.token_bias is None:
padding = -self.tokenizer.config.vocab_size % 32
@ -1264,7 +1258,7 @@ class ExllamaV2Container:
)
# Map logits to the tensor with their biases
for token_id, bias in logit_bias.items():
for token_id, bias in params.logit_bias.items():
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
gen_settings.token_bias[token_id] = bias
else:
@ -1289,7 +1283,7 @@ class ExllamaV2Container:
stop_conditions += eos_tokens
# Get multimodal embeddings if present
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
# TODO: Remove kwargs and pass this as optional
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
# Encode both positive and negative prompts
@ -1312,7 +1306,7 @@ class ExllamaV2Container:
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
kwargs.get("max_tokens"),
params.max_tokens,
self.config.max_seq_len - max(context_len, negative_context_len),
)
if max_tokens < 1:
@ -1349,12 +1343,6 @@ class ExllamaV2Container:
f"is greater than cache_size {self.cache_size}"
)
# Set min_tokens to generate while keeping EOS banned
min_tokens = kwargs.get("min_tokens")
# This is an inverse of skip_special_tokens
decode_special_tokens = not kwargs.get("skip_special_tokens")
# Log prompt to console. Add the BOS token if specified
log_prompt(
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
@ -1369,17 +1357,17 @@ class ExllamaV2Container:
self.generator,
input_ids=input_ids,
max_new_tokens=max_tokens,
min_new_tokens=min_tokens,
min_new_tokens=params.min_tokens,
gen_settings=gen_settings,
stop_conditions=stop_conditions,
decode_special_tokens=decode_special_tokens,
decode_special_tokens=not params.skip_special_tokens,
filters=grammar_handler.filters,
filter_prefer_eos=bool(grammar_handler.filters),
return_probs=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
return_probs=params.logprobs > 0,
return_top_tokens=params.logprobs,
return_logits=params.logprobs > 0,
banned_strings=banned_strings,
token_healing=token_healing,
token_healing=params.token_healing,
identifier=job_id,
embeddings=mm_embeddings_content,
)
@ -1418,7 +1406,7 @@ class ExllamaV2Container:
"offset": len(full_response),
}
if request_logprobs > 0:
if params.logprobs > 0:
# Get top tokens and probs
top_tokens = unwrap(
result.get("top_k_tokens"),
@ -1494,8 +1482,7 @@ class ExllamaV2Container:
request_id=request_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
**kwargs,
generate_window=generate_window,
**params.model_dump(),
auto_scale_penalty_range=auto_scale_penalty_range,
)

View file

@ -282,6 +282,11 @@ class BaseSamplerRequest(BaseModel):
ge=0,
)
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
@field_validator("top_k", mode="before")
def convert_top_k(cls, v):
"""Fixes instance if Top-K is -1."""

View file

@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
stream_options: Optional[ChatCompletionStreamOptions] = None
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
response_format: Optional[CompletionResponseFormat] = Field(
default_factory=CompletionResponseFormat
)

View file

@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
_stream_collector(
n,
gen_queue,
prompt,
request.state.id,
prompt,
task_gen_params,
abort_event,
embeddings=embeddings,
**task_gen_params.model_dump(exclude={"prompt"}),
mm_embeddings=embeddings,
)
)
@ -422,10 +422,10 @@ async def generate_chat_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
prompt,
request.state.id,
embeddings=embeddings,
**data.model_dump(exclude={"prompt"}),
prompt,
data,
mm_embeddings=embeddings,
)
)
)
@ -465,7 +465,6 @@ async def generate_tool_calls(
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema
gen_params = tool_data.model_dump()
for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start:
@ -488,10 +487,10 @@ async def generate_tool_calls(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
pre_tool_prompt,
request.state.id,
pre_tool_prompt,
tool_data,
embeddings=mm_embeddings,
**gen_params,
)
)
)

View file

@ -8,12 +8,12 @@ import asyncio
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException, Request
from typing import List, Union
from loguru import logger
from typing import List, Optional, Union
from common import model
from common.auth import get_key_permission
from common.multimodal import MultimodalEmbeddingWrapper
from common.networking import (
get_generator_error,
handle_request_disconnect,
@ -86,16 +86,21 @@ def _create_response(
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
request_id: str,
prompt: str,
params: CompletionRequest,
abort_event: asyncio.Event,
**kwargs,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
request_id,
prompt,
params,
abort_event,
mm_embeddings,
)
async for generation in new_generation:
generation["index"] = task_idx
@ -195,10 +200,10 @@ async def stream_generate_completion(
_stream_collector(
n,
gen_queue,
data.prompt,
request.state.id,
data.prompt,
task_gen_params,
abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
)
)
@ -256,9 +261,9 @@ async def generate_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
data.prompt,
request.state.id,
**task_gen_params.model_dump(exclude={"prompt"}),
data.prompt,
task_gen_params,
)
)
)