Model: Extract settings creation to a separate function
Maybe move this out of the class entirely, but for now, it makes sense to encapsulate this logic. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
5697204e47
commit
2f5235e1a3
1 changed files with 66 additions and 61 deletions
|
|
@ -1044,35 +1044,13 @@ class ExllamaV2Container:
|
|||
|
||||
return params
|
||||
|
||||
async def generate_gen(
|
||||
def assign_gen_params(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
grammar_handler: ExLlamaV2Grammar,
|
||||
banned_strings: List[str],
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
||||
for kwargs, check common/sampling.py
|
||||
"""
|
||||
|
||||
# Wait for load lock to be freed before processing
|
||||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
# TODO: Not used for some reason?
|
||||
generate_window = max(params.generate_window, self.config.max_seq_len // 8)
|
||||
|
||||
# Sampler settings
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
|
||||
# Check unsupported settings for dev wheels
|
||||
params = self.check_unsupported_settings(params)
|
||||
|
||||
# Apply settings
|
||||
gen_settings.temperature = params.temperature
|
||||
gen_settings.temperature_last = params.temperature_last
|
||||
|
|
@ -1120,25 +1098,6 @@ class ExllamaV2Container:
|
|||
gen_settings.mirostat_tau = params.mirostat_tau
|
||||
gen_settings.mirostat_eta = params.mirostat_eta
|
||||
|
||||
# Set CFG scale and negative prompt
|
||||
cfg_scale = params.cfg_scale
|
||||
negative_prompt = None
|
||||
if cfg_scale not in [None, 1.0]:
|
||||
if self.paged:
|
||||
gen_settings.cfg_scale = cfg_scale
|
||||
|
||||
# If the negative prompt is empty, use the BOS token
|
||||
negative_prompt = unwrap(
|
||||
params.negative_prompt, self.tokenizer.bos_token
|
||||
)
|
||||
|
||||
prompts.append(negative_prompt)
|
||||
else:
|
||||
logger.warning(
|
||||
"CFG is currently disabled because paged mode is disabled. "
|
||||
"Please use an ampere (30 series) or higher GPU for CFG support."
|
||||
)
|
||||
|
||||
# Penalties
|
||||
gen_settings.token_repetition_penalty = params.repetition_penalty
|
||||
gen_settings.token_frequency_penalty = params.frequency_penalty
|
||||
|
|
@ -1189,9 +1148,6 @@ class ExllamaV2Container:
|
|||
self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
|
||||
}
|
||||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
if params.json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
|
|
@ -1220,10 +1176,6 @@ class ExllamaV2Container:
|
|||
|
||||
banned_strings = []
|
||||
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = params.add_bos_token
|
||||
ban_eos_token = params.ban_eos_token
|
||||
|
||||
# Speculative Ngram
|
||||
self.generator.speculative_ngram = params.speculative_ngram
|
||||
|
||||
|
|
@ -1267,6 +1219,62 @@ class ExllamaV2Container:
|
|||
"in the model's vocab. Skipping."
|
||||
)
|
||||
|
||||
async def generate_gen(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
||||
for kwargs, check common/sampling.py
|
||||
"""
|
||||
|
||||
# Wait for load lock to be freed before processing
|
||||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
prompts = [prompt]
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
banned_strings = []
|
||||
|
||||
# TODO: Not used for some reason?
|
||||
generate_window = max(params.generate_window, self.config.max_seq_len // 8)
|
||||
|
||||
self.assign_gen_params(
|
||||
params,
|
||||
gen_settings,
|
||||
grammar_handler,
|
||||
banned_strings,
|
||||
)
|
||||
|
||||
# Set CFG scale and negative prompt
|
||||
cfg_scale = params.cfg_scale
|
||||
negative_prompt = None
|
||||
if cfg_scale not in [None, 1.0]:
|
||||
if self.paged:
|
||||
gen_settings.cfg_scale = cfg_scale
|
||||
|
||||
# If the negative prompt is empty, use the BOS token
|
||||
negative_prompt = unwrap(
|
||||
params.negative_prompt, self.tokenizer.bos_token
|
||||
)
|
||||
|
||||
prompts.append(negative_prompt)
|
||||
else:
|
||||
logger.warning(
|
||||
"CFG is currently disabled because paged mode is disabled. "
|
||||
"Please use an ampere (30 series) or higher GPU for CFG support."
|
||||
)
|
||||
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = params.add_bos_token
|
||||
ban_eos_token = params.ban_eos_token
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
|
|
@ -1283,7 +1291,6 @@ class ExllamaV2Container:
|
|||
stop_conditions += eos_tokens
|
||||
|
||||
# Get multimodal embeddings if present
|
||||
# TODO: Remove kwargs and pass this as optional
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
# Encode both positive and negative prompts
|
||||
|
|
@ -1316,16 +1323,14 @@ class ExllamaV2Container:
|
|||
# Determine if the negative context or the context length is bigger
|
||||
context_to_check = max(negative_context_len, context_len)
|
||||
|
||||
# Check highest possible total length of request
|
||||
if context_to_check + max_tokens > self.config.max_seq_len:
|
||||
# Check total length of prompt against max context length
|
||||
if context_to_check > self.config.max_seq_len:
|
||||
preamble = (
|
||||
"Negative prompt request"
|
||||
if negative_context_len > context_len
|
||||
else "Request"
|
||||
"Negative prompt" if negative_context_len > context_len else "Prompt"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"{preamble} length {context_to_check} + {max_tokens} is greater than "
|
||||
f"{preamble} length {context_to_check} is greater than "
|
||||
f"max_seq_len {self.config.max_seq_len}"
|
||||
)
|
||||
|
||||
|
|
@ -1482,8 +1487,8 @@ class ExllamaV2Container:
|
|||
request_id=request_id,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
**params.model_dump(),
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
prompt=prompt,
|
||||
**params.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
|
||||
# Log the metrics if present
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue