Model: Extract settings creation to a separate function

Maybe move this out of the class entirely, but for now, it makes
sense to encapsulate this logic.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-16 12:57:27 -04:00
parent 5697204e47
commit 2f5235e1a3

View file

@ -1044,35 +1044,13 @@ class ExllamaV2Container:
return params
async def generate_gen(
def assign_gen_params(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
gen_settings: ExLlamaV2Sampler.Settings,
grammar_handler: ExLlamaV2Grammar,
banned_strings: List[str],
):
"""
Create generator function for prompt completion.
for kwargs, check common/sampling.py
"""
# Wait for load lock to be freed before processing
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
prompts = [prompt]
# TODO: Not used for some reason?
generate_window = max(params.generate_window, self.config.max_seq_len // 8)
# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
# Check unsupported settings for dev wheels
params = self.check_unsupported_settings(params)
# Apply settings
gen_settings.temperature = params.temperature
gen_settings.temperature_last = params.temperature_last
@ -1120,25 +1098,6 @@ class ExllamaV2Container:
gen_settings.mirostat_tau = params.mirostat_tau
gen_settings.mirostat_eta = params.mirostat_eta
# Set CFG scale and negative prompt
cfg_scale = params.cfg_scale
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.paged:
gen_settings.cfg_scale = cfg_scale
# If the negative prompt is empty, use the BOS token
negative_prompt = unwrap(
params.negative_prompt, self.tokenizer.bos_token
)
prompts.append(negative_prompt)
else:
logger.warning(
"CFG is currently disabled because paged mode is disabled. "
"Please use an ampere (30 series) or higher GPU for CFG support."
)
# Penalties
gen_settings.token_repetition_penalty = params.repetition_penalty
gen_settings.token_frequency_penalty = params.frequency_penalty
@ -1189,9 +1148,6 @@ class ExllamaV2Container:
self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
}
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists
if params.json_schema:
grammar_handler.add_json_schema_filter(
@ -1220,10 +1176,6 @@ class ExllamaV2Container:
banned_strings = []
stop_conditions = params.stop
add_bos_token = params.add_bos_token
ban_eos_token = params.ban_eos_token
# Speculative Ngram
self.generator.speculative_ngram = params.speculative_ngram
@ -1267,6 +1219,62 @@ class ExllamaV2Container:
"in the model's vocab. Skipping."
)
async def generate_gen(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""
Create generator function for prompt completion.
for kwargs, check common/sampling.py
"""
# Wait for load lock to be freed before processing
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
prompts = [prompt]
gen_settings = ExLlamaV2Sampler.Settings()
grammar_handler = ExLlamaV2Grammar()
banned_strings = []
# TODO: Not used for some reason?
generate_window = max(params.generate_window, self.config.max_seq_len // 8)
self.assign_gen_params(
params,
gen_settings,
grammar_handler,
banned_strings,
)
# Set CFG scale and negative prompt
cfg_scale = params.cfg_scale
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.paged:
gen_settings.cfg_scale = cfg_scale
# If the negative prompt is empty, use the BOS token
negative_prompt = unwrap(
params.negative_prompt, self.tokenizer.bos_token
)
prompts.append(negative_prompt)
else:
logger.warning(
"CFG is currently disabled because paged mode is disabled. "
"Please use an ampere (30 series) or higher GPU for CFG support."
)
stop_conditions = params.stop
add_bos_token = params.add_bos_token
ban_eos_token = params.ban_eos_token
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
@ -1283,7 +1291,6 @@ class ExllamaV2Container:
stop_conditions += eos_tokens
# Get multimodal embeddings if present
# TODO: Remove kwargs and pass this as optional
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
# Encode both positive and negative prompts
@ -1316,16 +1323,14 @@ class ExllamaV2Container:
# Determine if the negative context or the context length is bigger
context_to_check = max(negative_context_len, context_len)
# Check highest possible total length of request
if context_to_check + max_tokens > self.config.max_seq_len:
# Check total length of prompt against max context length
if context_to_check > self.config.max_seq_len:
preamble = (
"Negative prompt request"
if negative_context_len > context_len
else "Request"
"Negative prompt" if negative_context_len > context_len else "Prompt"
)
raise ValueError(
f"{preamble} length {context_to_check} + {max_tokens} is greater than "
f"{preamble} length {context_to_check} is greater than "
f"max_seq_len {self.config.max_seq_len}"
)
@ -1482,8 +1487,8 @@ class ExllamaV2Container:
request_id=request_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
**params.model_dump(),
auto_scale_penalty_range=auto_scale_penalty_range,
prompt=prompt,
**params.model_dump(exclude={"prompt"}),
)
# Log the metrics if present