From 9390d362dd99dfe29ed7eb9b44db9998aeb20443 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 1 Aug 2024 00:19:21 -0400 Subject: [PATCH] Model: Log generation params and metrics after the prompt/response A user's prompt and response can be large in the console. Therefore, always log the smaller payloads (ex. gen params + metrics) after the large chunks. However, it's recommended to keep prompt logging off anyways since it'll result in console spam. Signed-off-by: kingbri --- backends/exllamav2/model.py | 76 ++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c7c032a..373a753 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1126,31 +1126,6 @@ class ExllamaV2Container: # This is an inverse of skip_special_tokens decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) - # Log generation options to console - # Some options are too large, so log the args instead - log_generation_params( - request_id=request_id, - max_tokens=max_tokens, - min_tokens=min_tokens, - stream=kwargs.get("stream"), - **gen_settings_log_dict, - token_healing=token_healing, - auto_scale_penalty_range=auto_scale_penalty_range, - generate_window=generate_window, - bos_token_id=self.tokenizer.bos_token_id, - eos_token_id=eos_tokens, - add_bos_token=add_bos_token, - ban_eos_token=ban_eos_token, - skip_special_tokens=not decode_special_tokens, - speculative_ngram=self.generator.speculative_ngram, - logprobs=request_logprobs, - stop_conditions=stop_conditions, - banned_tokens=banned_tokens, - banned_strings=banned_strings, - logit_bias=logit_bias, - filters=grammar_handler.filters, - ) - # Log prompt to console log_prompt(prompt, request_id, negative_prompt) @@ -1181,6 +1156,7 @@ class ExllamaV2Container: max_seq_len = self.config.max_seq_len generated_tokens = 0 full_response = "" + metrics_result = {} # Get the generation status once it's ready try: @@ -1241,16 +1217,8 @@ class ExllamaV2Container: "length" if eos_reason == "max_new_tokens" else "stop" ) - log_metrics( - result.get("time_enqueued"), - result.get("prompt_tokens"), - result.get("cached_tokens"), - result.get("time_prefill"), - result.get("new_tokens"), - result.get("time_generate"), - context_len, - max_seq_len, - ) + # Save the final result for metrics logging + metrics_result = result # Remove the token text generation = { @@ -1274,3 +1242,41 @@ class ExllamaV2Container: asyncio.ensure_future(self.create_generator()) raise ex + finally: + # Log generation options to console + # Some options are too large, so log the args instead + log_generation_params( + request_id=request_id, + max_tokens=max_tokens, + min_tokens=min_tokens, + stream=kwargs.get("stream"), + **gen_settings_log_dict, + token_healing=token_healing, + auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=eos_tokens, + add_bos_token=add_bos_token, + ban_eos_token=ban_eos_token, + skip_special_tokens=not decode_special_tokens, + speculative_ngram=self.generator.speculative_ngram, + logprobs=request_logprobs, + stop_conditions=stop_conditions, + banned_tokens=banned_tokens, + banned_strings=banned_strings, + logit_bias=logit_bias, + filters=grammar_handler.filters, + ) + + # Log the metrics if present + if metrics_result: + log_metrics( + metrics_result.get("time_enqueued"), + metrics_result.get("prompt_tokens"), + metrics_result.get("cached_tokens"), + metrics_result.get("time_prefill"), + metrics_result.get("new_tokens"), + metrics_result.get("time_generate"), + context_len, + max_seq_len, + )