API: Fix sequential requests
FastAPI is kinda weird with queueing. If an await is used within an async def, requests aren't executed sequentially. Get the sequential requests back by using a semaphore to limit concurrent execution from generator functions. Also scaffold the framework to move generator functions to their own file. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
e740b53478
commit
ed6c962aad
2 changed files with 26 additions and 5 deletions
10
generators.py
Normal file
10
generators.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from asyncio import Semaphore
|
||||
from typing import AsyncGenerator
|
||||
|
||||
generate_semaphore = Semaphore(1)
|
||||
|
||||
# Async generation that blocks on a semaphore
|
||||
async def generate_with_semaphore(generator: AsyncGenerator):
|
||||
async with generate_semaphore:
|
||||
async for result in generator():
|
||||
yield result
|
||||
21
main.py
21
main.py
|
|
@ -7,6 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import StreamingResponse
|
||||
from model import ModelContainer
|
||||
from progress.bar import IncrementalBar
|
||||
from generators import generate_with_semaphore
|
||||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||
|
|
@ -200,10 +201,15 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
model_path.name)
|
||||
|
||||
yield get_sse_packet(response.json(ensure_ascii=False))
|
||||
except GeneratorExit:
|
||||
print("Completion response aborted")
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
|
||||
return StreamingResponse(generator(), media_type = "text/event-stream")
|
||||
return StreamingResponse(
|
||||
generate_with_semaphore(generator),
|
||||
media_type = "text/event-stream"
|
||||
)
|
||||
else:
|
||||
response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params())
|
||||
response = create_completion_response(response_text,
|
||||
|
|
@ -238,12 +244,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
model_path.name
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.json(ensure_ascii=False))
|
||||
yield get_sse_packet(response.json(ensure_ascii=False))
|
||||
except GeneratorExit:
|
||||
print("Chat completion response aborted")
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
finally:
|
||||
|
||||
# Always finish no matter what
|
||||
# FIXME: An error currently fires here since the generator is closed, move this somewhere else
|
||||
finish_response = create_chat_completion_stream_chunk(
|
||||
const_id,
|
||||
finish_reason = "stop"
|
||||
|
|
@ -251,7 +259,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
|
||||
yield get_sse_packet(finish_response.json(ensure_ascii=False))
|
||||
|
||||
return StreamingResponse(generator(), media_type = "text/event-stream")
|
||||
return StreamingResponse(
|
||||
generate_with_semaphore(generator),
|
||||
media_type = "text/event-stream"
|
||||
)
|
||||
else:
|
||||
response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params())
|
||||
response = create_chat_completion_response(response_text,
|
||||
|
|
@ -283,7 +294,7 @@ if __name__ == "__main__":
|
|||
if "model_name" in model_config:
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
||||
|
||||
model_container = ModelContainer(model_path.resolve(), False, **model_config)
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue