Model: Correct exl3 generation, add concurrency, and cleanup
Fixes application of sampler parameters by adding a new sampler builder interface. Also expose the generator class-wide and add wait_for_jobs. Finally, allow inline loading to specify the backend. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
c744790f14
commit
303e2dde12
2 changed files with 155 additions and 85 deletions
|
|
@ -14,7 +14,6 @@ from exllamav3 import (
|
|||
AsyncGenerator,
|
||||
AsyncJob,
|
||||
Cache,
|
||||
ComboSampler,
|
||||
Config,
|
||||
Model,
|
||||
Tokenizer,
|
||||
|
|
@ -22,6 +21,7 @@ from exllamav3 import (
|
|||
from loguru import logger
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
|
|
@ -32,7 +32,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
|
|||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import unwrap
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard
|
||||
|
||||
|
||||
|
|
@ -58,6 +58,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
cache: Cache
|
||||
tokenizer: Tokenizer
|
||||
config: Config
|
||||
generator: Optional[AsyncGenerator] = None
|
||||
gpu_split: List[float] | None = None
|
||||
gpu_split_auto: bool = True
|
||||
autosplit_reserve: List[float] = [96 / 1024]
|
||||
|
|
@ -123,13 +124,47 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
# Reserve VRAM for each GPU
|
||||
self.autosplit_reserve = [
|
||||
value/1024
|
||||
for value in autosplit_reserve_megabytes
|
||||
value / 1024 for value in autosplit_reserve_megabytes
|
||||
]
|
||||
# TODO: speculative decoding
|
||||
|
||||
return self
|
||||
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
||||
Returns:
|
||||
Model parameters provided by the backend
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""
|
||||
Polling to wait for any active generation jobs to complete.
|
||||
|
||||
Args:
|
||||
skip_wait: If True, cancel jobs immediately instead of waiting.
|
||||
"""
|
||||
|
||||
if not self.generator:
|
||||
return
|
||||
|
||||
# Immediately abort all jobs if asked
|
||||
if skip_wait:
|
||||
logger.warning(
|
||||
"Immediately terminating all jobs. "
|
||||
"Clients will have their requests cancelled.\n"
|
||||
)
|
||||
|
||||
for job in self.active_job_ids.values():
|
||||
if job:
|
||||
await job.cancel()
|
||||
|
||||
while len(self.active_job_ids) > 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def load(self, progress_callback=None, **kwargs):
|
||||
"""
|
||||
Loads the model into memory.
|
||||
|
|
@ -161,8 +196,8 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
generator = self.load_model_sync(progress_callback)
|
||||
async for module, modules in iterate_in_threadpool(generator):
|
||||
yield module, modules
|
||||
async for value in iterate_in_threadpool(generator):
|
||||
yield value
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
|
|
@ -184,11 +219,42 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
for value in self.model.load_gen(
|
||||
reserve_per_device=self.autosplit_reserve,
|
||||
use_per_device=self.gpu_split,
|
||||
callback=progress_callback
|
||||
callback=progress_callback,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
async def create_generator(self):
|
||||
"""Create and save a Exllama generator class."""
|
||||
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
if self.loaded:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Immediately cancel all jobs
|
||||
await self.wait_for_jobs(skip_wait=True)
|
||||
|
||||
# Create new generator
|
||||
self.generator = AsyncGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
|
||||
# Update the state of the container var
|
||||
if self.max_batch_size is None:
|
||||
self.max_batch_size = self.generator.generator.max_batch_size
|
||||
finally:
|
||||
# This means the generator is being recreated
|
||||
# The load lock is already released in the load function
|
||||
if self.loaded:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
async def unload(self, loras_only: bool = False, **kwargs):
|
||||
"""
|
||||
Unloads the model and associated resources from memory.
|
||||
|
|
@ -198,7 +264,11 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
**kwargs: Additional unloading options (e.g., shutdown).
|
||||
"""
|
||||
|
||||
# Used when shutting down the server
|
||||
do_shutdown = kwargs.get("shutdown")
|
||||
|
||||
try:
|
||||
if not do_shutdown:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Wait for other jobs to finish
|
||||
|
|
@ -211,11 +281,17 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
self.cache = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Cleanup the generator from any pending jobs
|
||||
if self.generator is not None:
|
||||
await self.generator.close()
|
||||
self.generator = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Model unloaded.")
|
||||
finally:
|
||||
if not do_shutdown:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
|
|
@ -233,11 +309,15 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
A list of integer token IDs.
|
||||
"""
|
||||
|
||||
return self.tokenizer.encode(
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
).flatten().tolist()
|
||||
)
|
||||
.flatten()
|
||||
.tolist()
|
||||
)
|
||||
|
||||
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
|
|
@ -278,26 +358,6 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
"unk_token": self.tokenizer.unk_token,
|
||||
}
|
||||
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
||||
Returns:
|
||||
Model parameters provided by the backend
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""
|
||||
Waits for any active generation jobs to complete.
|
||||
|
||||
Args:
|
||||
skip_wait: If True, cancel jobs immediately instead of waiting.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
|
|
@ -446,37 +506,6 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
return finish_chunk
|
||||
|
||||
async def create_generator(self):
|
||||
"""Create and save a Exllama generator class."""
|
||||
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
if self.loaded:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Immediately cancel all jobs
|
||||
await self.wait_for_jobs(skip_wait=True)
|
||||
|
||||
# Create new generator
|
||||
self.generator = AsyncGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
|
||||
# Update the state of the container var
|
||||
if self.max_batch_size is None:
|
||||
self.max_batch_size = self.generator.generator.max_batch_size
|
||||
finally:
|
||||
# This means the generator is being recreated
|
||||
# The load lock is already released in the load function
|
||||
if self.loaded:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
async def generate_gen(
|
||||
self,
|
||||
request_id: str,
|
||||
|
|
@ -492,25 +521,58 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
"""
|
||||
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
# FIXME: this is probably not right
|
||||
base_sampler = BaseSamplerRequest()
|
||||
sampler = ComboSampler(
|
||||
rep_p=base_sampler.repetition_penalty,
|
||||
pres_p=base_sampler.presence_penalty,
|
||||
freq_p=base_sampler.frequency_penalty,
|
||||
rep_sustain_range=base_sampler.penalty_range,
|
||||
rep_decay_range=base_sampler.penalty_range,
|
||||
temperature=base_sampler.temperature,
|
||||
min_p=base_sampler.min_p,
|
||||
top_k=base_sampler.top_k,
|
||||
top_p=base_sampler.top_p,
|
||||
temp_last=base_sampler.temperature_last,
|
||||
sampler_builder = ExllamaV3SamplerBuilder()
|
||||
|
||||
# Penalties
|
||||
|
||||
# Set penalty range
|
||||
penalty_range = unwrap(params.penalty_range, self.max_seq_len)
|
||||
|
||||
# Exl3's version of including the entire context
|
||||
if penalty_range < 0:
|
||||
penalty_range = 10e7
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed
|
||||
# fallback
|
||||
# Always default to 0 if something goes wrong
|
||||
if params.penalty_range < 0:
|
||||
fallback_decay = 0
|
||||
else:
|
||||
fallback_decay = params.penalty_range
|
||||
|
||||
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
|
||||
|
||||
# Apply penalties to builder
|
||||
sampler_builder.penalties(
|
||||
params.repetition_penalty,
|
||||
params.frequency_penalty,
|
||||
params.presence_penalty,
|
||||
penalty_range,
|
||||
repetition_decay,
|
||||
)
|
||||
|
||||
# Apply temperature first to builder
|
||||
if not params.temperature_last:
|
||||
sampler_builder.temperature(params.temperature)
|
||||
|
||||
# Apply alphabet samplers to builder
|
||||
sampler_builder.top_k(params.top_k)
|
||||
sampler_builder.top_p(params.top_p)
|
||||
sampler_builder.min_p(params.min_p)
|
||||
|
||||
# Apply temperature last to builder
|
||||
if params.temperature_last:
|
||||
sampler_builder.temperature(params.temperature)
|
||||
|
||||
# Build the sampler
|
||||
# Set greedy if temperature is 0
|
||||
sampler = sampler_builder.build(params.temperature == 0)
|
||||
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
# TODO:
|
||||
# TODO: This currently does not work in exl3
|
||||
# auto_scale_penalty_range = (
|
||||
# gen_settings.token_frequency_penalty != 0
|
||||
# or gen_settings.token_presence_penalty != 0
|
||||
|
|
@ -576,6 +638,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
||||
max_new_tokens=max_tokens,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_strings=params.banned_strings,
|
||||
)
|
||||
|
||||
generated_tokens = 0
|
||||
|
|
@ -585,6 +648,11 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# Get the generation status once it's ready
|
||||
try:
|
||||
async for result in job:
|
||||
# Abort if the event is set while streaming
|
||||
if abort_event and abort_event.is_set():
|
||||
await job.cancel()
|
||||
break
|
||||
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
if chunk:
|
||||
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
||||
|
|
|
|||
|
|
@ -163,8 +163,10 @@ class ModelConfig(BaseConfigModel):
|
|||
"Example: ['max_seq_len', 'cache_mode']."
|
||||
),
|
||||
)
|
||||
|
||||
# Defaults to exllamav2 in common/model.py
|
||||
backend: Optional[str] = Field(
|
||||
"exllamav2",
|
||||
None,
|
||||
description=(
|
||||
"Backend to use for this model (default: exllamav2)\n"
|
||||
"Options: exllamav2, exllamav3",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue