Model: Store active jobs in tabby
Rather than relying on the generator, use tabby to store the active job IDs. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
1afc9b983e
commit
3f1d5d396e
1 changed files with 16 additions and 9 deletions
|
|
@ -106,7 +106,7 @@ class ExllamaV2Container:
|
|||
# Load synchronization
|
||||
# The lock keeps load tasks sequential
|
||||
# The condition notifies any waiting tasks
|
||||
active_job_ids: Dict[str, ExLlamaV2DynamicJobAsync] = {}
|
||||
active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {}
|
||||
load_lock: asyncio.Lock = asyncio.Lock()
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
|
|
@ -531,12 +531,11 @@ class ExllamaV2Container:
|
|||
"Clients will have their requests cancelled.\n"
|
||||
)
|
||||
|
||||
# Requires a copy to avoid errors during iteration
|
||||
jobs_copy = self.generator.jobs.copy()
|
||||
for job in jobs_copy.values():
|
||||
await job.cancel()
|
||||
for job in self.active_job_ids.values():
|
||||
if job:
|
||||
await job.cancel()
|
||||
|
||||
while self.generator.jobs:
|
||||
while len(self.active_job_ids) > 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def load(self, progress_callback=None):
|
||||
|
|
@ -1237,6 +1236,9 @@ class ExllamaV2Container:
|
|||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
# Mark that the job is running
|
||||
self.active_job_ids[request_id] = None
|
||||
|
||||
prompts = [prompt]
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
|
@ -1354,7 +1356,6 @@ class ExllamaV2Container:
|
|||
|
||||
# Create and add a new job
|
||||
# Don't use the request ID here as there can be multiple jobs per request
|
||||
job_id = uuid.uuid4().hex
|
||||
job = ExLlamaV2DynamicJobAsync(
|
||||
self.generator,
|
||||
input_ids=input_ids,
|
||||
|
|
@ -1370,10 +1371,13 @@ class ExllamaV2Container:
|
|||
return_logits=params.logprobs > 0,
|
||||
banned_strings=banned_strings,
|
||||
token_healing=params.token_healing,
|
||||
identifier=job_id,
|
||||
identifier=request_id,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
|
||||
# Assign the active job to the request ID
|
||||
self.active_job_ids[request_id] = job
|
||||
|
||||
# Save generated tokens and full response
|
||||
# Copy over max seq len incase model is unloaded and stored jobs can complete
|
||||
# Full response is required for offset calculation
|
||||
|
|
@ -1393,7 +1397,7 @@ class ExllamaV2Container:
|
|||
stage = result.get("stage")
|
||||
result_id = result.get("identifier")
|
||||
|
||||
if stage == "streaming" and result_id == job_id:
|
||||
if stage == "streaming" and result_id == request_id:
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
full_response += chunk
|
||||
|
||||
|
|
@ -1501,3 +1505,6 @@ class ExllamaV2Container:
|
|||
context_len,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
# Remove the job from active IDs
|
||||
del self.active_job_ids[request_id]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue