From 7e007f0761ae29476a630b5e67331adc18382aca Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Thu, 24 Apr 2025 21:19:03 -0400 Subject: [PATCH] Model: Handle finish chunks and logprobs in separate functions Helps split up and trim the generate_gen function. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/exllamav2/model.py | 112 ++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 50 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 01d13d4..55aa497 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1004,15 +1004,6 @@ class ExllamaV2Container(BaseModelContainer): params.penalty_range, self.config.max_seq_len ) - # TODO: Not used for some reason? - # Dynamically scale penalty range to output tokens - # Only do this if freq/pres pen is enabled - # and the repetition range is -1 - auto_scale_penalty_range = ( - gen_settings.token_frequency_penalty != 0 - or gen_settings.token_presence_penalty != 0 - ) and gen_settings.token_repetition_range == -1 - # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed # fallback @@ -1115,6 +1106,51 @@ class ExllamaV2Container(BaseModelContainer): "in the model's vocab. Skipping." ) + # Adds logprobs to a generation chunk + def handle_logprobs(self, result: dict, generation: dict): + top_tokens = unwrap( + result.get("top_k_tokens"), + torch.empty((1, 0, 1), dtype=torch.long), + ) + + top_probs = unwrap( + result.get("top_k_probs"), + torch.empty((1, 0, 1), dtype=torch.float), + ) + + if top_tokens.numel() > 0 and top_probs.numel() > 0: + logprobs = self.get_logprobs(top_tokens, top_probs) + generation["logprobs"] = logprobs + + # The first logprob is the selected token prob + generation["token_probs"] = { + token: logprobs[token] for token in list(logprobs.keys())[:1] + } + + # Creates and returns a finish chunk + def handle_finish_chunk(self, result: dict, generation: dict): + eos_reason = result.get("eos_reason") + + stop_str = None + if eos_reason == "max_new_tokens": + finish_reason = "length" + else: + finish_reason = "stop" + # Grab stop string if stop was the reason + if eos_reason == "stop_token": + stop_str = result.get("eos_triggering_token_str") + elif eos_reason == "stop_string": + stop_str = result.get("eos_triggering_string") + + finish_chunk = { + "prompt_tokens": generation.get("prompt_tokens"), + "generated_tokens": generation.get("generated_tokens"), + "finish_reason": finish_reason, + "stop_str": stop_str, + } + + return finish_chunk + async def generate_gen( self, request_id: str, @@ -1174,6 +1210,14 @@ class ExllamaV2Container(BaseModelContainer): "Please use an ampere (30 series) or higher GPU for CFG support." ) + # Dynamically scale penalty range to output tokens + # Only do this if freq/pres pen is enabled + # and the repetition range is -1 + auto_scale_penalty_range = ( + gen_settings.token_frequency_penalty != 0 + or gen_settings.token_presence_penalty != 0 + ) and gen_settings.token_repetition_range == -1 + stop_conditions = params.stop add_bos_token = params.add_bos_token ban_eos_token = params.ban_eos_token @@ -1316,58 +1360,25 @@ class ExllamaV2Container(BaseModelContainer): "offset": len(full_response), } + # Increase penalty range to generated token amount + if auto_scale_penalty_range: + gen_settings.token_repetition_range = generated_tokens + + # Handle logprobs if params.logprobs > 0: - # Get top tokens and probs - top_tokens = unwrap( - result.get("top_k_tokens"), - torch.empty((1, 0, 1), dtype=torch.long), - ) - - top_probs = unwrap( - result.get("top_k_probs"), - torch.empty((1, 0, 1), dtype=torch.float), - ) - - if top_tokens.numel() > 0 and top_probs.numel() > 0: - logprobs = self.get_logprobs(top_tokens, top_probs) - generation["logprobs"] = logprobs - - # The first logprob is the selected token prob - generation["token_probs"] = { - token: logprobs[token] - for token in list(logprobs.keys())[:1] - } + self.handle_logprobs(result, generation) yield generation - # Second yield if eos is true + # Yield a finish chunk when generation is finished if result.get("eos"): log_response(request_id, full_response) - eos_reason = result.get("eos_reason") - - stop_str = None - if eos_reason == "max_new_tokens": - finish_reason = "length" - else: - finish_reason = "stop" - # Grab stop string if stop was the reason - if eos_reason == "stop_token": - stop_str = result.get("eos_triggering_token_str") - elif eos_reason == "stop_string": - stop_str = result.get("eos_triggering_string") + generation = self.handle_finish_chunk(result, generation) # Save the final result for metrics logging metrics_result = result - # Remove the token text - generation = { - "prompt_tokens": generation.get("prompt_tokens"), - "generated_tokens": generation.get("generated_tokens"), - "finish_reason": finish_reason, - "stop_str": stop_str, - } - yield generation break except asyncio.CancelledError: @@ -1394,6 +1405,7 @@ class ExllamaV2Container(BaseModelContainer): eos_token_id=eos_tokens, prompt=prompt, **params.model_dump(exclude={"prompt"}), + auto_scale_penalty_range=auto_scale_penalty_range, ) # Log the metrics if present