Merge pull request #373 from AUTOMATIC1111/exl3-logprobs
add logprobs support for exl3
This commit is contained in:
commit
6623dbcd86
1 changed files with 41 additions and 0 deletions
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import gc
|
||||
import pathlib
|
||||
import re
|
||||
from itertools import zip_longest
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
|
|
@ -608,6 +609,21 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
"unk_token": self.tokenizer.unk_token,
|
||||
}
|
||||
|
||||
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
|
||||
top_tokens = [
|
||||
self.tokenizer.get_id_to_piece_list(True)[index]
|
||||
for index in token_ids.flatten().tolist()
|
||||
]
|
||||
|
||||
top_values = torch.log(token_probs).flatten().tolist()
|
||||
|
||||
# Cannot return -inf in JSON
|
||||
cleaned_values = [
|
||||
-1000 if value == float("-inf") else value for value in top_values
|
||||
]
|
||||
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
|
|
@ -730,6 +746,26 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# Clean up and remove the job from active IDs
|
||||
del self.active_job_ids[request_id]
|
||||
|
||||
def handle_logprobs(self, result: dict, generation: dict):
|
||||
top_tokens = unwrap(
|
||||
result.get("top_k_tokens"),
|
||||
torch.empty((1, 0, 1), dtype=torch.long),
|
||||
)
|
||||
|
||||
top_probs = unwrap(
|
||||
result.get("top_k_probs"),
|
||||
torch.empty((1, 0, 1), dtype=torch.float),
|
||||
)
|
||||
|
||||
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
||||
logprobs = self.get_logprobs(top_tokens, top_probs)
|
||||
generation["logprobs"] = logprobs
|
||||
|
||||
# The first logprob is the selected token prob
|
||||
generation["token_probs"] = {
|
||||
token: logprobs[token] for token in list(logprobs.keys())[:1]
|
||||
}
|
||||
|
||||
def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
|
||||
eos_reason = result.get("eos_reason")
|
||||
|
||||
|
|
@ -915,6 +951,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
stop_conditions=stop_conditions,
|
||||
banned_strings=params.banned_strings,
|
||||
embeddings=mm_embeddings_content,
|
||||
return_top_tokens=params.logprobs,
|
||||
)
|
||||
|
||||
generated_tokens = 0
|
||||
|
|
@ -948,6 +985,10 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
|
||||
if params.logprobs > 0:
|
||||
self.handle_logprobs(result, generation)
|
||||
|
||||
yield generation
|
||||
|
||||
if result.get("eos"):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue