Merge branch 'main' of https://github.com/ziadloo/tabbyAPI into ziadloo-main

This commit is contained in:
kingbri 2023-11-30 01:01:48 -05:00
commit e703c716ee
6 changed files with 34 additions and 25 deletions

View file

@ -32,8 +32,6 @@ class ChatCompletionResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None
class ChatCompletionStreamChunk(BaseModel):

View file

@ -8,8 +8,8 @@ class LogProbs(BaseModel):
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel):
completion_tokens: int
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CommonCompletionRequest(BaseModel):

View file

@ -22,6 +22,4 @@ class CompletionResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "text_completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None

View file

@ -1,5 +1,5 @@
import os, pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
@ -20,9 +20,7 @@ try:
except ImportError:
_fastchat_available = False
def create_completion_response(text: str, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
choice = CompletionRespChoice(
finish_reason = "Generated",
text = text
@ -30,14 +28,15 @@ def create_completion_response(text: str, model_name: Optional[str]):
response = CompletionResponse(
choices = [choice],
model = model_name or ""
model = model_name or "",
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
)
return response
def create_chat_completion_response(text: str, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
message = ChatCompletionMessage(
role = "assistant",
content = text
@ -50,7 +49,10 @@ def create_chat_completion_response(text: str, model_name: Optional[str]):
response = ChatCompletionResponse(
choices = [choice],
model = model_name or ""
model = model_name or "",
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
)
return response

23
main.py
View file

@ -188,11 +188,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
async def generator():
try:
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
for part in new_generation:
for (part, prompt_tokens, completion_tokens) in new_generation:
if await request.is_disconnected():
break
response = create_completion_response(part, model_path.name)
response = create_completion_response(part,
prompt_tokens,
completion_tokens,
model_path.name)
yield response.json(ensure_ascii=False)
except Exception as e:
@ -200,8 +203,11 @@ async def generate_completion(request: Request, data: CompletionRequest):
return EventSourceResponse(generator())
else:
response_text = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text, model_path.name)
response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text,
prompt_tokens,
completion_tokens,
model_path.name)
return response
@ -219,7 +225,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
const_id = f"chatcmpl-{uuid4().hex}"
async def generator():
try:
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
new_generation, prompt_tokens, completion_tokens = model_container.generate_gen(prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
@ -236,8 +242,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
return EventSourceResponse(generator())
else:
response_text = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text, model_path.name)
response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text,
prompt_tokens,
completion_tokens,
model_path.name)
return response

View file

@ -226,9 +226,9 @@ class ModelContainer:
def generate(self, prompt: str, **kwargs):
gen = self.generate_gen(prompt, **kwargs)
reponse = "".join(gen)
return reponse
gen = list(self.generate_gen(prompt, **kwargs))
reponse = "".join(map(lambda o: o[0], gen))
return reponse, gen[-1][1], gen[-1][2]
def generate_gen(self, prompt: str, **kwargs):
"""
@ -345,6 +345,8 @@ class ModelContainer:
"Generation is truncated and metrics may not be accurate."
)
prompt_tokens = ids.shape[-1]
# Begin
generated_tokens = 0
@ -390,7 +392,7 @@ class ModelContainer:
elapsed = now - last_chunk_time
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
yield chunk_buffer
yield chunk_buffer, prompt_tokens, generated_tokens
full_response += chunk_buffer
chunk_buffer = ""
last_chunk_time = now