Model: Add logprobs support

Returns token offsets, selected tokens, probabilities of tokens
post-sampling, and normalized probability of selecting a token
pre-sampling (for efficiency purposes).

Only for text completions. Chat completions in a later commit.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-02-07 21:41:15 -05:00 committed by Brian Dashore
parent 2642ef7156
commit 0af6a38af3
6 changed files with 145 additions and 52 deletions

View file

@ -1,19 +1,10 @@
""" Common types for OAI. """
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
from typing import Optional
from common.sampling import BaseSamplerRequest
class LogProbs(BaseModel):
"""Represents log probabilities."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class UsageStats(BaseModel):
"""Represents usage stats."""
@ -29,6 +20,10 @@ class CommonCompletionRequest(BaseSamplerRequest):
# This parameter is not used, the loaded model is used instead
model: Optional[str] = None
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
logprobs: Optional[int] = 0
# Extra OAI request stuff
best_of: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
@ -36,9 +31,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
logprobs: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
@ -49,5 +41,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
def to_gen_params(self):
extra_gen_params = {"logprobs": self.logprobs}
return super().to_gen_params(**extra_gen_params)

View file

@ -1,10 +1,19 @@
""" Completion API protocols """
from pydantic import BaseModel, Field
from time import time
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from uuid import uuid4
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
from OAI.types.common import CommonCompletionRequest, UsageStats
class CompletionLogProbs(BaseModel):
"""Represents log probabilities for a completion request."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class CompletionRespChoice(BaseModel):
@ -13,7 +22,7 @@ class CompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
text: str

View file

@ -9,22 +9,40 @@ from OAI.types.chat_completion import (
ChatCompletionResponse,
ChatCompletionStreamChoice,
)
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.completion import (
CompletionResponse,
CompletionRespChoice,
CompletionLogProbs,
)
from OAI.types.common import UsageStats
def create_completion_response(
text: str,
prompt_tokens: int,
completion_tokens: int,
model_name: Optional[str],
):
def create_completion_response(**kwargs):
"""Create a completion response from the provided text."""
choice = CompletionRespChoice(finish_reason="Generated", text=text)
token_probs = unwrap(kwargs.get("token_probs"), {})
logprobs = unwrap(kwargs.get("logprobs"), [])
offset = unwrap(kwargs.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
choice = CompletionRespChoice(
finish_reason="Generated",
text=unwrap(kwargs.get("text"), ""),
logprobs=logprob_response,
)
prompt_tokens = unwrap(kwargs.get("prompt_tokens"), 0)
completion_tokens = unwrap(kwargs.get("completion_tokens"), 0)
response = CompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
model=unwrap(kwargs.get("model_name"), ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@ -37,12 +55,12 @@ def create_completion_response(
def create_chat_completion_response(
text: str,
prompt_tokens: int,
completion_tokens: int,
prompt_tokens: Optional[int],
completion_tokens: Optional[int],
model_name: Optional[str],
):
"""Create a chat completion response from the provided text."""
message = ChatCompletionMessage(role="assistant", content=text)
message = ChatCompletionMessage(role="assistant", content=unwrap(text, ""))
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)

View file

@ -472,20 +472,72 @@ class ExllamaV2Container:
"unk_token": self.tokenizer.unk_token,
}
def get_logprobs(self, logits: torch.Tensor, max_logprobs: int):
normalized_logits = torch.log_softmax(logits, dim=-1)
top_values, top_ids = torch.topk(normalized_logits, max_logprobs, dim=-1)
top_tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
),
top_ids[0].tolist(),
)
)
top_values = top_values[0].tolist()
return dict(zip(top_tokens, top_values, strict=True))
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
),
token_ids[0].tolist(),
)
)
return dict(zip(tokens, token_probs[0].tolist(), strict=True))
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generations = list(self.generate_gen(prompt, **kwargs))
joined_generation = {
"chunk": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"offset": [],
"token_probs": {},
"logprobs": [],
}
if generations:
for generation in generations:
joined_generation["chunk"] += unwrap(generation.get("chunk"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), []))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)
joined_generation["prompt_tokens"] = unwrap(
generations[-1].get("prompt_tokens"), 0
)
joined_generation["generation_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
return joined_generation
def check_unsupported_settings(self, **kwargs):
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
pass
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs))
if generation:
response = "".join(map(lambda chunk: chunk[0], generation))
return response, generation[-1][1], generation[-1][2]
return "", 0, 0
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def generate_gen(self, prompt: str, **kwargs):
"""
@ -639,6 +691,7 @@ class ExllamaV2Container:
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias")
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
@ -657,6 +710,7 @@ class ExllamaV2Container:
generate_window=generate_window,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
logprobs=request_logprobs,
stop_conditions=stop_conditions,
logit_bias=logit_bias,
)
@ -758,7 +812,7 @@ class ExllamaV2Container:
gen_settings.token_repetition_range = generated_tokens
# Generate
chunk, eos, tokens, _, _ = self.generator.stream()
chunk, eos, tokens, token_probs, logits = self.generator.stream()
if token_healing:
# Extract healed token
@ -780,7 +834,27 @@ class ExllamaV2Container:
if chunk_buffer != "" and (
elapsed > stream_interval or eos or generated_tokens == max_tokens
):
yield chunk_buffer, prompt_tokens, generated_tokens
generation = {
"chunk": chunk_buffer,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
if request_logprobs > 0:
# Get sampled token probs
if token_probs.numel() > 0 and tokens.numel() > 0:
generation["token_probs"] = self.get_token_probs(
tokens, token_probs
)
# Get logprob choices
if logits.numel() > 0:
generation["logprobs"] = self.get_logprobs(
logits, request_logprobs
)
yield generation
full_response += chunk_buffer
chunk_buffer = ""
last_chunk_time = now

View file

@ -159,7 +159,7 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
def to_gen_params(self):
def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params"""
# Add forced overrides if present
@ -201,7 +201,7 @@ class BaseSamplerRequest(BaseModel):
"negative_prompt": self.negative_prompt,
}
return gen_params
return {**gen_params, **kwargs}
# Global for default overrides

16
main.py
View file

@ -458,12 +458,13 @@ async def generate_completion(request: Request, data: CompletionRequest):
new_generation = MODEL_CONTAINER.generate_gen(
data.prompt, **data.to_gen_params()
)
for part, prompt_tokens, completion_tokens in new_generation:
for generation in new_generation:
if await request.is_disconnected():
break
response = create_completion_response(
part, prompt_tokens, completion_tokens, model_path.name
**generation,
model_name=model_path.name,
)
yield get_sse_packet(response.model_dump_json())
@ -479,13 +480,10 @@ async def generate_completion(request: Request, data: CompletionRequest):
generate_with_semaphore(generator), media_type="text/event-stream"
)
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
generation = await call_with_semaphore(
partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params())
)
response = create_completion_response(
response_text, prompt_tokens, completion_tokens, model_path.name
)
response = create_completion_response(**generation)
return response
@ -545,12 +543,12 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
new_generation = MODEL_CONTAINER.generate_gen(
prompt, **data.to_gen_params()
)
for part, _, _ in new_generation:
for generation in new_generation:
if await request.is_disconnected():
break
response = create_chat_completion_stream_chunk(
const_id, part, model_path.name
const_id, generation.get("chunk"), model_path.name
)
yield get_sse_packet(response.model_dump_json())