API: Fix responses and some params
Responses were not being properly sent as JSON. Only run pydantic's JSON function on stream responses. FastAPI does the rest with static responses. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
d8d61fa19b
commit
282b5b2931
2 changed files with 19 additions and 14 deletions
|
|
@ -79,7 +79,7 @@ class CommonCompletionRequest(BaseModel):
|
|||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_penalty_range": self.repetition_penalty_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": True if self.mirostat_mode == 2 else False,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"mirostat_eta": self.mirostat_eta
|
||||
}
|
||||
|
|
|
|||
31
main.py
31
main.py
|
|
@ -48,13 +48,14 @@ async def list_models():
|
|||
|
||||
models = get_model_list(model_path)
|
||||
|
||||
return models.json()
|
||||
return models
|
||||
|
||||
# Currently loaded model endpoint
|
||||
@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_current_model():
|
||||
model_card = ModelCard(id=model_container.get_model_path().name)
|
||||
return model_card.json()
|
||||
return model_card
|
||||
|
||||
# Load model endpoint
|
||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
|
|
@ -84,17 +85,21 @@ async def load_model(data: ModelLoadRequest):
|
|||
else:
|
||||
loading_bar.next()
|
||||
|
||||
yield ModelLoadResponse(
|
||||
response = ModelLoadResponse(
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing"
|
||||
).json()
|
||||
)
|
||||
|
||||
yield ModelLoadResponse(
|
||||
yield response.json(ensure_ascii=False)
|
||||
|
||||
response = ModelLoadResponse(
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="finished"
|
||||
).json()
|
||||
)
|
||||
|
||||
yield response.json(ensure_ascii=False)
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
|
||||
|
|
@ -112,7 +117,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
return response.json()
|
||||
return response
|
||||
|
||||
# Decode tokens endpoint
|
||||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
|
|
@ -120,7 +125,7 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=message)
|
||||
|
||||
return response.json()
|
||||
return response
|
||||
|
||||
# Completions endpoint
|
||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
|
|
@ -139,14 +144,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
|
||||
response = create_completion_response(part, model_path.name)
|
||||
|
||||
yield response.json()
|
||||
yield response.json(ensure_ascii=False)
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(data.prompt, **data.to_gen_params())
|
||||
response = create_completion_response(response_text, model_path.name)
|
||||
|
||||
return response.json()
|
||||
return response
|
||||
|
||||
# Chat completions endpoint
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
|
|
@ -172,14 +177,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
model_path.name
|
||||
)
|
||||
|
||||
yield response.json()
|
||||
yield response.json(ensure_ascii=False)
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(prompt, **data.to_gen_params())
|
||||
response = create_chat_completion_response(response_text, model_path.name)
|
||||
|
||||
return response.json()
|
||||
return response
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize auth keys
|
||||
|
|
@ -196,7 +201,7 @@ if __name__ == "__main__":
|
|||
|
||||
model_config = config.get("model", {})
|
||||
if "model_name" in model_config:
|
||||
model_path = pathlib.Path(model_config.get("model", "models"))
|
||||
model_path = pathlib.Path(model_config.get("model_dir", "models"))
|
||||
model_path = model_path / model_config["model_name"]
|
||||
|
||||
model_container = ModelContainer(model_path, False, **model_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue