diff --git a/main.py b/main.py index 4227409..05c880e 100644 --- a/main.py +++ b/main.py @@ -124,7 +124,10 @@ async def unload_model(): # Encode tokens endpoint @app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def encode_tokens(data: TokenEncodeRequest): - tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist() + raw_tokens = model_container.get_tokens(data.text, None, **data.get_params()) + + # Have to use this if check otherwise Torch's tensors error out with a boolean issue + tokens = raw_tokens[0].tolist() if raw_tokens is not None else [] response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) return response @@ -133,7 +136,7 @@ async def encode_tokens(data: TokenEncodeRequest): @app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def decode_tokens(data: TokenDecodeRequest): message = model_container.get_tokens(None, data.tokens, **data.get_params()) - response = TokenDecodeResponse(text=message) + response = TokenDecodeResponse(text = message or "") return response