Model: Correct exl3 generation, add concurrency, and cleanup

Fixes application of sampler parameters by adding a new sampler builder
interface. Also expose the generator class-wide and add wait_for_jobs.

Finally, allow inline loading to specify the backend.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-30 22:59:25 -04:00
parent c744790f14
commit 303e2dde12
2 changed files with 155 additions and 85 deletions

View file

@ -14,7 +14,6 @@ from exllamav3 import (
AsyncGenerator, AsyncGenerator,
AsyncJob, AsyncJob,
Cache, Cache,
ComboSampler,
Config, Config,
Model, Model,
Tokenizer, Tokenizer,
@ -22,6 +21,7 @@ from exllamav3 import (
from loguru import logger from loguru import logger
from backends.base_model_container import BaseModelContainer from backends.base_model_container import BaseModelContainer
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
from common.concurrency import iterate_in_threadpool from common.concurrency import iterate_in_threadpool
from common.gen_logging import ( from common.gen_logging import (
log_generation_params, log_generation_params,
@ -32,7 +32,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig from common.transformers_utils import GenerationConfig
from common.utils import unwrap from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard from endpoints.core.types.model import ModelCard
@ -58,6 +58,7 @@ class ExllamaV3Container(BaseModelContainer):
cache: Cache cache: Cache
tokenizer: Tokenizer tokenizer: Tokenizer
config: Config config: Config
generator: Optional[AsyncGenerator] = None
gpu_split: List[float] | None = None gpu_split: List[float] | None = None
gpu_split_auto: bool = True gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 / 1024] autosplit_reserve: List[float] = [96 / 1024]
@ -123,13 +124,47 @@ class ExllamaV3Container(BaseModelContainer):
# Reserve VRAM for each GPU # Reserve VRAM for each GPU
self.autosplit_reserve = [ self.autosplit_reserve = [
value/1024 value / 1024 for value in autosplit_reserve_megabytes
for value in autosplit_reserve_megabytes
] ]
# TODO: speculative decoding # TODO: speculative decoding
return self return self
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
Returns:
Model parameters provided by the backend
"""
pass
async def wait_for_jobs(self, skip_wait: bool = False):
"""
Polling to wait for any active generation jobs to complete.
Args:
skip_wait: If True, cancel jobs immediately instead of waiting.
"""
if not self.generator:
return
# Immediately abort all jobs if asked
if skip_wait:
logger.warning(
"Immediately terminating all jobs. "
"Clients will have their requests cancelled.\n"
)
for job in self.active_job_ids.values():
if job:
await job.cancel()
while len(self.active_job_ids) > 0:
await asyncio.sleep(0.01)
async def load(self, progress_callback=None, **kwargs): async def load(self, progress_callback=None, **kwargs):
""" """
Loads the model into memory. Loads the model into memory.
@ -161,8 +196,8 @@ class ExllamaV3Container(BaseModelContainer):
await self.wait_for_jobs(kwargs.get("skip_wait")) await self.wait_for_jobs(kwargs.get("skip_wait"))
generator = self.load_model_sync(progress_callback) generator = self.load_model_sync(progress_callback)
async for module, modules in iterate_in_threadpool(generator): async for value in iterate_in_threadpool(generator):
yield module, modules yield value
# Clean up any extra vram usage from torch and cuda # Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows) # (Helps reduce VRAM bottlenecking on Windows)
@ -184,11 +219,42 @@ class ExllamaV3Container(BaseModelContainer):
for value in self.model.load_gen( for value in self.model.load_gen(
reserve_per_device=self.autosplit_reserve, reserve_per_device=self.autosplit_reserve,
use_per_device=self.gpu_split, use_per_device=self.gpu_split,
callback=progress_callback callback=progress_callback,
): ):
if value: if value:
yield value yield value
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs): async def unload(self, loras_only: bool = False, **kwargs):
""" """
Unloads the model and associated resources from memory. Unloads the model and associated resources from memory.
@ -198,11 +264,15 @@ class ExllamaV3Container(BaseModelContainer):
**kwargs: Additional unloading options (e.g., shutdown). **kwargs: Additional unloading options (e.g., shutdown).
""" """
try: # Used when shutting down the server
await self.load_lock.acquire() do_shutdown = kwargs.get("shutdown")
# Wait for other jobs to finish try:
await self.wait_for_jobs(kwargs.get("skip_wait")) if not do_shutdown:
await self.load_lock.acquire()
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
self.model.unload() self.model.unload()
self.model = None self.model = None
@ -211,15 +281,21 @@ class ExllamaV3Container(BaseModelContainer):
self.cache = None self.cache = None
self.tokenizer = None self.tokenizer = None
# Cleanup the generator from any pending jobs
if self.generator is not None:
await self.generator.close()
self.generator = None
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Model unloaded.") logger.info("Model unloaded.")
finally: finally:
self.load_lock.release() if not do_shutdown:
self.load_lock.release()
async with self.load_condition: async with self.load_condition:
self.load_condition.notify_all() self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs) -> List[int]: def encode_tokens(self, text: str, **kwargs) -> List[int]:
""" """
@ -233,11 +309,15 @@ class ExllamaV3Container(BaseModelContainer):
A list of integer token IDs. A list of integer token IDs.
""" """
return self.tokenizer.encode( return (
text, self.tokenizer.encode(
add_bos=unwrap(kwargs.get("add_bos_token"), True), text,
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), add_bos=unwrap(kwargs.get("add_bos_token"), True),
).flatten().tolist() encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)
.flatten()
.tolist()
)
def decode_tokens(self, ids: List[int], **kwargs) -> str: def decode_tokens(self, ids: List[int], **kwargs) -> str:
""" """
@ -278,26 +358,6 @@ class ExllamaV3Container(BaseModelContainer):
"unk_token": self.tokenizer.unk_token, "unk_token": self.tokenizer.unk_token,
} }
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
Returns:
Model parameters provided by the backend
"""
pass
async def wait_for_jobs(self, skip_wait: bool = False):
"""
Waits for any active generation jobs to complete.
Args:
skip_wait: If True, cancel jobs immediately instead of waiting.
"""
pass
async def generate( async def generate(
self, self,
request_id: str, request_id: str,
@ -446,37 +506,6 @@ class ExllamaV3Container(BaseModelContainer):
return finish_chunk return finish_chunk
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def generate_gen( async def generate_gen(
self, self,
request_id: str, request_id: str,
@ -492,25 +521,58 @@ class ExllamaV3Container(BaseModelContainer):
""" """
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor] chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
# FIXME: this is probably not right sampler_builder = ExllamaV3SamplerBuilder()
base_sampler = BaseSamplerRequest()
sampler = ComboSampler( # Penalties
rep_p=base_sampler.repetition_penalty,
pres_p=base_sampler.presence_penalty, # Set penalty range
freq_p=base_sampler.frequency_penalty, penalty_range = unwrap(params.penalty_range, self.max_seq_len)
rep_sustain_range=base_sampler.penalty_range,
rep_decay_range=base_sampler.penalty_range, # Exl3's version of including the entire context
temperature=base_sampler.temperature, if penalty_range < 0:
min_p=base_sampler.min_p, penalty_range = 10e7
top_k=base_sampler.top_k,
top_p=base_sampler.top_p, # Always make sure the fallback is 0 if range < 0
temp_last=base_sampler.temperature_last, # It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
if params.penalty_range < 0:
fallback_decay = 0
else:
fallback_decay = params.penalty_range
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
# Apply penalties to builder
sampler_builder.penalties(
params.repetition_penalty,
params.frequency_penalty,
params.presence_penalty,
penalty_range,
repetition_decay,
) )
# Apply temperature first to builder
if not params.temperature_last:
sampler_builder.temperature(params.temperature)
# Apply alphabet samplers to builder
sampler_builder.top_k(params.top_k)
sampler_builder.top_p(params.top_p)
sampler_builder.min_p(params.min_p)
# Apply temperature last to builder
if params.temperature_last:
sampler_builder.temperature(params.temperature)
# Build the sampler
# Set greedy if temperature is 0
sampler = sampler_builder.build(params.temperature == 0)
# Dynamically scale penalty range to output tokens # Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled # Only do this if freq/pres pen is enabled
# and the repetition range is -1 # and the repetition range is -1
# TODO: # TODO: This currently does not work in exl3
# auto_scale_penalty_range = ( # auto_scale_penalty_range = (
# gen_settings.token_frequency_penalty != 0 # gen_settings.token_frequency_penalty != 0
# or gen_settings.token_presence_penalty != 0 # or gen_settings.token_presence_penalty != 0
@ -576,6 +638,7 @@ class ExllamaV3Container(BaseModelContainer):
input_ids=self.tokenizer.encode(prompt, add_bos=False), input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
stop_conditions=stop_conditions, stop_conditions=stop_conditions,
banned_strings=params.banned_strings,
) )
generated_tokens = 0 generated_tokens = 0
@ -585,6 +648,11 @@ class ExllamaV3Container(BaseModelContainer):
# Get the generation status once it's ready # Get the generation status once it's ready
try: try:
async for result in job: async for result in job:
# Abort if the event is set while streaming
if abort_event and abort_event.is_set():
await job.cancel()
break
chunk = unwrap(result.get("text"), "") chunk = unwrap(result.get("text"), "")
if chunk: if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))

View file

@ -163,8 +163,10 @@ class ModelConfig(BaseConfigModel):
"Example: ['max_seq_len', 'cache_mode']." "Example: ['max_seq_len', 'cache_mode']."
), ),
) )
# Defaults to exllamav2 in common/model.py
backend: Optional[str] = Field( backend: Optional[str] = Field(
"exllamav2", None,
description=( description=(
"Backend to use for this model (default: exllamav2)\n" "Backend to use for this model (default: exllamav2)\n"
"Options: exllamav2, exllamav3", "Options: exllamav2, exllamav3",