From 3084ef9fa1bfc1679890d3b7f81419b833a7f589 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Wed, 16 Apr 2025 00:50:05 -0400 Subject: [PATCH] Model + API: Migrate to use BaseSamplerParams kwargs is pretty ugly when figuring out which arguments to use. The base requests falls back to defaults anyways, so pass in the params object as is. However, since Python's typing isn't like TypeScript where types can be transformed, the type hinting has a possiblity of None showing up despite there always being a value for some params. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/exllamav2/model.py | 185 ++++++++++++------------- common/sampling.py | 5 + endpoints/OAI/types/common.py | 4 - endpoints/OAI/utils/chat_completion.py | 17 ++- endpoints/OAI/utils/completion.py | 23 +-- 5 files changed, 113 insertions(+), 121 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6accc88..6d69f63 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,17 +1,13 @@ """The model container class for ExLlamaV2 models.""" -from functools import partial import aiofiles import asyncio import gc import math import pathlib import traceback -from backends.exllamav2.vision import clear_image_embedding_cache -from common.multimodal import MultimodalEmbeddingWrapper import torch import uuid -from copy import deepcopy from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, @@ -32,7 +28,7 @@ from exllamav2.generator import ( ) from itertools import zip_longest from loguru import logger -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from ruamel.yaml import YAML @@ -47,6 +43,7 @@ from backends.exllamav2.utils import ( hardware_supports_flash_attn, supports_paged_attn, ) +from backends.exllamav2.vision import clear_image_embedding_cache from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, @@ -54,6 +51,8 @@ from common.gen_logging import ( log_prompt, log_response, ) +from common.multimodal import MultimodalEmbeddingWrapper +from common.sampling import BaseSamplerRequest from common.templating import ( PromptTemplate, TemplateLoadError, @@ -976,15 +975,20 @@ class ExllamaV2Container: async def generate( self, - prompt: str, request_id: str, - abort_event: asyncio.Event = None, - **kwargs, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """Generate a response to a prompt.""" generations = [] async for generation in self.generate_gen( - prompt, request_id, abort_event, **kwargs + request_id, + prompt, + params, + abort_event, + mm_embeddings, ): generations.append(generation) @@ -1031,21 +1035,22 @@ class ExllamaV2Container: return joined_generation - def check_unsupported_settings(self, **kwargs): + def check_unsupported_settings(self, params: BaseSamplerRequest): """ Check and warn the user if a sampler is unsupported. Meant for dev wheels! """ - return kwargs + return params async def generate_gen( self, - prompt: str, request_id: str, + prompt: str, + params: BaseSamplerRequest, abort_event: Optional[asyncio.Event] = None, - **kwargs, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """ Create generator function for prompt completion. @@ -1059,46 +1064,43 @@ class ExllamaV2Container: prompts = [prompt] - token_healing = kwargs.get("token_healing") - generate_window = max( - kwargs.get("generate_window"), self.config.max_seq_len // 8 - ) + # TODO: Not used for some reason? + generate_window = max(params.generate_window, self.config.max_seq_len // 8) # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() # Check unsupported settings for dev wheels - kwargs = self.check_unsupported_settings(**kwargs) + params = self.check_unsupported_settings(params) # Apply settings - gen_settings.temperature = kwargs.get("temperature") - gen_settings.temperature_last = kwargs.get("temperature_last") - gen_settings.smoothing_factor = kwargs.get("smoothing_factor") - gen_settings.top_k = kwargs.get("top_k") - gen_settings.top_p = kwargs.get("top_p") - gen_settings.top_a = kwargs.get("top_a") - gen_settings.min_p = kwargs.get("min_p") - gen_settings.tfs = kwargs.get("tfs") - gen_settings.typical = kwargs.get("typical") - gen_settings.mirostat = kwargs.get("mirostat") - gen_settings.skew = kwargs.get("skew") + gen_settings.temperature = params.temperature + gen_settings.temperature_last = params.temperature_last + gen_settings.smoothing_factor = params.smoothing_factor + gen_settings.top_k = params.top_k + gen_settings.top_p = params.top_p + gen_settings.top_a = params.top_a + gen_settings.min_p = params.min_p + gen_settings.tfs = params.tfs + gen_settings.typical = params.typical + gen_settings.mirostat = params.mirostat + gen_settings.skew = params.skew # XTC - xtc_probability = kwargs.get("xtc_probability") - if xtc_probability > 0.0: - gen_settings.xtc_probability = xtc_probability + if params.xtc_probability > 0.0: + gen_settings.xtc_probability = params.xtc_probability # 0.1 is the default for this value - gen_settings.xtc_threshold = kwargs.get("xtc_threshold") + gen_settings.xtc_threshold = params.xtc_threshold # DynaTemp settings - max_temp = kwargs.get("max_temp") - min_temp = kwargs.get("min_temp") + max_temp = params.max_temp + min_temp = params.min_temp - if max_temp > min_temp: + if params.max_temp > params.min_temp: gen_settings.max_temp = max_temp gen_settings.min_temp = min_temp - gen_settings.temp_exponent = kwargs.get("temp_exponent") + gen_settings.temp_exponent = params.temp_exponent else: # Force to default values gen_settings.max_temp = 1.0 @@ -1115,11 +1117,11 @@ class ExllamaV2Container: ) # Default tau and eta fallbacks don't matter if mirostat is off - gen_settings.mirostat_tau = kwargs.get("mirostat_tau") - gen_settings.mirostat_eta = kwargs.get("mirostat_eta") + gen_settings.mirostat_tau = params.mirostat_tau + gen_settings.mirostat_eta = params.mirostat_eta # Set CFG scale and negative prompt - cfg_scale = kwargs.get("cfg_scale") + cfg_scale = params.cfg_scale negative_prompt = None if cfg_scale not in [None, 1.0]: if self.paged: @@ -1127,7 +1129,7 @@ class ExllamaV2Container: # If the negative prompt is empty, use the BOS token negative_prompt = unwrap( - kwargs.get("negative_prompt"), self.tokenizer.bos_token + params.negative_prompt, self.tokenizer.bos_token ) prompts.append(negative_prompt) @@ -1138,15 +1140,16 @@ class ExllamaV2Container: ) # Penalties - gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") - gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty") - gen_settings.token_presence_penalty = kwargs.get("presence_penalty") + gen_settings.token_repetition_penalty = params.repetition_penalty + gen_settings.token_frequency_penalty = params.frequency_penalty + gen_settings.token_presence_penalty = params.presence_penalty # Applies for all penalties despite being called token_repetition_range gen_settings.token_repetition_range = unwrap( - kwargs.get("penalty_range"), self.config.max_seq_len + params.penalty_range, self.config.max_seq_len ) + # TODO: Not used for some reason? # Dynamically scale penalty range to output tokens # Only do this if freq/pres pen is enabled # and the repetition range is -1 @@ -1164,54 +1167,51 @@ class ExllamaV2Container: else: fallback_decay = gen_settings.token_repetition_range gen_settings.token_repetition_decay = coalesce( - kwargs.get("repetition_decay"), fallback_decay, 0 + params.repetition_decay, fallback_decay, 0 ) # DRY options - dry_multiplier = kwargs.get("dry_multiplier") + dry_multiplier = params.dry_multiplier # < 0 = disabled if dry_multiplier > 0: gen_settings.dry_multiplier = dry_multiplier - - gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length") - gen_settings.dry_base = kwargs.get("dry_base") + gen_settings.dry_allowed_length = params.dry_allowed_length + gen_settings.dry_base = params.dry_base # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range # Use max_seq_len as the fallback to stay consistent - gen_settings.dry_range = unwrap( - kwargs.get("dry_range"), self.config.max_seq_len - ) + gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len) # Tokenize sequence breakers - dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers") - if dry_sequence_breakers_json: + if params.dry_sequence_breakers: gen_settings.dry_sequence_breakers = { - self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json + self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers } # Initialize grammar handler grammar_handler = ExLlamaV2Grammar() # Add JSON schema filter if it exists - json_schema = kwargs.get("json_schema") - if json_schema: + if params.json_schema: grammar_handler.add_json_schema_filter( - json_schema, self.model, self.tokenizer + params.json_schema, self.model, self.tokenizer ) # Add regex filter if it exists - regex_pattern = kwargs.get("regex_pattern") - if regex_pattern: - grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer) + if params.regex_pattern: + grammar_handler.add_regex_filter( + params.regex_pattern, self.model, self.tokenizer + ) # Add EBNF filter if it exists - grammar_string = kwargs.get("grammar_string") - if grammar_string: - grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer) + if params.grammar_string: + grammar_handler.add_kbnf_filter( + params.grammar_string, self.model, self.tokenizer + ) # Set banned strings - banned_strings = kwargs.get("banned_strings") + banned_strings = params.banned_strings if banned_strings and len(grammar_handler.filters) > 0: logger.warning( "Disabling banned_strings because " @@ -1220,16 +1220,12 @@ class ExllamaV2Container: banned_strings = [] - stop_conditions = kwargs.get("stop") - add_bos_token = kwargs.get("add_bos_token"), True - ban_eos_token = kwargs.get("ban_eos_token"), False - logit_bias = kwargs.get("logit_bias") - - # Logprobs - request_logprobs = kwargs.get("logprobs") + stop_conditions = params.stop + add_bos_token = params.add_bos_token + ban_eos_token = params.ban_eos_token # Speculative Ngram - self.generator.speculative_ngram = kwargs.get("speculative_ngram") + self.generator.speculative_ngram = params.speculative_ngram # Override sampler settings for temp = 0 if gen_settings.temperature == 0: @@ -1244,17 +1240,15 @@ class ExllamaV2Container: ) # Set banned tokens - banned_tokens = kwargs.get("banned_tokens") - if banned_tokens: - gen_settings.disallow_tokens(self.tokenizer, banned_tokens) + if params.banned_tokens: + gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens) # Set allowed tokens - allowed_tokens = kwargs.get("allowed_tokens") - if allowed_tokens: - gen_settings.allow_tokens(self.tokenizer, allowed_tokens) + if params.allowed_tokens: + gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens) # Set logit bias - if logit_bias: + if params.logit_bias: # Create a vocab tensor if it doesn't exist for token biasing if gen_settings.token_bias is None: padding = -self.tokenizer.config.vocab_size % 32 @@ -1264,7 +1258,7 @@ class ExllamaV2Container: ) # Map logits to the tensor with their biases - for token_id, bias in logit_bias.items(): + for token_id, bias in params.logit_bias.items(): if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)): gen_settings.token_bias[token_id] = bias else: @@ -1289,7 +1283,7 @@ class ExllamaV2Container: stop_conditions += eos_tokens # Get multimodal embeddings if present - mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + # TODO: Remove kwargs and pass this as optional mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] # Encode both positive and negative prompts @@ -1312,7 +1306,7 @@ class ExllamaV2Container: # Automatically set max_tokens to fill up the context # This should be an OK default, but may be changed in the future max_tokens = unwrap( - kwargs.get("max_tokens"), + params.max_tokens, self.config.max_seq_len - max(context_len, negative_context_len), ) if max_tokens < 1: @@ -1349,12 +1343,6 @@ class ExllamaV2Container: f"is greater than cache_size {self.cache_size}" ) - # Set min_tokens to generate while keeping EOS banned - min_tokens = kwargs.get("min_tokens") - - # This is an inverse of skip_special_tokens - decode_special_tokens = not kwargs.get("skip_special_tokens") - # Log prompt to console. Add the BOS token if specified log_prompt( f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}", @@ -1369,17 +1357,17 @@ class ExllamaV2Container: self.generator, input_ids=input_ids, max_new_tokens=max_tokens, - min_new_tokens=min_tokens, + min_new_tokens=params.min_tokens, gen_settings=gen_settings, stop_conditions=stop_conditions, - decode_special_tokens=decode_special_tokens, + decode_special_tokens=not params.skip_special_tokens, filters=grammar_handler.filters, filter_prefer_eos=bool(grammar_handler.filters), - return_probs=request_logprobs > 0, - return_top_tokens=request_logprobs, - return_logits=request_logprobs > 0, + return_probs=params.logprobs > 0, + return_top_tokens=params.logprobs, + return_logits=params.logprobs > 0, banned_strings=banned_strings, - token_healing=token_healing, + token_healing=params.token_healing, identifier=job_id, embeddings=mm_embeddings_content, ) @@ -1418,7 +1406,7 @@ class ExllamaV2Container: "offset": len(full_response), } - if request_logprobs > 0: + if params.logprobs > 0: # Get top tokens and probs top_tokens = unwrap( result.get("top_k_tokens"), @@ -1494,8 +1482,7 @@ class ExllamaV2Container: request_id=request_id, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=eos_tokens, - **kwargs, - generate_window=generate_window, + **params.model_dump(), auto_scale_penalty_range=auto_scale_penalty_range, ) diff --git a/common/sampling.py b/common/sampling.py index c7ef934..1b4bc69 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -282,6 +282,11 @@ class BaseSamplerRequest(BaseModel): ge=0, ) + logprobs: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("logprobs", 0), + ge=0, + ) + @field_validator("top_k", mode="before") def convert_top_k(cls, v): """Fixes instance if Top-K is -1.""" diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index 640ead7..9c90188 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False stream_options: Optional[ChatCompletionStreamOptions] = None - logprobs: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("logprobs", 0), - ge=0, - ) response_format: Optional[CompletionResponseFormat] = Field( default_factory=CompletionResponseFormat ) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index f66cb52..dcc0dea 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -333,11 +333,11 @@ async def stream_generate_chat_completion( _stream_collector( n, gen_queue, - prompt, request.state.id, + prompt, + task_gen_params, abort_event, - embeddings=embeddings, - **task_gen_params.model_dump(exclude={"prompt"}), + mm_embeddings=embeddings, ) ) @@ -422,10 +422,10 @@ async def generate_chat_completion( gen_tasks.append( asyncio.create_task( model.container.generate( - prompt, request.state.id, - embeddings=embeddings, - **data.model_dump(exclude={"prompt"}), + prompt, + data, + mm_embeddings=embeddings, ) ) ) @@ -465,7 +465,6 @@ async def generate_tool_calls( # FIXME: May not be necessary depending on how the codebase evolves tool_data = data.model_copy(deep=True) tool_data.json_schema = tool_data.tool_call_schema - gen_params = tool_data.model_dump() for idx, gen in enumerate(generations): if gen["stop_str"] in tool_data.tool_call_start: @@ -488,10 +487,10 @@ async def generate_tool_calls( gen_tasks.append( asyncio.create_task( model.container.generate( - pre_tool_prompt, request.state.id, + pre_tool_prompt, + tool_data, embeddings=mm_embeddings, - **gen_params, ) ) ) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 9fd8b90..1d706d4 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -8,12 +8,12 @@ import asyncio import pathlib from asyncio import CancelledError from fastapi import HTTPException, Request -from typing import List, Union - from loguru import logger +from typing import List, Optional, Union from common import model from common.auth import get_key_permission +from common.multimodal import MultimodalEmbeddingWrapper from common.networking import ( get_generator_error, handle_request_disconnect, @@ -86,16 +86,21 @@ def _create_response( async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, - prompt: str, request_id: str, + prompt: str, + params: CompletionRequest, abort_event: asyncio.Event, - **kwargs, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """Collects a stream and places results in a common queue""" try: new_generation = model.container.generate_gen( - prompt, request_id, abort_event, **kwargs + request_id, + prompt, + params, + abort_event, + mm_embeddings, ) async for generation in new_generation: generation["index"] = task_idx @@ -195,10 +200,10 @@ async def stream_generate_completion( _stream_collector( n, gen_queue, - data.prompt, request.state.id, + data.prompt, + task_gen_params, abort_event, - **task_gen_params.model_dump(exclude={"prompt"}), ) ) @@ -256,9 +261,9 @@ async def generate_completion( gen_tasks.append( asyncio.create_task( model.container.generate( - data.prompt, request.state.id, - **task_gen_params.model_dump(exclude={"prompt"}), + data.prompt, + task_gen_params, ) ) )