diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6accc88..6d69f63 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,17 +1,13 @@ """The model container class for ExLlamaV2 models.""" -from functools import partial import aiofiles import asyncio import gc import math import pathlib import traceback -from backends.exllamav2.vision import clear_image_embedding_cache -from common.multimodal import MultimodalEmbeddingWrapper import torch import uuid -from copy import deepcopy from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, @@ -32,7 +28,7 @@ from exllamav2.generator import ( ) from itertools import zip_longest from loguru import logger -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from ruamel.yaml import YAML @@ -47,6 +43,7 @@ from backends.exllamav2.utils import ( hardware_supports_flash_attn, supports_paged_attn, ) +from backends.exllamav2.vision import clear_image_embedding_cache from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, @@ -54,6 +51,8 @@ from common.gen_logging import ( log_prompt, log_response, ) +from common.multimodal import MultimodalEmbeddingWrapper +from common.sampling import BaseSamplerRequest from common.templating import ( PromptTemplate, TemplateLoadError, @@ -976,15 +975,20 @@ class ExllamaV2Container: async def generate( self, - prompt: str, request_id: str, - abort_event: asyncio.Event = None, - **kwargs, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """Generate a response to a prompt.""" generations = [] async for generation in self.generate_gen( - prompt, request_id, abort_event, **kwargs + request_id, + prompt, + params, + abort_event, + mm_embeddings, ): generations.append(generation) @@ -1031,21 +1035,22 @@ class ExllamaV2Container: return joined_generation - def check_unsupported_settings(self, **kwargs): + def check_unsupported_settings(self, params: BaseSamplerRequest): """ Check and warn the user if a sampler is unsupported. Meant for dev wheels! """ - return kwargs + return params async def generate_gen( self, - prompt: str, request_id: str, + prompt: str, + params: BaseSamplerRequest, abort_event: Optional[asyncio.Event] = None, - **kwargs, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """ Create generator function for prompt completion. @@ -1059,46 +1064,43 @@ class ExllamaV2Container: prompts = [prompt] - token_healing = kwargs.get("token_healing") - generate_window = max( - kwargs.get("generate_window"), self.config.max_seq_len // 8 - ) + # TODO: Not used for some reason? + generate_window = max(params.generate_window, self.config.max_seq_len // 8) # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() # Check unsupported settings for dev wheels - kwargs = self.check_unsupported_settings(**kwargs) + params = self.check_unsupported_settings(params) # Apply settings - gen_settings.temperature = kwargs.get("temperature") - gen_settings.temperature_last = kwargs.get("temperature_last") - gen_settings.smoothing_factor = kwargs.get("smoothing_factor") - gen_settings.top_k = kwargs.get("top_k") - gen_settings.top_p = kwargs.get("top_p") - gen_settings.top_a = kwargs.get("top_a") - gen_settings.min_p = kwargs.get("min_p") - gen_settings.tfs = kwargs.get("tfs") - gen_settings.typical = kwargs.get("typical") - gen_settings.mirostat = kwargs.get("mirostat") - gen_settings.skew = kwargs.get("skew") + gen_settings.temperature = params.temperature + gen_settings.temperature_last = params.temperature_last + gen_settings.smoothing_factor = params.smoothing_factor + gen_settings.top_k = params.top_k + gen_settings.top_p = params.top_p + gen_settings.top_a = params.top_a + gen_settings.min_p = params.min_p + gen_settings.tfs = params.tfs + gen_settings.typical = params.typical + gen_settings.mirostat = params.mirostat + gen_settings.skew = params.skew # XTC - xtc_probability = kwargs.get("xtc_probability") - if xtc_probability > 0.0: - gen_settings.xtc_probability = xtc_probability + if params.xtc_probability > 0.0: + gen_settings.xtc_probability = params.xtc_probability # 0.1 is the default for this value - gen_settings.xtc_threshold = kwargs.get("xtc_threshold") + gen_settings.xtc_threshold = params.xtc_threshold # DynaTemp settings - max_temp = kwargs.get("max_temp") - min_temp = kwargs.get("min_temp") + max_temp = params.max_temp + min_temp = params.min_temp - if max_temp > min_temp: + if params.max_temp > params.min_temp: gen_settings.max_temp = max_temp gen_settings.min_temp = min_temp - gen_settings.temp_exponent = kwargs.get("temp_exponent") + gen_settings.temp_exponent = params.temp_exponent else: # Force to default values gen_settings.max_temp = 1.0 @@ -1115,11 +1117,11 @@ class ExllamaV2Container: ) # Default tau and eta fallbacks don't matter if mirostat is off - gen_settings.mirostat_tau = kwargs.get("mirostat_tau") - gen_settings.mirostat_eta = kwargs.get("mirostat_eta") + gen_settings.mirostat_tau = params.mirostat_tau + gen_settings.mirostat_eta = params.mirostat_eta # Set CFG scale and negative prompt - cfg_scale = kwargs.get("cfg_scale") + cfg_scale = params.cfg_scale negative_prompt = None if cfg_scale not in [None, 1.0]: if self.paged: @@ -1127,7 +1129,7 @@ class ExllamaV2Container: # If the negative prompt is empty, use the BOS token negative_prompt = unwrap( - kwargs.get("negative_prompt"), self.tokenizer.bos_token + params.negative_prompt, self.tokenizer.bos_token ) prompts.append(negative_prompt) @@ -1138,15 +1140,16 @@ class ExllamaV2Container: ) # Penalties - gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") - gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty") - gen_settings.token_presence_penalty = kwargs.get("presence_penalty") + gen_settings.token_repetition_penalty = params.repetition_penalty + gen_settings.token_frequency_penalty = params.frequency_penalty + gen_settings.token_presence_penalty = params.presence_penalty # Applies for all penalties despite being called token_repetition_range gen_settings.token_repetition_range = unwrap( - kwargs.get("penalty_range"), self.config.max_seq_len + params.penalty_range, self.config.max_seq_len ) + # TODO: Not used for some reason? # Dynamically scale penalty range to output tokens # Only do this if freq/pres pen is enabled # and the repetition range is -1 @@ -1164,54 +1167,51 @@ class ExllamaV2Container: else: fallback_decay = gen_settings.token_repetition_range gen_settings.token_repetition_decay = coalesce( - kwargs.get("repetition_decay"), fallback_decay, 0 + params.repetition_decay, fallback_decay, 0 ) # DRY options - dry_multiplier = kwargs.get("dry_multiplier") + dry_multiplier = params.dry_multiplier # < 0 = disabled if dry_multiplier > 0: gen_settings.dry_multiplier = dry_multiplier - - gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length") - gen_settings.dry_base = kwargs.get("dry_base") + gen_settings.dry_allowed_length = params.dry_allowed_length + gen_settings.dry_base = params.dry_base # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range # Use max_seq_len as the fallback to stay consistent - gen_settings.dry_range = unwrap( - kwargs.get("dry_range"), self.config.max_seq_len - ) + gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len) # Tokenize sequence breakers - dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers") - if dry_sequence_breakers_json: + if params.dry_sequence_breakers: gen_settings.dry_sequence_breakers = { - self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json + self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers } # Initialize grammar handler grammar_handler = ExLlamaV2Grammar() # Add JSON schema filter if it exists - json_schema = kwargs.get("json_schema") - if json_schema: + if params.json_schema: grammar_handler.add_json_schema_filter( - json_schema, self.model, self.tokenizer + params.json_schema, self.model, self.tokenizer ) # Add regex filter if it exists - regex_pattern = kwargs.get("regex_pattern") - if regex_pattern: - grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer) + if params.regex_pattern: + grammar_handler.add_regex_filter( + params.regex_pattern, self.model, self.tokenizer + ) # Add EBNF filter if it exists - grammar_string = kwargs.get("grammar_string") - if grammar_string: - grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer) + if params.grammar_string: + grammar_handler.add_kbnf_filter( + params.grammar_string, self.model, self.tokenizer + ) # Set banned strings - banned_strings = kwargs.get("banned_strings") + banned_strings = params.banned_strings if banned_strings and len(grammar_handler.filters) > 0: logger.warning( "Disabling banned_strings because " @@ -1220,16 +1220,12 @@ class ExllamaV2Container: banned_strings = [] - stop_conditions = kwargs.get("stop") - add_bos_token = kwargs.get("add_bos_token"), True - ban_eos_token = kwargs.get("ban_eos_token"), False - logit_bias = kwargs.get("logit_bias") - - # Logprobs - request_logprobs = kwargs.get("logprobs") + stop_conditions = params.stop + add_bos_token = params.add_bos_token + ban_eos_token = params.ban_eos_token # Speculative Ngram - self.generator.speculative_ngram = kwargs.get("speculative_ngram") + self.generator.speculative_ngram = params.speculative_ngram # Override sampler settings for temp = 0 if gen_settings.temperature == 0: @@ -1244,17 +1240,15 @@ class ExllamaV2Container: ) # Set banned tokens - banned_tokens = kwargs.get("banned_tokens") - if banned_tokens: - gen_settings.disallow_tokens(self.tokenizer, banned_tokens) + if params.banned_tokens: + gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens) # Set allowed tokens - allowed_tokens = kwargs.get("allowed_tokens") - if allowed_tokens: - gen_settings.allow_tokens(self.tokenizer, allowed_tokens) + if params.allowed_tokens: + gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens) # Set logit bias - if logit_bias: + if params.logit_bias: # Create a vocab tensor if it doesn't exist for token biasing if gen_settings.token_bias is None: padding = -self.tokenizer.config.vocab_size % 32 @@ -1264,7 +1258,7 @@ class ExllamaV2Container: ) # Map logits to the tensor with their biases - for token_id, bias in logit_bias.items(): + for token_id, bias in params.logit_bias.items(): if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)): gen_settings.token_bias[token_id] = bias else: @@ -1289,7 +1283,7 @@ class ExllamaV2Container: stop_conditions += eos_tokens # Get multimodal embeddings if present - mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + # TODO: Remove kwargs and pass this as optional mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] # Encode both positive and negative prompts @@ -1312,7 +1306,7 @@ class ExllamaV2Container: # Automatically set max_tokens to fill up the context # This should be an OK default, but may be changed in the future max_tokens = unwrap( - kwargs.get("max_tokens"), + params.max_tokens, self.config.max_seq_len - max(context_len, negative_context_len), ) if max_tokens < 1: @@ -1349,12 +1343,6 @@ class ExllamaV2Container: f"is greater than cache_size {self.cache_size}" ) - # Set min_tokens to generate while keeping EOS banned - min_tokens = kwargs.get("min_tokens") - - # This is an inverse of skip_special_tokens - decode_special_tokens = not kwargs.get("skip_special_tokens") - # Log prompt to console. Add the BOS token if specified log_prompt( f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}", @@ -1369,17 +1357,17 @@ class ExllamaV2Container: self.generator, input_ids=input_ids, max_new_tokens=max_tokens, - min_new_tokens=min_tokens, + min_new_tokens=params.min_tokens, gen_settings=gen_settings, stop_conditions=stop_conditions, - decode_special_tokens=decode_special_tokens, + decode_special_tokens=not params.skip_special_tokens, filters=grammar_handler.filters, filter_prefer_eos=bool(grammar_handler.filters), - return_probs=request_logprobs > 0, - return_top_tokens=request_logprobs, - return_logits=request_logprobs > 0, + return_probs=params.logprobs > 0, + return_top_tokens=params.logprobs, + return_logits=params.logprobs > 0, banned_strings=banned_strings, - token_healing=token_healing, + token_healing=params.token_healing, identifier=job_id, embeddings=mm_embeddings_content, ) @@ -1418,7 +1406,7 @@ class ExllamaV2Container: "offset": len(full_response), } - if request_logprobs > 0: + if params.logprobs > 0: # Get top tokens and probs top_tokens = unwrap( result.get("top_k_tokens"), @@ -1494,8 +1482,7 @@ class ExllamaV2Container: request_id=request_id, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=eos_tokens, - **kwargs, - generate_window=generate_window, + **params.model_dump(), auto_scale_penalty_range=auto_scale_penalty_range, ) diff --git a/common/sampling.py b/common/sampling.py index c7ef934..1b4bc69 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -282,6 +282,11 @@ class BaseSamplerRequest(BaseModel): ge=0, ) + logprobs: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("logprobs", 0), + ge=0, + ) + @field_validator("top_k", mode="before") def convert_top_k(cls, v): """Fixes instance if Top-K is -1.""" diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index 640ead7..9c90188 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False stream_options: Optional[ChatCompletionStreamOptions] = None - logprobs: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("logprobs", 0), - ge=0, - ) response_format: Optional[CompletionResponseFormat] = Field( default_factory=CompletionResponseFormat ) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index f66cb52..dcc0dea 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -333,11 +333,11 @@ async def stream_generate_chat_completion( _stream_collector( n, gen_queue, - prompt, request.state.id, + prompt, + task_gen_params, abort_event, - embeddings=embeddings, - **task_gen_params.model_dump(exclude={"prompt"}), + mm_embeddings=embeddings, ) ) @@ -422,10 +422,10 @@ async def generate_chat_completion( gen_tasks.append( asyncio.create_task( model.container.generate( - prompt, request.state.id, - embeddings=embeddings, - **data.model_dump(exclude={"prompt"}), + prompt, + data, + mm_embeddings=embeddings, ) ) ) @@ -465,7 +465,6 @@ async def generate_tool_calls( # FIXME: May not be necessary depending on how the codebase evolves tool_data = data.model_copy(deep=True) tool_data.json_schema = tool_data.tool_call_schema - gen_params = tool_data.model_dump() for idx, gen in enumerate(generations): if gen["stop_str"] in tool_data.tool_call_start: @@ -488,10 +487,10 @@ async def generate_tool_calls( gen_tasks.append( asyncio.create_task( model.container.generate( - pre_tool_prompt, request.state.id, + pre_tool_prompt, + tool_data, embeddings=mm_embeddings, - **gen_params, ) ) ) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 9fd8b90..1d706d4 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -8,12 +8,12 @@ import asyncio import pathlib from asyncio import CancelledError from fastapi import HTTPException, Request -from typing import List, Union - from loguru import logger +from typing import List, Optional, Union from common import model from common.auth import get_key_permission +from common.multimodal import MultimodalEmbeddingWrapper from common.networking import ( get_generator_error, handle_request_disconnect, @@ -86,16 +86,21 @@ def _create_response( async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, - prompt: str, request_id: str, + prompt: str, + params: CompletionRequest, abort_event: asyncio.Event, - **kwargs, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """Collects a stream and places results in a common queue""" try: new_generation = model.container.generate_gen( - prompt, request_id, abort_event, **kwargs + request_id, + prompt, + params, + abort_event, + mm_embeddings, ) async for generation in new_generation: generation["index"] = task_idx @@ -195,10 +200,10 @@ async def stream_generate_completion( _stream_collector( n, gen_queue, - data.prompt, request.state.id, + data.prompt, + task_gen_params, abort_event, - **task_gen_params.model_dump(exclude={"prompt"}), ) ) @@ -256,9 +261,9 @@ async def generate_completion( gen_tasks.append( asyncio.create_task( model.container.generate( - data.prompt, request.state.id, - **task_gen_params.model_dump(exclude={"prompt"}), + data.prompt, + task_gen_params, ) ) )