Forward exceptions from _stream_collector to stream_generate_(chat)_completion (#126)

This commit is contained in:
turboderp 2024-06-03 19:42:45 +02:00 committed by GitHub
parent 0eb8fa5d1e
commit 1951f7521c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 28 additions and 12 deletions

View file

@ -211,14 +211,17 @@ async def _stream_collector(
):
"""Collects a stream and places results in a common queue"""
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
await gen_queue.put(generation)
if "finish_reason" in generation:
break
if "finish_reason" in generation:
break
except Exception as e:
await gen_queue.put(e)
async def stream_generate_chat_completion(
@ -253,6 +256,11 @@ async def stream_generate_chat_completion(
handle_request_disconnect("Completion generation cancelled by user.")
generation = await gen_queue.get()
# Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception):
raise generation
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()

View file

@ -82,14 +82,17 @@ async def _stream_collector(
):
"""Collects a stream and places results in a common queue"""
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
await gen_queue.put(generation)
if "finish_reason" in generation:
break
if "finish_reason" in generation:
break
except Exception as e:
await gen_queue.put(e)
async def stream_generate_completion(
@ -126,6 +129,11 @@ async def stream_generate_completion(
handle_request_disconnect("Completion generation cancelled by user.")
generation = await gen_queue.get()
# Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception):
raise generation
response = _create_response(generation, model_path.name)
yield response.model_dump_json()