Model: Correct exl3 generation, add concurrency, and cleanup

Fixes application of sampler parameters by adding a new sampler builder
interface. Also expose the generator class-wide and add wait_for_jobs.

Finally, allow inline loading to specify the backend.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-30 22:59:25 -04:00
parent c744790f14
commit 303e2dde12
2 changed files with 155 additions and 85 deletions

View file

@ -14,7 +14,6 @@ from exllamav3 import (
AsyncGenerator,
AsyncJob,
Cache,
ComboSampler,
Config,
Model,
Tokenizer,
@ -22,6 +21,7 @@ from exllamav3 import (
from loguru import logger
from backends.base_model_container import BaseModelContainer
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
@ -32,7 +32,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig
from common.utils import unwrap
from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard
@ -58,6 +58,7 @@ class ExllamaV3Container(BaseModelContainer):
cache: Cache
tokenizer: Tokenizer
config: Config
generator: Optional[AsyncGenerator] = None
gpu_split: List[float] | None = None
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 / 1024]
@ -123,13 +124,47 @@ class ExllamaV3Container(BaseModelContainer):
# Reserve VRAM for each GPU
self.autosplit_reserve = [
value/1024
for value in autosplit_reserve_megabytes
value / 1024 for value in autosplit_reserve_megabytes
]
# TODO: speculative decoding
return self
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
Returns:
Model parameters provided by the backend
"""
pass
async def wait_for_jobs(self, skip_wait: bool = False):
"""
Polling to wait for any active generation jobs to complete.
Args:
skip_wait: If True, cancel jobs immediately instead of waiting.
"""
if not self.generator:
return
# Immediately abort all jobs if asked
if skip_wait:
logger.warning(
"Immediately terminating all jobs. "
"Clients will have their requests cancelled.\n"
)
for job in self.active_job_ids.values():
if job:
await job.cancel()
while len(self.active_job_ids) > 0:
await asyncio.sleep(0.01)
async def load(self, progress_callback=None, **kwargs):
"""
Loads the model into memory.
@ -161,8 +196,8 @@ class ExllamaV3Container(BaseModelContainer):
await self.wait_for_jobs(kwargs.get("skip_wait"))
generator = self.load_model_sync(progress_callback)
async for module, modules in iterate_in_threadpool(generator):
yield module, modules
async for value in iterate_in_threadpool(generator):
yield value
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
@ -184,11 +219,42 @@ class ExllamaV3Container(BaseModelContainer):
for value in self.model.load_gen(
reserve_per_device=self.autosplit_reserve,
use_per_device=self.gpu_split,
callback=progress_callback
callback=progress_callback,
):
if value:
yield value
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs):
"""
Unloads the model and associated resources from memory.
@ -198,7 +264,11 @@ class ExllamaV3Container(BaseModelContainer):
**kwargs: Additional unloading options (e.g., shutdown).
"""
# Used when shutting down the server
do_shutdown = kwargs.get("shutdown")
try:
if not do_shutdown:
await self.load_lock.acquire()
# Wait for other jobs to finish
@ -211,11 +281,17 @@ class ExllamaV3Container(BaseModelContainer):
self.cache = None
self.tokenizer = None
# Cleanup the generator from any pending jobs
if self.generator is not None:
await self.generator.close()
self.generator = None
gc.collect()
torch.cuda.empty_cache()
logger.info("Model unloaded.")
finally:
if not do_shutdown:
self.load_lock.release()
async with self.load_condition:
@ -233,11 +309,15 @@ class ExllamaV3Container(BaseModelContainer):
A list of integer token IDs.
"""
return self.tokenizer.encode(
return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
).flatten().tolist()
)
.flatten()
.tolist()
)
def decode_tokens(self, ids: List[int], **kwargs) -> str:
"""
@ -278,26 +358,6 @@ class ExllamaV3Container(BaseModelContainer):
"unk_token": self.tokenizer.unk_token,
}
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
Returns:
Model parameters provided by the backend
"""
pass
async def wait_for_jobs(self, skip_wait: bool = False):
"""
Waits for any active generation jobs to complete.
Args:
skip_wait: If True, cancel jobs immediately instead of waiting.
"""
pass
async def generate(
self,
request_id: str,
@ -446,37 +506,6 @@ class ExllamaV3Container(BaseModelContainer):
return finish_chunk
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def generate_gen(
self,
request_id: str,
@ -492,25 +521,58 @@ class ExllamaV3Container(BaseModelContainer):
"""
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
# FIXME: this is probably not right
base_sampler = BaseSamplerRequest()
sampler = ComboSampler(
rep_p=base_sampler.repetition_penalty,
pres_p=base_sampler.presence_penalty,
freq_p=base_sampler.frequency_penalty,
rep_sustain_range=base_sampler.penalty_range,
rep_decay_range=base_sampler.penalty_range,
temperature=base_sampler.temperature,
min_p=base_sampler.min_p,
top_k=base_sampler.top_k,
top_p=base_sampler.top_p,
temp_last=base_sampler.temperature_last,
sampler_builder = ExllamaV3SamplerBuilder()
# Penalties
# Set penalty range
penalty_range = unwrap(params.penalty_range, self.max_seq_len)
# Exl3's version of including the entire context
if penalty_range < 0:
penalty_range = 10e7
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
if params.penalty_range < 0:
fallback_decay = 0
else:
fallback_decay = params.penalty_range
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
# Apply penalties to builder
sampler_builder.penalties(
params.repetition_penalty,
params.frequency_penalty,
params.presence_penalty,
penalty_range,
repetition_decay,
)
# Apply temperature first to builder
if not params.temperature_last:
sampler_builder.temperature(params.temperature)
# Apply alphabet samplers to builder
sampler_builder.top_k(params.top_k)
sampler_builder.top_p(params.top_p)
sampler_builder.min_p(params.min_p)
# Apply temperature last to builder
if params.temperature_last:
sampler_builder.temperature(params.temperature)
# Build the sampler
# Set greedy if temperature is 0
sampler = sampler_builder.build(params.temperature == 0)
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
# TODO:
# TODO: This currently does not work in exl3
# auto_scale_penalty_range = (
# gen_settings.token_frequency_penalty != 0
# or gen_settings.token_presence_penalty != 0
@ -576,6 +638,7 @@ class ExllamaV3Container(BaseModelContainer):
input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens,
stop_conditions=stop_conditions,
banned_strings=params.banned_strings,
)
generated_tokens = 0
@ -585,6 +648,11 @@ class ExllamaV3Container(BaseModelContainer):
# Get the generation status once it's ready
try:
async for result in job:
# Abort if the event is set while streaming
if abort_event and abort_event.is_set():
await job.cancel()
break
chunk = unwrap(result.get("text"), "")
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))

View file

@ -163,8 +163,10 @@ class ModelConfig(BaseConfigModel):
"Example: ['max_seq_len', 'cache_mode']."
),
)
# Defaults to exllamav2 in common/model.py
backend: Optional[str] = Field(
"exllamav2",
None,
description=(
"Backend to use for this model (default: exllamav2)\n"
"Options: exllamav2, exllamav3",