Exl3: Couldn't wait

Just copied some stuff around and it ended up working for basic use.
This commit is contained in:
randoentity 2025-04-29 23:57:53 +02:00 committed by kingbri
parent b4ff2f23cf
commit daae9ec43d
2 changed files with 221 additions and 10 deletions

View file

@ -533,8 +533,7 @@ class ExllamaV2Container(BaseModelContainer):
# Load draft model if a config is present
if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config)
if not self.quiet:
logger.info("Loading draft model: " + self.draft_config.model_dir)
logger.info("Loading draft model: " + self.draft_config.model_dir)
# Draft uses the autosplit loader, so create a cache that reflects this
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
@ -587,8 +586,7 @@ class ExllamaV2Container(BaseModelContainer):
yield value
self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)
logger.info("Loading model: " + self.config.model_dir)
# Get class of the model cache
cache_class = self.get_cache_class(self.cache_mode)

View file

@ -16,12 +16,12 @@ from backends.base_model_container import BaseModelContainer
from common.concurrency import iterate_in_threadpool
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig
from common.utils import unwrap
from endpoints.core.types.model import ModelCard
from exllamav3 import Config, Model, Cache, Tokenizer
from exllamav3 import AsyncGenerator, AsyncJob, Config, Model, Cache, Tokenizer
class ExllamaV3Container(BaseModelContainer):
@ -46,6 +46,8 @@ class ExllamaV3Container(BaseModelContainer):
cache: Cache
tokenizer: Tokenizer
config: Config
gpu_split: List[float] = []
max_seq_len: int = 2048
# Required methods
@classmethod
@ -74,6 +76,16 @@ class ExllamaV3Container(BaseModelContainer):
max_seq_len = kwargs.get("max_seq_len")
self.cache = Cache(self.model, max_num_tokens=max_seq_len)
gpu_split = unwrap(kwargs.get("gpu_split"), [])
# Set GPU split options
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split = gpu_split
# Try to set prompt template
self.prompt_template = await find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
return self
@ -128,7 +140,10 @@ class ExllamaV3Container(BaseModelContainer):
# TODO: Add draft loading
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
for value in self.model.load_gen(callback=progress_callback):
for value in self.model.load_gen(
use_per_device=self.gpu_split,
callback=progress_callback
):
if value:
yield value
@ -263,7 +278,58 @@ class ExllamaV3Container(BaseModelContainer):
A dictionary containing the generation info
"""
pass
generations = []
async for generation in self.stream_generate(
request_id,
prompt,
params,
abort_event,
mm_embeddings,
):
generations.append(generation)
joined_generation = {
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
"logprobs": [],
}
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else:
joined_generation["finish_reason"] = "stop"
if len(generations) > 0:
for generation in generations:
joined_generation["text"] += unwrap(generation.get("text"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)
joined_generation["prompt_tokens"] = unwrap(
generations[-1].get("prompt_tokens"), 0
)
joined_generation["generated_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
return joined_generation
async def stream_generate(
self,
@ -287,5 +353,152 @@ class ExllamaV3Container(BaseModelContainer):
Generation chunks
"""
if False:
yield
try:
# Wait for load lock to be freed before processing
# Mainly used for loras and other operations where the class is available
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
# If the model is being unloaded, don't accept new requests
if not self.loaded:
raise RuntimeError(
"Model is being unloaded. Cannot process new generation requests."
)
# Mark that the job is running
self.active_job_ids[request_id] = None
# Yield from the internal generator
async for generation_chunk in self.generate_gen(
request_id=request_id,
prompt=prompt,
params=params,
abort_event=abort_event,
mm_embeddings=mm_embeddings,
):
yield generation_chunk
finally:
# Clean up and remove the job from active IDs
del self.active_job_ids[request_id]
def handle_finish_chunk(self, result: dict, generation: dict):
eos_reason = result.get("eos_reason")
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
return finish_chunk
async def generate_gen(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""
Create generator function for prompt completion.
for kwargs, check common/sampling.py
"""
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = params.add_bos_token
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
stop_conditions += eos_tokens
input_ids = [
self.tokenizer.encode(
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
)
for prompt in prompts
]
# The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1)
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
params.max_tokens,
self.max_seq_len - context_len,
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = context_len
# Check total length of prompt against max context length
if context_to_check > self.max_seq_len:
preamble = "Prompt"
raise ValueError(
f"{preamble} length {context_to_check} is greater than "
f"max_seq_len {self.max_seq_len}"
)
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
tokenizer=self.tokenizer,
)
generation = {}
print(max_tokens)
job = AsyncJob(
self.generator,
input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens,
stop_conditions=stop_conditions,
)
generated_tokens = 0
full_response = ""
async for result in job:
chunk = unwrap(result.get("text"), "")
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
full_response += chunk
if isinstance(chunk_tokens, torch.Tensor):
generated_tokens += chunk_tokens.size(dim=0)
generation = {
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
yield generation
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
yield generation
# Assign the active job to the request ID
self.active_job_ids[request_id] = job