API: Fix responses and some params

Responses were not being properly sent as JSON. Only run pydantic's
JSON function on stream responses. FastAPI does the rest with static
responses.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-16 17:11:55 -05:00
parent d8d61fa19b
commit 282b5b2931
2 changed files with 19 additions and 14 deletions

View file

@ -79,7 +79,7 @@ class CommonCompletionRequest(BaseModel):
"repetition_penalty": self.repetition_penalty,
"repetition_penalty_range": self.repetition_penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": True if self.mirostat_mode == 2 else False,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
}

31
main.py
View file

@ -48,13 +48,14 @@ async def list_models():
models = get_model_list(model_path)
return models.json()
return models
# Currently loaded model endpoint
@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_current_model():
model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.json()
return model_card
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
@ -84,17 +85,21 @@ async def load_model(data: ModelLoadRequest):
else:
loading_bar.next()
yield ModelLoadResponse(
response = ModelLoadResponse(
module=module,
modules=modules,
status="processing"
).json()
)
yield ModelLoadResponse(
yield response.json(ensure_ascii=False)
response = ModelLoadResponse(
module=module,
modules=modules,
status="finished"
).json()
)
yield response.json(ensure_ascii=False)
return EventSourceResponse(generator())
@ -112,7 +117,7 @@ async def encode_tokens(data: TokenEncodeRequest):
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response.json()
return response
# Decode tokens endpoint
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@ -120,7 +125,7 @@ async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text=message)
return response.json()
return response
# Completions endpoint
@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@ -139,14 +144,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
response = create_completion_response(part, model_path.name)
yield response.json()
yield response.json(ensure_ascii=False)
return EventSourceResponse(generator())
else:
response_text = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text, model_path.name)
return response.json()
return response
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@ -172,14 +177,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
model_path.name
)
yield response.json()
yield response.json(ensure_ascii=False)
return EventSourceResponse(generator())
else:
response_text = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text, model_path.name)
return response.json()
return response
if __name__ == "__main__":
# Initialize auth keys
@ -196,7 +201,7 @@ if __name__ == "__main__":
model_config = config.get("model", {})
if "model_name" in model_config:
model_path = pathlib.Path(model_config.get("model", "models"))
model_path = pathlib.Path(model_config.get("model_dir", "models"))
model_path = model_path / model_config["model_name"]
model_container = ModelContainer(model_path, False, **model_config)