Api: Add token endpoints

Support for encoding and decoding with various parameters.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-14 22:43:37 -05:00
parent 2d741653c3
commit 8fea5391a8
3 changed files with 60 additions and 1 deletions

30
OAI/types/token.py Normal file
View file

@ -0,0 +1,30 @@
from pydantic import BaseModel
from typing import List
class CommonTokenRequest(BaseModel):
add_bos: bool = True
encode_special_tokens: bool = True
decode_special_tokens: bool = True
def get_params(self):
return {
"add_bos": self.add_bos,
"encode_special_tokens": self.encode_special_tokens,
"decode_special_tokens": self.decode_special_tokens
}
class TokenEncodeRequest(CommonTokenRequest):
text: str
class TokenEncodeResponse(BaseModel):
tokens: List[int]
length: int
class TokenDecodeRequest(CommonTokenRequest):
tokens: List[int]
class TokenDecodeResponse(BaseModel):
text: str
class TokenCountResponse(BaseModel):
length: int

15
main.py
View file

@ -8,6 +8,7 @@ from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.types.completions import CompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
from OAI.utils import create_completion_response, get_model_list
from typing import Optional
from utils import load_progress
@ -80,6 +81,20 @@ async def unload_model():
model_container.unload()
model_container = None
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
async def encode_tokens(data: TokenEncodeRequest):
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response.model_dump_json()
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text=message)
return response.model_dump_json()
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:

View file

@ -11,7 +11,7 @@ from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
from typing import Optional
from typing import List, Optional
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
@ -195,6 +195,20 @@ class ModelContainer:
torch.cuda.empty_cache()
# Common function for token operations
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
if text:
# Assume token encoding
return self.tokenizer.encode(
text, add_bos = kwargs.get("add_bos", True),
encode_special_tokens = kwargs.get("encode_special_tokens", True)
)
if ids:
# Assume token decoding
ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens", True))[0]
def generate(self, prompt: str, **kwargs):
gen = self.generate_gen(prompt, **kwargs)
reponse = "".join(gen)