Model: Handle finish chunks and logprobs in separate functions
Helps split up and trim the generate_gen function. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
f2c7da2faf
commit
7e007f0761
1 changed files with 62 additions and 50 deletions
|
|
@ -1004,15 +1004,6 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
params.penalty_range, self.config.max_seq_len
|
||||
)
|
||||
|
||||
# TODO: Not used for some reason?
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
auto_scale_penalty_range = (
|
||||
gen_settings.token_frequency_penalty != 0
|
||||
or gen_settings.token_presence_penalty != 0
|
||||
) and gen_settings.token_repetition_range == -1
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed
|
||||
# fallback
|
||||
|
|
@ -1115,6 +1106,51 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
"in the model's vocab. Skipping."
|
||||
)
|
||||
|
||||
# Adds logprobs to a generation chunk
|
||||
def handle_logprobs(self, result: dict, generation: dict):
|
||||
top_tokens = unwrap(
|
||||
result.get("top_k_tokens"),
|
||||
torch.empty((1, 0, 1), dtype=torch.long),
|
||||
)
|
||||
|
||||
top_probs = unwrap(
|
||||
result.get("top_k_probs"),
|
||||
torch.empty((1, 0, 1), dtype=torch.float),
|
||||
)
|
||||
|
||||
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
||||
logprobs = self.get_logprobs(top_tokens, top_probs)
|
||||
generation["logprobs"] = logprobs
|
||||
|
||||
# The first logprob is the selected token prob
|
||||
generation["token_probs"] = {
|
||||
token: logprobs[token] for token in list(logprobs.keys())[:1]
|
||||
}
|
||||
|
||||
# Creates and returns a finish chunk
|
||||
def handle_finish_chunk(self, result: dict, generation: dict):
|
||||
eos_reason = result.get("eos_reason")
|
||||
|
||||
stop_str = None
|
||||
if eos_reason == "max_new_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
# Grab stop string if stop was the reason
|
||||
if eos_reason == "stop_token":
|
||||
stop_str = result.get("eos_triggering_token_str")
|
||||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
|
||||
finish_chunk = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
||||
return finish_chunk
|
||||
|
||||
async def generate_gen(
|
||||
self,
|
||||
request_id: str,
|
||||
|
|
@ -1174,6 +1210,14 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
"Please use an ampere (30 series) or higher GPU for CFG support."
|
||||
)
|
||||
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
auto_scale_penalty_range = (
|
||||
gen_settings.token_frequency_penalty != 0
|
||||
or gen_settings.token_presence_penalty != 0
|
||||
) and gen_settings.token_repetition_range == -1
|
||||
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = params.add_bos_token
|
||||
ban_eos_token = params.ban_eos_token
|
||||
|
|
@ -1316,58 +1360,25 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
"offset": len(full_response),
|
||||
}
|
||||
|
||||
# Increase penalty range to generated token amount
|
||||
if auto_scale_penalty_range:
|
||||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
# Handle logprobs
|
||||
if params.logprobs > 0:
|
||||
# Get top tokens and probs
|
||||
top_tokens = unwrap(
|
||||
result.get("top_k_tokens"),
|
||||
torch.empty((1, 0, 1), dtype=torch.long),
|
||||
)
|
||||
|
||||
top_probs = unwrap(
|
||||
result.get("top_k_probs"),
|
||||
torch.empty((1, 0, 1), dtype=torch.float),
|
||||
)
|
||||
|
||||
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
||||
logprobs = self.get_logprobs(top_tokens, top_probs)
|
||||
generation["logprobs"] = logprobs
|
||||
|
||||
# The first logprob is the selected token prob
|
||||
generation["token_probs"] = {
|
||||
token: logprobs[token]
|
||||
for token in list(logprobs.keys())[:1]
|
||||
}
|
||||
self.handle_logprobs(result, generation)
|
||||
|
||||
yield generation
|
||||
|
||||
# Second yield if eos is true
|
||||
# Yield a finish chunk when generation is finished
|
||||
if result.get("eos"):
|
||||
log_response(request_id, full_response)
|
||||
|
||||
eos_reason = result.get("eos_reason")
|
||||
|
||||
stop_str = None
|
||||
if eos_reason == "max_new_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
# Grab stop string if stop was the reason
|
||||
if eos_reason == "stop_token":
|
||||
stop_str = result.get("eos_triggering_token_str")
|
||||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
generation = self.handle_finish_chunk(result, generation)
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
|
||||
# Remove the token text
|
||||
generation = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
||||
yield generation
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
|
|
@ -1394,6 +1405,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
eos_token_id=eos_tokens,
|
||||
prompt=prompt,
|
||||
**params.model_dump(exclude={"prompt"}),
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
)
|
||||
|
||||
# Log the metrics if present
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue