add logprobs support for exl3

This commit is contained in:
AUTOMATIC 2025-08-03 11:42:32 +03:00
parent 03d72a37be
commit 056527ceb3

View file

@ -2,6 +2,7 @@ import asyncio
import gc
import pathlib
import re
from itertools import zip_longest
from typing import (
Any,
AsyncIterator,
@ -608,6 +609,22 @@ class ExllamaV3Container(BaseModelContainer):
"unk_token": self.tokenizer.unk_token,
}
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
top_tokens = [
self.tokenizer.get_id_to_piece_list(True)[index]
for index in token_ids.flatten().tolist()
]
top_values = torch.log(token_probs).flatten().tolist()
# Cannot return -inf in JSON
cleaned_values = [
-1000 if value == float("-inf") else value for value in top_values
]
return dict(zip_longest(top_tokens, cleaned_values))
async def generate(
self,
request_id: str,
@ -730,6 +747,26 @@ class ExllamaV3Container(BaseModelContainer):
# Clean up and remove the job from active IDs
del self.active_job_ids[request_id]
def handle_logprobs(self, result: dict, generation: dict):
top_tokens = unwrap(
result.get("top_k_tokens"),
torch.empty((1, 0, 1), dtype=torch.long),
)
top_probs = unwrap(
result.get("top_k_probs"),
torch.empty((1, 0, 1), dtype=torch.float),
)
if top_tokens.numel() > 0 and top_probs.numel() > 0:
logprobs = self.get_logprobs(top_tokens, top_probs)
generation["logprobs"] = logprobs
# The first logprob is the selected token prob
generation["token_probs"] = {
token: logprobs[token] for token in list(logprobs.keys())[:1]
}
def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
eos_reason = result.get("eos_reason")
@ -915,6 +952,7 @@ class ExllamaV3Container(BaseModelContainer):
stop_conditions=stop_conditions,
banned_strings=params.banned_strings,
embeddings=mm_embeddings_content,
return_top_tokens=params.logprobs,
)
generated_tokens = 0
@ -948,6 +986,10 @@ class ExllamaV3Container(BaseModelContainer):
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
if params.logprobs > 0:
self.handle_logprobs(result, generation)
yield generation
if result.get("eos"):