Requirements: Add fastchat and override pydantic

Use an older version of pydantic to stay compatible

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-15 01:00:08 -05:00
parent bbb59d0747
commit 1f444c8fb7
2 changed files with 8 additions and 8 deletions

16
main.py
View file

@ -25,7 +25,7 @@ async def list_models():
model_config = config["model"] model_config = config["model"]
models = get_model_list(pathlib.Path(model_config["model_dir"] or "models")) models = get_model_list(pathlib.Path(model_config["model_dir"] or "models"))
return models.model_dump_json() return models.json()
@app.get("/v1/model", dependencies=[Depends(check_api_key)]) @app.get("/v1/model", dependencies=[Depends(check_api_key)])
async def get_current_model(): async def get_current_model():
@ -33,7 +33,7 @@ async def get_current_model():
return HTTPException(400, "No models are loaded.") return HTTPException(400, "No models are loaded.")
model_card = ModelCard(id=model_container.get_model_path().name) model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.model_dump_json() return model_card.json()
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest): async def load_model(data: ModelLoadRequest):
@ -61,13 +61,13 @@ async def load_model(data: ModelLoadRequest):
module=module, module=module,
modules=modules, modules=modules,
status="processing" status="processing"
).model_dump_json() ).json()
yield ModelLoadResponse( yield ModelLoadResponse(
module=module, module=module,
modules=modules, modules=modules,
status="finished" status="finished"
).model_dump_json() ).json()
return EventSourceResponse(generator()) return EventSourceResponse(generator())
@ -86,14 +86,14 @@ async def encode_tokens(data: TokenEncodeRequest):
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist() tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response.model_dump_json() return response.json()
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)]) @app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
async def decode_tokens(data: TokenDecodeRequest): async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params()) message = model_container.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text=message) response = TokenDecodeResponse(text=message)
return response.model_dump_json() return response.json()
@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) @app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest): async def generate_completion(request: Request, data: CompletionRequest):
@ -106,14 +106,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
response = create_completion_response(part, index, model_container.get_model_path().name) response = create_completion_response(part, index, model_container.get_model_path().name)
yield response.model_dump_json() yield response.json()
return EventSourceResponse(generator()) return EventSourceResponse(generator())
else: else:
response_text = model_container.generate(**data.to_gen_params()) response_text = model_container.generate(**data.to_gen_params())
response = create_completion_response(response_text, 0, model_container.get_model_path().name) response = create_completion_response(response_text, 0, model_container.get_model_path().name)
return response.model_dump_json() return response.json()
if __name__ == "__main__": if __name__ == "__main__":
# Initialize auth keys # Initialize auth keys

Binary file not shown.