Api: Add token endpoints
Support for encoding and decoding with various parameters. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2d741653c3
commit
8fea5391a8
3 changed files with 60 additions and 1 deletions
30
OAI/types/token.py
Normal file
30
OAI/types/token.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
add_bos: bool = True
|
||||
encode_special_tokens: bool = True
|
||||
decode_special_tokens: bool = True
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"add_bos": self.add_bos,
|
||||
"encode_special_tokens": self.encode_special_tokens,
|
||||
"decode_special_tokens": self.decode_special_tokens
|
||||
}
|
||||
|
||||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
text: str
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
tokens: List[int]
|
||||
length: int
|
||||
|
||||
class TokenDecodeRequest(CommonTokenRequest):
|
||||
tokens: List[int]
|
||||
|
||||
class TokenDecodeResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
length: int
|
||||
15
main.py
15
main.py
|
|
@ -8,6 +8,7 @@ from progress.bar import IncrementalBar
|
|||
from sse_starlette import EventSourceResponse
|
||||
from OAI.types.completions import CompletionRequest
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
|
||||
from OAI.utils import create_completion_response, get_model_list
|
||||
from typing import Optional
|
||||
from utils import load_progress
|
||||
|
|
@ -80,6 +81,20 @@ async def unload_model():
|
|||
model_container.unload()
|
||||
model_container = None
|
||||
|
||||
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
|
||||
async def encode_tokens(data: TokenEncodeRequest):
|
||||
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
return response.model_dump_json()
|
||||
|
||||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=message)
|
||||
|
||||
return response.model_dump_json()
|
||||
|
||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
if data.stream:
|
||||
|
|
|
|||
16
model.py
16
model.py
|
|
@ -11,7 +11,7 @@ from exllamav2.generator import(
|
|||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
|
|
@ -195,6 +195,20 @@ class ModelContainer:
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Common function for token operations
|
||||
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
|
||||
if text:
|
||||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text, add_bos = kwargs.get("add_bos", True),
|
||||
encode_special_tokens = kwargs.get("encode_special_tokens", True)
|
||||
)
|
||||
if ids:
|
||||
# Assume token decoding
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(ids, decode_special_tokens = kwargs.get("decode_special_tokens", True))[0]
|
||||
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
gen = self.generate_gen(prompt, **kwargs)
|
||||
reponse = "".join(gen)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue