From 8fea5391a8047bf43bc406e4ef2cdcec07817af0 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 14 Nov 2023 22:43:37 -0500 Subject: [PATCH] Api: Add token endpoints Support for encoding and decoding with various parameters. Signed-off-by: kingbri --- OAI/types/token.py | 30 ++++++++++++++++++++++++++++++ main.py | 15 +++++++++++++++ model.py | 16 +++++++++++++++- 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 OAI/types/token.py diff --git a/OAI/types/token.py b/OAI/types/token.py new file mode 100644 index 0000000..fdc59c6 --- /dev/null +++ b/OAI/types/token.py @@ -0,0 +1,30 @@ +from pydantic import BaseModel +from typing import List + +class CommonTokenRequest(BaseModel): + add_bos: bool = True + encode_special_tokens: bool = True + decode_special_tokens: bool = True + + def get_params(self): + return { + "add_bos": self.add_bos, + "encode_special_tokens": self.encode_special_tokens, + "decode_special_tokens": self.decode_special_tokens + } + +class TokenEncodeRequest(CommonTokenRequest): + text: str + +class TokenEncodeResponse(BaseModel): + tokens: List[int] + length: int + +class TokenDecodeRequest(CommonTokenRequest): + tokens: List[int] + +class TokenDecodeResponse(BaseModel): + text: str + +class TokenCountResponse(BaseModel): + length: int diff --git a/main.py b/main.py index ef39427..b5a971e 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ from progress.bar import IncrementalBar from sse_starlette import EventSourceResponse from OAI.types.completions import CompletionRequest from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse +from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse from OAI.utils import create_completion_response, get_model_list from typing import Optional from utils import load_progress @@ -80,6 +81,20 @@ async def unload_model(): model_container.unload() model_container = None +@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)]) +async def encode_tokens(data: TokenEncodeRequest): + tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist() + response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) + + return response.model_dump_json() + +@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)]) +async def decode_tokens(data: TokenDecodeRequest): + message = model_container.get_tokens(None, data.tokens, **data.get_params()) + response = TokenDecodeResponse(text=message) + + return response.model_dump_json() + @app.post("/v1/completions", dependencies=[Depends(check_api_key)]) async def generate_completion(request: Request, data: CompletionRequest): if data.stream: diff --git a/model.py b/model.py index 41b51c9..e07291c 100644 --- a/model.py +++ b/model.py @@ -11,7 +11,7 @@ from exllamav2.generator import( ExLlamaV2StreamingGenerator, ExLlamaV2Sampler ) -from typing import Optional +from typing import List, Optional # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 @@ -195,6 +195,20 @@ class ModelContainer: torch.cuda.empty_cache() + # Common function for token operations + def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs): + if text: + # Assume token encoding + return self.tokenizer.encode( + text, add_bos = kwargs.get("add_bos", True), + encode_special_tokens = kwargs.get("encode_special_tokens", True) + ) + if ids: + # Assume token decoding + ids = torch.tensor([ids]) + return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens", True))[0] + + def generate(self, prompt: str, **kwargs): gen = self.generate_gen(prompt, **kwargs) reponse = "".join(gen)