Model: Add abort on generation
When the model is processing a prompt, add the ability to abort on request cancellation. This is also a catch for a SIGINT. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7020a0a2d1
commit
07d9b7cf7b
3 changed files with 21 additions and 5 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import gc
|
||||
import pathlib
|
||||
import threading
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
|
@ -623,14 +624,18 @@ class ExllamaV2Container:
|
|||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(self, prompt: str, **kwargs):
|
||||
async def generate_gen(
|
||||
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
|
||||
):
|
||||
"""Basic async wrapper for completion generator"""
|
||||
|
||||
sync_generator = self.generate_gen_sync(prompt, **kwargs)
|
||||
sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs)
|
||||
async for value in iterate_in_threadpool(sync_generator):
|
||||
yield value
|
||||
|
||||
def generate_gen_sync(self, prompt: str, **kwargs):
|
||||
def generate_gen_sync(
|
||||
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
||||
|
|
@ -893,6 +898,7 @@ class ExllamaV2Container:
|
|||
return_probabilities=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
abort_event=abort_event,
|
||||
)
|
||||
else:
|
||||
self.generator.begin_stream_ex(
|
||||
|
|
@ -903,6 +909,7 @@ class ExllamaV2Container:
|
|||
return_probabilities=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
abort_event=abort_event,
|
||||
)
|
||||
|
||||
# Reset offsets for subsequent passes if the context is truncated
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from asyncio import CancelledError
|
||||
import pathlib
|
||||
import threading
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -161,8 +162,11 @@ async def stream_generate_chat_completion(
|
|||
"""Generator for the generation process."""
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
abort_event = threading.Event()
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in new_generation:
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
|
|
@ -174,6 +178,7 @@ async def stream_generate_chat_completion(
|
|||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
import threading
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -64,8 +65,10 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
|
|||
"""Streaming generation for completions."""
|
||||
|
||||
try:
|
||||
abort_event = threading.Event()
|
||||
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
data.prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in new_generation:
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
|
@ -78,6 +81,7 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
|
|||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue