Exl3: Couldn't wait
Just copied some stuff around and it ended up working for basic use.
This commit is contained in:
parent
b4ff2f23cf
commit
daae9ec43d
2 changed files with 221 additions and 10 deletions
|
|
@ -533,8 +533,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Load draft model if a config is present
|
||||
if self.draft_config:
|
||||
self.draft_model = ExLlamaV2(self.draft_config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
# Draft uses the autosplit loader, so create a cache that reflects this
|
||||
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
|
||||
|
|
@ -587,8 +586,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
yield value
|
||||
|
||||
self.model = ExLlamaV2(self.config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
|
||||
# Get class of the model cache
|
||||
cache_class = self.get_cache_class(self.cache_mode)
|
||||
|
|
|
|||
|
|
@ -16,12 +16,12 @@ from backends.base_model_container import BaseModelContainer
|
|||
from common.concurrency import iterate_in_threadpool
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.types.model import ModelCard
|
||||
|
||||
from exllamav3 import Config, Model, Cache, Tokenizer
|
||||
from exllamav3 import AsyncGenerator, AsyncJob, Config, Model, Cache, Tokenizer
|
||||
|
||||
|
||||
class ExllamaV3Container(BaseModelContainer):
|
||||
|
|
@ -46,6 +46,8 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
cache: Cache
|
||||
tokenizer: Tokenizer
|
||||
config: Config
|
||||
gpu_split: List[float] = []
|
||||
max_seq_len: int = 2048
|
||||
|
||||
# Required methods
|
||||
@classmethod
|
||||
|
|
@ -74,6 +76,16 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
max_seq_len = kwargs.get("max_seq_len")
|
||||
self.cache = Cache(self.model, max_num_tokens=max_seq_len)
|
||||
gpu_split = unwrap(kwargs.get("gpu_split"), [])
|
||||
|
||||
# Set GPU split options
|
||||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
self.gpu_split = gpu_split
|
||||
# Try to set prompt template
|
||||
self.prompt_template = await find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
|
@ -128,7 +140,10 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# TODO: Add draft loading
|
||||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
for value in self.model.load_gen(callback=progress_callback):
|
||||
for value in self.model.load_gen(
|
||||
use_per_device=self.gpu_split,
|
||||
callback=progress_callback
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
|
|
@ -263,7 +278,58 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
A dictionary containing the generation info
|
||||
"""
|
||||
|
||||
pass
|
||||
generations = []
|
||||
async for generation in self.stream_generate(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
abort_event,
|
||||
mm_embeddings,
|
||||
):
|
||||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
"text": "",
|
||||
"prompt_tokens": 0,
|
||||
"generation_tokens": 0,
|
||||
"tool_calls": None,
|
||||
"offset": [],
|
||||
"token_probs": {},
|
||||
"logprobs": [],
|
||||
}
|
||||
|
||||
if generations:
|
||||
# Get finish_reason first and then shift where -1 points to
|
||||
if "finish_reason" in generations[-1]:
|
||||
finish_reason_gen = generations.pop()
|
||||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
||||
"finish_reason"
|
||||
)
|
||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
||||
else:
|
||||
joined_generation["finish_reason"] = "stop"
|
||||
|
||||
if len(generations) > 0:
|
||||
for generation in generations:
|
||||
joined_generation["text"] += unwrap(generation.get("text"), "")
|
||||
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
|
||||
joined_generation["token_probs"].update(
|
||||
unwrap(generation.get("token_probs"), {})
|
||||
)
|
||||
|
||||
# Include empty logprob dicts for index preservation
|
||||
joined_generation["logprobs"].append(
|
||||
unwrap(generation.get("logprobs"), {})
|
||||
)
|
||||
|
||||
joined_generation["prompt_tokens"] = unwrap(
|
||||
generations[-1].get("prompt_tokens"), 0
|
||||
)
|
||||
joined_generation["generated_tokens"] = unwrap(
|
||||
generations[-1].get("generated_tokens"), 0
|
||||
)
|
||||
|
||||
return joined_generation
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
|
|
@ -287,5 +353,152 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
Generation chunks
|
||||
"""
|
||||
|
||||
if False:
|
||||
yield
|
||||
try:
|
||||
# Wait for load lock to be freed before processing
|
||||
# Mainly used for loras and other operations where the class is available
|
||||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
# If the model is being unloaded, don't accept new requests
|
||||
if not self.loaded:
|
||||
raise RuntimeError(
|
||||
"Model is being unloaded. Cannot process new generation requests."
|
||||
)
|
||||
|
||||
# Mark that the job is running
|
||||
self.active_job_ids[request_id] = None
|
||||
|
||||
# Yield from the internal generator
|
||||
async for generation_chunk in self.generate_gen(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
abort_event=abort_event,
|
||||
mm_embeddings=mm_embeddings,
|
||||
):
|
||||
yield generation_chunk
|
||||
finally:
|
||||
# Clean up and remove the job from active IDs
|
||||
del self.active_job_ids[request_id]
|
||||
|
||||
def handle_finish_chunk(self, result: dict, generation: dict):
|
||||
eos_reason = result.get("eos_reason")
|
||||
|
||||
stop_str = None
|
||||
if eos_reason == "max_new_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
# Grab stop string if stop was the reason
|
||||
if eos_reason == "stop_token":
|
||||
stop_str = result.get("eos_triggering_token_str")
|
||||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
|
||||
finish_chunk = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
||||
return finish_chunk
|
||||
|
||||
async def generate_gen(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
||||
for kwargs, check common/sampling.py
|
||||
"""
|
||||
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
prompts = [prompt]
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = params.add_bos_token
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
input_ids = [
|
||||
self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
# The first index will always be the positive prompt
|
||||
context_len = input_ids[0].size(dim=-1)
|
||||
|
||||
# Automatically set max_tokens to fill up the context
|
||||
# This should be an OK default, but may be changed in the future
|
||||
max_tokens = unwrap(
|
||||
params.max_tokens,
|
||||
self.max_seq_len - context_len,
|
||||
)
|
||||
if max_tokens < 1:
|
||||
logger.warning("max_tokens must be a positive integer, setting to 1.")
|
||||
max_tokens = 1
|
||||
|
||||
# Determine if the negative context or the context length is bigger
|
||||
context_to_check = context_len
|
||||
|
||||
# Check total length of prompt against max context length
|
||||
if context_to_check > self.max_seq_len:
|
||||
preamble = "Prompt"
|
||||
|
||||
raise ValueError(
|
||||
f"{preamble} length {context_to_check} is greater than "
|
||||
f"max_seq_len {self.max_seq_len}"
|
||||
)
|
||||
|
||||
self.generator = AsyncGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
generation = {}
|
||||
print(max_tokens)
|
||||
job = AsyncJob(
|
||||
self.generator,
|
||||
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
||||
max_new_tokens=max_tokens,
|
||||
stop_conditions=stop_conditions,
|
||||
)
|
||||
generated_tokens = 0
|
||||
full_response = ""
|
||||
async for result in job:
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
if chunk:
|
||||
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
||||
full_response += chunk
|
||||
if isinstance(chunk_tokens, torch.Tensor):
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
yield generation
|
||||
|
||||
if result.get("eos"):
|
||||
generation = self.handle_finish_chunk(result, generation)
|
||||
yield generation
|
||||
# Assign the active job to the request ID
|
||||
self.active_job_ids[request_id] = job
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue