From 056527ceb30f0c4974412a6b95b79d9e2de1b6f8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <-> Date: Sun, 3 Aug 2025 11:42:32 +0300 Subject: [PATCH] add logprobs support for exl3 --- backends/exllamav3/model.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 6c1fa79..69ee0d4 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -2,6 +2,7 @@ import asyncio import gc import pathlib import re +from itertools import zip_longest from typing import ( Any, AsyncIterator, @@ -608,6 +609,22 @@ class ExllamaV3Container(BaseModelContainer): "unk_token": self.tokenizer.unk_token, } + def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): + top_tokens = [ + self.tokenizer.get_id_to_piece_list(True)[index] + for index in token_ids.flatten().tolist() + ] + + top_values = torch.log(token_probs).flatten().tolist() + + # Cannot return -inf in JSON + cleaned_values = [ + -1000 if value == float("-inf") else value for value in top_values + ] + + return dict(zip_longest(top_tokens, cleaned_values)) + + async def generate( self, request_id: str, @@ -730,6 +747,26 @@ class ExllamaV3Container(BaseModelContainer): # Clean up and remove the job from active IDs del self.active_job_ids[request_id] + def handle_logprobs(self, result: dict, generation: dict): + top_tokens = unwrap( + result.get("top_k_tokens"), + torch.empty((1, 0, 1), dtype=torch.long), + ) + + top_probs = unwrap( + result.get("top_k_probs"), + torch.empty((1, 0, 1), dtype=torch.float), + ) + + if top_tokens.numel() > 0 and top_probs.numel() > 0: + logprobs = self.get_logprobs(top_tokens, top_probs) + generation["logprobs"] = logprobs + + # The first logprob is the selected token prob + generation["token_probs"] = { + token: logprobs[token] for token in list(logprobs.keys())[:1] + } + def handle_finish_chunk(self, result: dict, request_id: str, full_text: str): eos_reason = result.get("eos_reason") @@ -915,6 +952,7 @@ class ExllamaV3Container(BaseModelContainer): stop_conditions=stop_conditions, banned_strings=params.banned_strings, embeddings=mm_embeddings_content, + return_top_tokens=params.logprobs, ) generated_tokens = 0 @@ -948,6 +986,10 @@ class ExllamaV3Container(BaseModelContainer): "generated_tokens": generated_tokens, "offset": len(full_response), } + + if params.logprobs > 0: + self.handle_logprobs(result, generation) + yield generation if result.get("eos"):