API: Persist request IDs and append full_text to finish chunk
Adding these to each generation chunk helps remove redundancy and unecessary request ID operations. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
e77fa0b7a8
commit
0b4ca567f8
3 changed files with 19 additions and 9 deletions
|
|
@ -909,7 +909,9 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
"request_id": "",
|
||||
"text": "",
|
||||
"full_text": "",
|
||||
"prompt_tokens": 0,
|
||||
"gen_tokens": 0,
|
||||
"tool_calls": None,
|
||||
|
|
@ -923,12 +925,12 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
if "finish_reason" in generations[-1]:
|
||||
finish_chunk = generations.pop()
|
||||
joined_generation = {**joined_generation, **finish_chunk}
|
||||
joined_generation["text"] = joined_generation.get("full_text", "")
|
||||
else:
|
||||
joined_generation["finish_reason"] = "stop"
|
||||
|
||||
if len(generations) > 0:
|
||||
for generation in generations:
|
||||
joined_generation["text"] += unwrap(generation.get("text"), "")
|
||||
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
|
||||
joined_generation["token_probs"].update(
|
||||
unwrap(generation.get("token_probs"), {})
|
||||
|
|
@ -1170,7 +1172,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
}
|
||||
|
||||
# Creates and returns a finish chunk
|
||||
def handle_finish_chunk(self, result: dict, generation: dict):
|
||||
def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
|
||||
eos_reason = result.get("eos_reason")
|
||||
|
||||
stop_str = None
|
||||
|
|
@ -1204,6 +1206,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
total_time = round(queue_time + prompt_time + gen_time, 2)
|
||||
|
||||
finish_chunk = {
|
||||
"request_id": request_id,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_time": round(prompt_time, 2),
|
||||
"prompt_tokens_per_sec": prompt_ts,
|
||||
|
|
@ -1215,6 +1218,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
"cached_tokens": cached_tokens,
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
"full_text": full_text,
|
||||
}
|
||||
|
||||
return finish_chunk
|
||||
|
|
@ -1414,6 +1418,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
|
||||
generation = {
|
||||
"request_id": request_id,
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
|
|
@ -1434,7 +1439,9 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
if result.get("eos"):
|
||||
log_response(request_id, full_response)
|
||||
|
||||
finish_chunk = self.handle_finish_chunk(result, generation)
|
||||
finish_chunk = self.handle_finish_chunk(
|
||||
result, request_id, full_response
|
||||
)
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = finish_chunk
|
||||
|
|
|
|||
|
|
@ -730,7 +730,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# Clean up and remove the job from active IDs
|
||||
del self.active_job_ids[request_id]
|
||||
|
||||
def handle_finish_chunk(self, result: dict, generation: dict):
|
||||
def handle_finish_chunk(self, result: dict, request_id: str, full_text: str):
|
||||
eos_reason = result.get("eos_reason")
|
||||
|
||||
stop_str = None
|
||||
|
|
@ -764,6 +764,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
total_time = round(queue_time + prompt_time + gen_time, 2)
|
||||
|
||||
finish_chunk = {
|
||||
"request_id": request_id,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_time": round(prompt_time, 2),
|
||||
"prompt_tokens_per_sec": prompt_ts,
|
||||
|
|
@ -775,6 +776,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
"cached_tokens": cached_tokens,
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
"full_text": full_text,
|
||||
}
|
||||
|
||||
return finish_chunk
|
||||
|
|
@ -940,6 +942,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
generation = {
|
||||
"request_id": request_id,
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
|
|
@ -948,7 +951,9 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
yield generation
|
||||
|
||||
if result.get("eos"):
|
||||
finish_chunk = self.handle_finish_chunk(result, generation)
|
||||
finish_chunk = self.handle_finish_chunk(
|
||||
result, request_id, full_response
|
||||
)
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = finish_chunk
|
||||
|
|
|
|||
|
|
@ -364,7 +364,6 @@ async def stream_generate_chat_completion(
|
|||
data,
|
||||
[generation],
|
||||
request,
|
||||
current_generation_text=current_generation_text,
|
||||
)
|
||||
|
||||
# Only one generation present in this case
|
||||
|
|
@ -468,7 +467,6 @@ async def generate_tool_calls(
|
|||
data: ChatCompletionRequest,
|
||||
generations: List[str],
|
||||
request: Request,
|
||||
current_generation_text: str = None,
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
tool_start = model.container.prompt_template.metadata.tool_start
|
||||
|
|
@ -487,11 +485,11 @@ async def generate_tool_calls(
|
|||
logger.info(f"Detected tool call in chat completion request {request.state.id}")
|
||||
|
||||
# Append the existing generation text if present
|
||||
precursor_text = current_generation_text or gen.get("text")
|
||||
precursor_text = gen.get("full_text")
|
||||
if precursor_text:
|
||||
prompt = prompt + precursor_text
|
||||
|
||||
gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
||||
gen_request_id = gen.get("request_id")
|
||||
tool_request_id = f"{gen_request_id}-tool"
|
||||
|
||||
gen_tasks.append(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue