Model: Correct exl3 generation, add concurrency, and cleanup
Fixes application of sampler parameters by adding a new sampler builder interface. Also expose the generator class-wide and add wait_for_jobs. Finally, allow inline loading to specify the backend. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
c744790f14
commit
303e2dde12
2 changed files with 155 additions and 85 deletions
|
|
@ -14,7 +14,6 @@ from exllamav3 import (
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
AsyncJob,
|
AsyncJob,
|
||||||
Cache,
|
Cache,
|
||||||
ComboSampler,
|
|
||||||
Config,
|
Config,
|
||||||
Model,
|
Model,
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
|
|
@ -22,6 +21,7 @@ from exllamav3 import (
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from backends.base_model_container import BaseModelContainer
|
from backends.base_model_container import BaseModelContainer
|
||||||
|
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
|
||||||
from common.concurrency import iterate_in_threadpool
|
from common.concurrency import iterate_in_threadpool
|
||||||
from common.gen_logging import (
|
from common.gen_logging import (
|
||||||
log_generation_params,
|
log_generation_params,
|
||||||
|
|
@ -32,7 +32,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from common.sampling import BaseSamplerRequest
|
from common.sampling import BaseSamplerRequest
|
||||||
from common.templating import PromptTemplate, find_prompt_template
|
from common.templating import PromptTemplate, find_prompt_template
|
||||||
from common.transformers_utils import GenerationConfig
|
from common.transformers_utils import GenerationConfig
|
||||||
from common.utils import unwrap
|
from common.utils import coalesce, unwrap
|
||||||
from endpoints.core.types.model import ModelCard
|
from endpoints.core.types.model import ModelCard
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -58,6 +58,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
cache: Cache
|
cache: Cache
|
||||||
tokenizer: Tokenizer
|
tokenizer: Tokenizer
|
||||||
config: Config
|
config: Config
|
||||||
|
generator: Optional[AsyncGenerator] = None
|
||||||
gpu_split: List[float] | None = None
|
gpu_split: List[float] | None = None
|
||||||
gpu_split_auto: bool = True
|
gpu_split_auto: bool = True
|
||||||
autosplit_reserve: List[float] = [96 / 1024]
|
autosplit_reserve: List[float] = [96 / 1024]
|
||||||
|
|
@ -123,13 +124,47 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
|
|
||||||
# Reserve VRAM for each GPU
|
# Reserve VRAM for each GPU
|
||||||
self.autosplit_reserve = [
|
self.autosplit_reserve = [
|
||||||
value/1024
|
value / 1024 for value in autosplit_reserve_megabytes
|
||||||
for value in autosplit_reserve_megabytes
|
|
||||||
]
|
]
|
||||||
# TODO: speculative decoding
|
# TODO: speculative decoding
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def model_info(self) -> ModelCard:
|
||||||
|
"""
|
||||||
|
Returns a dictionary of the current model's configuration parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model parameters provided by the backend
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||||
|
"""
|
||||||
|
Polling to wait for any active generation jobs to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_wait: If True, cancel jobs immediately instead of waiting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.generator:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Immediately abort all jobs if asked
|
||||||
|
if skip_wait:
|
||||||
|
logger.warning(
|
||||||
|
"Immediately terminating all jobs. "
|
||||||
|
"Clients will have their requests cancelled.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
for job in self.active_job_ids.values():
|
||||||
|
if job:
|
||||||
|
await job.cancel()
|
||||||
|
|
||||||
|
while len(self.active_job_ids) > 0:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
async def load(self, progress_callback=None, **kwargs):
|
async def load(self, progress_callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Loads the model into memory.
|
Loads the model into memory.
|
||||||
|
|
@ -161,8 +196,8 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||||
|
|
||||||
generator = self.load_model_sync(progress_callback)
|
generator = self.load_model_sync(progress_callback)
|
||||||
async for module, modules in iterate_in_threadpool(generator):
|
async for value in iterate_in_threadpool(generator):
|
||||||
yield module, modules
|
yield value
|
||||||
|
|
||||||
# Clean up any extra vram usage from torch and cuda
|
# Clean up any extra vram usage from torch and cuda
|
||||||
# (Helps reduce VRAM bottlenecking on Windows)
|
# (Helps reduce VRAM bottlenecking on Windows)
|
||||||
|
|
@ -184,11 +219,42 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
for value in self.model.load_gen(
|
for value in self.model.load_gen(
|
||||||
reserve_per_device=self.autosplit_reserve,
|
reserve_per_device=self.autosplit_reserve,
|
||||||
use_per_device=self.gpu_split,
|
use_per_device=self.gpu_split,
|
||||||
callback=progress_callback
|
callback=progress_callback,
|
||||||
):
|
):
|
||||||
if value:
|
if value:
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
|
async def create_generator(self):
|
||||||
|
"""Create and save a Exllama generator class."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Don't acquire locks unless a model is loaded
|
||||||
|
if self.loaded:
|
||||||
|
await self.load_lock.acquire()
|
||||||
|
|
||||||
|
# Immediately cancel all jobs
|
||||||
|
await self.wait_for_jobs(skip_wait=True)
|
||||||
|
|
||||||
|
# Create new generator
|
||||||
|
self.generator = AsyncGenerator(
|
||||||
|
model=self.model,
|
||||||
|
cache=self.cache,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
max_batch_size=self.max_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the state of the container var
|
||||||
|
if self.max_batch_size is None:
|
||||||
|
self.max_batch_size = self.generator.generator.max_batch_size
|
||||||
|
finally:
|
||||||
|
# This means the generator is being recreated
|
||||||
|
# The load lock is already released in the load function
|
||||||
|
if self.loaded:
|
||||||
|
self.load_lock.release()
|
||||||
|
|
||||||
|
async with self.load_condition:
|
||||||
|
self.load_condition.notify_all()
|
||||||
|
|
||||||
async def unload(self, loras_only: bool = False, **kwargs):
|
async def unload(self, loras_only: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Unloads the model and associated resources from memory.
|
Unloads the model and associated resources from memory.
|
||||||
|
|
@ -198,11 +264,15 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
**kwargs: Additional unloading options (e.g., shutdown).
|
**kwargs: Additional unloading options (e.g., shutdown).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
# Used when shutting down the server
|
||||||
await self.load_lock.acquire()
|
do_shutdown = kwargs.get("shutdown")
|
||||||
|
|
||||||
# Wait for other jobs to finish
|
try:
|
||||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
if not do_shutdown:
|
||||||
|
await self.load_lock.acquire()
|
||||||
|
|
||||||
|
# Wait for other jobs to finish
|
||||||
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||||
|
|
||||||
self.model.unload()
|
self.model.unload()
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
@ -211,15 +281,21 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
self.cache = None
|
self.cache = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
|
# Cleanup the generator from any pending jobs
|
||||||
|
if self.generator is not None:
|
||||||
|
await self.generator.close()
|
||||||
|
self.generator = None
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
logger.info("Model unloaded.")
|
logger.info("Model unloaded.")
|
||||||
finally:
|
finally:
|
||||||
self.load_lock.release()
|
if not do_shutdown:
|
||||||
|
self.load_lock.release()
|
||||||
|
|
||||||
async with self.load_condition:
|
async with self.load_condition:
|
||||||
self.load_condition.notify_all()
|
self.load_condition.notify_all()
|
||||||
|
|
||||||
def encode_tokens(self, text: str, **kwargs) -> List[int]:
|
def encode_tokens(self, text: str, **kwargs) -> List[int]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -233,11 +309,15 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
A list of integer token IDs.
|
A list of integer token IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.tokenizer.encode(
|
return (
|
||||||
text,
|
self.tokenizer.encode(
|
||||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
text,
|
||||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||||
).flatten().tolist()
|
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||||
|
)
|
||||||
|
.flatten()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
|
||||||
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -278,26 +358,6 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
"unk_token": self.tokenizer.unk_token,
|
"unk_token": self.tokenizer.unk_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
def model_info(self) -> ModelCard:
|
|
||||||
"""
|
|
||||||
Returns a dictionary of the current model's configuration parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model parameters provided by the backend
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
|
||||||
"""
|
|
||||||
Waits for any active generation jobs to complete.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
skip_wait: If True, cancel jobs immediately instead of waiting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
|
@ -446,37 +506,6 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
|
|
||||||
return finish_chunk
|
return finish_chunk
|
||||||
|
|
||||||
async def create_generator(self):
|
|
||||||
"""Create and save a Exllama generator class."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Don't acquire locks unless a model is loaded
|
|
||||||
if self.loaded:
|
|
||||||
await self.load_lock.acquire()
|
|
||||||
|
|
||||||
# Immediately cancel all jobs
|
|
||||||
await self.wait_for_jobs(skip_wait=True)
|
|
||||||
|
|
||||||
# Create new generator
|
|
||||||
self.generator = AsyncGenerator(
|
|
||||||
model=self.model,
|
|
||||||
cache=self.cache,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
max_batch_size=self.max_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the state of the container var
|
|
||||||
if self.max_batch_size is None:
|
|
||||||
self.max_batch_size = self.generator.generator.max_batch_size
|
|
||||||
finally:
|
|
||||||
# This means the generator is being recreated
|
|
||||||
# The load lock is already released in the load function
|
|
||||||
if self.loaded:
|
|
||||||
self.load_lock.release()
|
|
||||||
|
|
||||||
async with self.load_condition:
|
|
||||||
self.load_condition.notify_all()
|
|
||||||
|
|
||||||
async def generate_gen(
|
async def generate_gen(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
|
@ -492,25 +521,58 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
"""
|
"""
|
||||||
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
# FIXME: this is probably not right
|
sampler_builder = ExllamaV3SamplerBuilder()
|
||||||
base_sampler = BaseSamplerRequest()
|
|
||||||
sampler = ComboSampler(
|
# Penalties
|
||||||
rep_p=base_sampler.repetition_penalty,
|
|
||||||
pres_p=base_sampler.presence_penalty,
|
# Set penalty range
|
||||||
freq_p=base_sampler.frequency_penalty,
|
penalty_range = unwrap(params.penalty_range, self.max_seq_len)
|
||||||
rep_sustain_range=base_sampler.penalty_range,
|
|
||||||
rep_decay_range=base_sampler.penalty_range,
|
# Exl3's version of including the entire context
|
||||||
temperature=base_sampler.temperature,
|
if penalty_range < 0:
|
||||||
min_p=base_sampler.min_p,
|
penalty_range = 10e7
|
||||||
top_k=base_sampler.top_k,
|
|
||||||
top_p=base_sampler.top_p,
|
# Always make sure the fallback is 0 if range < 0
|
||||||
temp_last=base_sampler.temperature_last,
|
# It's technically fine to use -1, but this just validates the passed
|
||||||
|
# fallback
|
||||||
|
# Always default to 0 if something goes wrong
|
||||||
|
if params.penalty_range < 0:
|
||||||
|
fallback_decay = 0
|
||||||
|
else:
|
||||||
|
fallback_decay = params.penalty_range
|
||||||
|
|
||||||
|
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
|
||||||
|
|
||||||
|
# Apply penalties to builder
|
||||||
|
sampler_builder.penalties(
|
||||||
|
params.repetition_penalty,
|
||||||
|
params.frequency_penalty,
|
||||||
|
params.presence_penalty,
|
||||||
|
penalty_range,
|
||||||
|
repetition_decay,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply temperature first to builder
|
||||||
|
if not params.temperature_last:
|
||||||
|
sampler_builder.temperature(params.temperature)
|
||||||
|
|
||||||
|
# Apply alphabet samplers to builder
|
||||||
|
sampler_builder.top_k(params.top_k)
|
||||||
|
sampler_builder.top_p(params.top_p)
|
||||||
|
sampler_builder.min_p(params.min_p)
|
||||||
|
|
||||||
|
# Apply temperature last to builder
|
||||||
|
if params.temperature_last:
|
||||||
|
sampler_builder.temperature(params.temperature)
|
||||||
|
|
||||||
|
# Build the sampler
|
||||||
|
# Set greedy if temperature is 0
|
||||||
|
sampler = sampler_builder.build(params.temperature == 0)
|
||||||
|
|
||||||
# Dynamically scale penalty range to output tokens
|
# Dynamically scale penalty range to output tokens
|
||||||
# Only do this if freq/pres pen is enabled
|
# Only do this if freq/pres pen is enabled
|
||||||
# and the repetition range is -1
|
# and the repetition range is -1
|
||||||
# TODO:
|
# TODO: This currently does not work in exl3
|
||||||
# auto_scale_penalty_range = (
|
# auto_scale_penalty_range = (
|
||||||
# gen_settings.token_frequency_penalty != 0
|
# gen_settings.token_frequency_penalty != 0
|
||||||
# or gen_settings.token_presence_penalty != 0
|
# or gen_settings.token_presence_penalty != 0
|
||||||
|
|
@ -576,6 +638,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
|
banned_strings=params.banned_strings,
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_tokens = 0
|
generated_tokens = 0
|
||||||
|
|
@ -585,6 +648,11 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
# Get the generation status once it's ready
|
# Get the generation status once it's ready
|
||||||
try:
|
try:
|
||||||
async for result in job:
|
async for result in job:
|
||||||
|
# Abort if the event is set while streaming
|
||||||
|
if abort_event and abort_event.is_set():
|
||||||
|
await job.cancel()
|
||||||
|
break
|
||||||
|
|
||||||
chunk = unwrap(result.get("text"), "")
|
chunk = unwrap(result.get("text"), "")
|
||||||
if chunk:
|
if chunk:
|
||||||
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
||||||
|
|
|
||||||
|
|
@ -163,8 +163,10 @@ class ModelConfig(BaseConfigModel):
|
||||||
"Example: ['max_seq_len', 'cache_mode']."
|
"Example: ['max_seq_len', 'cache_mode']."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Defaults to exllamav2 in common/model.py
|
||||||
backend: Optional[str] = Field(
|
backend: Optional[str] = Field(
|
||||||
"exllamav2",
|
None,
|
||||||
description=(
|
description=(
|
||||||
"Backend to use for this model (default: exllamav2)\n"
|
"Backend to use for this model (default: exllamav2)\n"
|
||||||
"Options: exllamav2, exllamav3",
|
"Options: exllamav2, exllamav3",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue