API: Use FastAPI streaming instead of sse_starlette

sse_starlette kept firing a ping response if it was taking too long
to set an event. Rather than using a hacky workaround, switch to
FastAPI's inbuilt streaming response and construct SSE requests with
a utility function.

This helps the API become more robust and removes an extra requirement.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-01 01:54:35 -05:00
parent 6493b1d2aa
commit ae69b18583
4 changed files with 15 additions and 12 deletions

20
main.py
View file

@ -4,9 +4,9 @@ import pathlib
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
@ -24,7 +24,7 @@ from OAI.utils import (
create_chat_completion_stream_chunk
)
from typing import Optional
from utils import get_generator_error, load_progress
from utils import get_generator_error, get_sse_packet, load_progress
from uuid import uuid4
app = FastAPI()
@ -126,7 +126,7 @@ async def load_model(data: ModelLoadRequest):
status="finished"
)
yield response.json(ensure_ascii=False)
yield get_sse_packet(response.json(ensure_ascii=False))
if model_container.draft_enabled:
model_type = "model"
@ -140,7 +140,7 @@ async def load_model(data: ModelLoadRequest):
status="processing"
)
yield response.json(ensure_ascii=False)
yield get_sse_packet(response.json(ensure_ascii=False))
except Exception as e:
yield get_generator_error(e)
load_failed = True
@ -149,7 +149,7 @@ async def load_model(data: ModelLoadRequest):
model_container.unload()
model_container = None
return EventSourceResponse(generator())
return StreamingResponse(generator(), media_type = "text/event-stream")
# Unload model endpoint
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
@ -199,11 +199,11 @@ async def generate_completion(request: Request, data: CompletionRequest):
completion_tokens,
model_path.name)
yield response.json(ensure_ascii=False)
yield get_sse_packet(response.json(ensure_ascii=False))
except Exception as e:
yield get_generator_error(e)
return EventSourceResponse(generator())
return StreamingResponse(generator(), media_type = "text/event-stream")
else:
response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text,
@ -238,7 +238,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
model_path.name
)
yield response.json(ensure_ascii=False)
yield get_sse_packet(response.json(ensure_ascii=False))
except Exception as e:
yield get_generator_error(e)
finally:
@ -249,9 +249,9 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
finish_reason = "stop"
)
yield finish_response.json(ensure_ascii=False)
yield get_sse_packet(finish_response.json(ensure_ascii=False))
return EventSourceResponse(generator())
return StreamingResponse(generator(), media_type = "text/event-stream")
else:
response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text,

View file

@ -311,6 +311,7 @@ class ModelContainer:
stop_conditions: List[Union[str, int]] = kwargs.get("stop", [])
ban_eos_token = kwargs.get("ban_eos_token", False)
# Ban the EOS token if specified. If not, append to stop conditions as well.
if ban_eos_token:

View file

@ -2,7 +2,6 @@
pydantic < 2,>= 1
PyYAML
progress
sse_starlette
uvicorn
# Wheels

View file

@ -26,4 +26,7 @@ def get_generator_error(exception: Exception):
# Log and send the exception
print(f"\n{generator_error.error.trace}")
return generator_error.json()
return get_sse_packet(generator_error.json(ensure_ascii = False))
def get_sse_packet(json_data: str):
return f"data: {json_data}\n\n"