Model: Handle finish chunks and logprobs in separate functions

Helps split up and trim the generate_gen function.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-24 21:19:03 -04:00
parent f2c7da2faf
commit 7e007f0761

View file

@ -1004,15 +1004,6 @@ class ExllamaV2Container(BaseModelContainer):
params.penalty_range, self.config.max_seq_len
)
# TODO: Not used for some reason?
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
# fallback
@ -1115,6 +1106,51 @@ class ExllamaV2Container(BaseModelContainer):
"in the model's vocab. Skipping."
)
# Adds logprobs to a generation chunk
def handle_logprobs(self, result: dict, generation: dict):
top_tokens = unwrap(
result.get("top_k_tokens"),
torch.empty((1, 0, 1), dtype=torch.long),
)
top_probs = unwrap(
result.get("top_k_probs"),
torch.empty((1, 0, 1), dtype=torch.float),
)
if top_tokens.numel() > 0 and top_probs.numel() > 0:
logprobs = self.get_logprobs(top_tokens, top_probs)
generation["logprobs"] = logprobs
# The first logprob is the selected token prob
generation["token_probs"] = {
token: logprobs[token] for token in list(logprobs.keys())[:1]
}
# Creates and returns a finish chunk
def handle_finish_chunk(self, result: dict, generation: dict):
eos_reason = result.get("eos_reason")
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
return finish_chunk
async def generate_gen(
self,
request_id: str,
@ -1174,6 +1210,14 @@ class ExllamaV2Container(BaseModelContainer):
"Please use an ampere (30 series) or higher GPU for CFG support."
)
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
stop_conditions = params.stop
add_bos_token = params.add_bos_token
ban_eos_token = params.ban_eos_token
@ -1316,58 +1360,25 @@ class ExllamaV2Container(BaseModelContainer):
"offset": len(full_response),
}
# Increase penalty range to generated token amount
if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
# Handle logprobs
if params.logprobs > 0:
# Get top tokens and probs
top_tokens = unwrap(
result.get("top_k_tokens"),
torch.empty((1, 0, 1), dtype=torch.long),
)
top_probs = unwrap(
result.get("top_k_probs"),
torch.empty((1, 0, 1), dtype=torch.float),
)
if top_tokens.numel() > 0 and top_probs.numel() > 0:
logprobs = self.get_logprobs(top_tokens, top_probs)
generation["logprobs"] = logprobs
# The first logprob is the selected token prob
generation["token_probs"] = {
token: logprobs[token]
for token in list(logprobs.keys())[:1]
}
self.handle_logprobs(result, generation)
yield generation
# Second yield if eos is true
# Yield a finish chunk when generation is finished
if result.get("eos"):
log_response(request_id, full_response)
eos_reason = result.get("eos_reason")
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
generation = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
# Remove the token text
generation = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
yield generation
break
except asyncio.CancelledError:
@ -1394,6 +1405,7 @@ class ExllamaV2Container(BaseModelContainer):
eos_token_id=eos_tokens,
prompt=prompt,
**params.model_dump(exclude={"prompt"}),
auto_scale_penalty_range=auto_scale_penalty_range,
)
# Log the metrics if present