fixup: autosplit, start work on metrics

This commit is contained in:
randoentity 2025-04-30 11:10:03 +02:00 committed by kingbri
parent 306fc7cd15
commit c0f268f33e

View file

@ -1,8 +1,6 @@
import asyncio
import gc
import math
import pathlib
from loguru import logger
from typing import (
Any,
AsyncIterator,
@ -12,9 +10,14 @@ from typing import (
)
import torch
from exllamav3 import AsyncGenerator, AsyncJob, Cache, Config, Model, Tokenizer
from loguru import logger
from backends.base_model_container import BaseModelContainer
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_metrics,
)
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
@ -22,8 +25,6 @@ from common.transformers_utils import GenerationConfig
from common.utils import unwrap
from endpoints.core.types.model import ModelCard
from exllamav3 import AsyncGenerator, AsyncJob, Config, Model, Cache, Tokenizer
class ExllamaV3Container(BaseModelContainer):
"""Abstract base class for model containers."""
@ -112,7 +113,7 @@ class ExllamaV3Container(BaseModelContainer):
# Reserve VRAM for each GPU
self.autosplit_reserve = [
int(math.ceil(value/1024))
value/1024
for value in autosplit_reserve_megabytes
]
# TODO: speculative decoding
@ -504,15 +505,17 @@ class ExllamaV3Container(BaseModelContainer):
)
generation = {}
print(max_tokens)
job = AsyncJob(
self.generator,
input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens,
stop_conditions=stop_conditions,
)
generated_tokens = 0
full_response = ""
metrics_result = {}
async for result in job:
chunk = unwrap(result.get("text"), "")
if chunk:
@ -530,6 +533,25 @@ class ExllamaV3Container(BaseModelContainer):
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
yield generation
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
# Log the metrics if present
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len,
self.max_seq_len,
)