Model: Add logprobs support
Returns token offsets, selected tokens, probabilities of tokens post-sampling, and normalized probability of selecting a token pre-sampling (for efficiency purposes). Only for text completions. Chat completions in a later commit. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2642ef7156
commit
0af6a38af3
6 changed files with 145 additions and 52 deletions
|
|
@ -1,19 +1,10 @@
|
|||
""" Common types for OAI. """
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from common.sampling import BaseSamplerRequest
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
"""Represents log probabilities."""
|
||||
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
"""Represents usage stats."""
|
||||
|
||||
|
|
@ -29,6 +20,10 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
# This parameter is not used, the loaded model is used instead
|
||||
model: Optional[str] = None
|
||||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = 0
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
|
|
@ -36,9 +31,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
echo: Optional[bool] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=False
|
||||
)
|
||||
logprobs: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
n: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=1
|
||||
)
|
||||
|
|
@ -49,5 +41,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
def to_gen_params(self):
|
||||
extra_gen_params = {"logprobs": self.logprobs}
|
||||
|
||||
return super().to_gen_params(**extra_gen_params)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,19 @@
|
|||
""" Completion API protocols """
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
|
||||
from OAI.types.common import CommonCompletionRequest, UsageStats
|
||||
|
||||
|
||||
class CompletionLogProbs(BaseModel):
|
||||
"""Represents log probabilities for a completion request."""
|
||||
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionRespChoice(BaseModel):
|
||||
|
|
@ -13,7 +22,7 @@ class CompletionRespChoice(BaseModel):
|
|||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
text: str
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,22 +9,40 @@ from OAI.types.chat_completion import (
|
|||
ChatCompletionResponse,
|
||||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice
|
||||
from OAI.types.completion import (
|
||||
CompletionResponse,
|
||||
CompletionRespChoice,
|
||||
CompletionLogProbs,
|
||||
)
|
||||
from OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def create_completion_response(
|
||||
text: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
model_name: Optional[str],
|
||||
):
|
||||
def create_completion_response(**kwargs):
|
||||
"""Create a completion response from the provided text."""
|
||||
choice = CompletionRespChoice(finish_reason="Generated", text=text)
|
||||
|
||||
token_probs = unwrap(kwargs.get("token_probs"), {})
|
||||
logprobs = unwrap(kwargs.get("logprobs"), [])
|
||||
offset = unwrap(kwargs.get("offset"), [])
|
||||
|
||||
logprob_response = CompletionLogProbs(
|
||||
text_offset=offset if isinstance(offset, list) else [offset],
|
||||
token_logprobs=token_probs.values(),
|
||||
tokens=token_probs.keys(),
|
||||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
)
|
||||
|
||||
choice = CompletionRespChoice(
|
||||
finish_reason="Generated",
|
||||
text=unwrap(kwargs.get("text"), ""),
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(kwargs.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(kwargs.get("completion_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
choices=[choice],
|
||||
model=unwrap(model_name, ""),
|
||||
model=unwrap(kwargs.get("model_name"), ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
|
@ -37,12 +55,12 @@ def create_completion_response(
|
|||
|
||||
def create_chat_completion_response(
|
||||
text: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
prompt_tokens: Optional[int],
|
||||
completion_tokens: Optional[int],
|
||||
model_name: Optional[str],
|
||||
):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
message = ChatCompletionMessage(role="assistant", content=text)
|
||||
message = ChatCompletionMessage(role="assistant", content=unwrap(text, ""))
|
||||
|
||||
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)
|
||||
|
||||
|
|
|
|||
|
|
@ -472,20 +472,72 @@ class ExllamaV2Container:
|
|||
"unk_token": self.tokenizer.unk_token,
|
||||
}
|
||||
|
||||
def get_logprobs(self, logits: torch.Tensor, max_logprobs: int):
|
||||
normalized_logits = torch.log_softmax(logits, dim=-1)
|
||||
top_values, top_ids = torch.topk(normalized_logits, max_logprobs, dim=-1)
|
||||
|
||||
top_tokens = list(
|
||||
map(
|
||||
lambda index: self.tokenizer.extended_id_to_piece.get(
|
||||
index, self.tokenizer.id_to_piece[index]
|
||||
),
|
||||
top_ids[0].tolist(),
|
||||
)
|
||||
)
|
||||
top_values = top_values[0].tolist()
|
||||
|
||||
return dict(zip(top_tokens, top_values, strict=True))
|
||||
|
||||
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
|
||||
tokens = list(
|
||||
map(
|
||||
lambda index: self.tokenizer.extended_id_to_piece.get(
|
||||
index, self.tokenizer.id_to_piece[index]
|
||||
),
|
||||
token_ids[0].tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
return dict(zip(tokens, token_probs[0].tolist(), strict=True))
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
generations = list(self.generate_gen(prompt, **kwargs))
|
||||
|
||||
joined_generation = {
|
||||
"chunk": "",
|
||||
"prompt_tokens": 0,
|
||||
"generation_tokens": 0,
|
||||
"offset": [],
|
||||
"token_probs": {},
|
||||
"logprobs": [],
|
||||
}
|
||||
|
||||
if generations:
|
||||
for generation in generations:
|
||||
joined_generation["chunk"] += unwrap(generation.get("chunk"), "")
|
||||
joined_generation["offset"].append(unwrap(generation.get("offset"), []))
|
||||
joined_generation["token_probs"].update(
|
||||
unwrap(generation.get("token_probs"), {})
|
||||
)
|
||||
joined_generation["logprobs"].append(
|
||||
unwrap(generation.get("logprobs"), {})
|
||||
)
|
||||
|
||||
joined_generation["prompt_tokens"] = unwrap(
|
||||
generations[-1].get("prompt_tokens"), 0
|
||||
)
|
||||
joined_generation["generation_tokens"] = unwrap(
|
||||
generations[-1].get("generated_tokens"), 0
|
||||
)
|
||||
|
||||
return joined_generation
|
||||
|
||||
def check_unsupported_settings(self, **kwargs):
|
||||
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
|
||||
|
||||
pass
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
generation = list(self.generate_gen(prompt, **kwargs))
|
||||
if generation:
|
||||
response = "".join(map(lambda chunk: chunk[0], generation))
|
||||
return response, generation[-1][1], generation[-1][2]
|
||||
|
||||
return "", 0, 0
|
||||
|
||||
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
||||
def generate_gen(self, prompt: str, **kwargs):
|
||||
"""
|
||||
|
|
@ -639,6 +691,7 @@ class ExllamaV2Container:
|
|||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
logit_bias = kwargs.get("logit_bias")
|
||||
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
if gen_settings.temperature == 0:
|
||||
|
|
@ -657,6 +710,7 @@ class ExllamaV2Container:
|
|||
generate_window=generate_window,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
|
|
@ -758,7 +812,7 @@ class ExllamaV2Container:
|
|||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
# Generate
|
||||
chunk, eos, tokens, _, _ = self.generator.stream()
|
||||
chunk, eos, tokens, token_probs, logits = self.generator.stream()
|
||||
|
||||
if token_healing:
|
||||
# Extract healed token
|
||||
|
|
@ -780,7 +834,27 @@ class ExllamaV2Container:
|
|||
if chunk_buffer != "" and (
|
||||
elapsed > stream_interval or eos or generated_tokens == max_tokens
|
||||
):
|
||||
yield chunk_buffer, prompt_tokens, generated_tokens
|
||||
generation = {
|
||||
"chunk": chunk_buffer,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
|
||||
if request_logprobs > 0:
|
||||
# Get sampled token probs
|
||||
if token_probs.numel() > 0 and tokens.numel() > 0:
|
||||
generation["token_probs"] = self.get_token_probs(
|
||||
tokens, token_probs
|
||||
)
|
||||
|
||||
# Get logprob choices
|
||||
if logits.numel() > 0:
|
||||
generation["logprobs"] = self.get_logprobs(
|
||||
logits, request_logprobs
|
||||
)
|
||||
|
||||
yield generation
|
||||
full_response += chunk_buffer
|
||||
chunk_buffer = ""
|
||||
last_chunk_time = now
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[1.0],
|
||||
)
|
||||
|
||||
def to_gen_params(self):
|
||||
def to_gen_params(self, **kwargs):
|
||||
"""Converts samplers to internal generation params"""
|
||||
|
||||
# Add forced overrides if present
|
||||
|
|
@ -201,7 +201,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"negative_prompt": self.negative_prompt,
|
||||
}
|
||||
|
||||
return gen_params
|
||||
return {**gen_params, **kwargs}
|
||||
|
||||
|
||||
# Global for default overrides
|
||||
|
|
|
|||
16
main.py
16
main.py
|
|
@ -458,12 +458,13 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
new_generation = MODEL_CONTAINER.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
for part, prompt_tokens, completion_tokens in new_generation:
|
||||
for generation in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(
|
||||
part, prompt_tokens, completion_tokens, model_path.name
|
||||
**generation,
|
||||
model_name=model_path.name,
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
|
|
@ -479,13 +480,10 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
generate_with_semaphore(generator), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
||||
generation = await call_with_semaphore(
|
||||
partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params())
|
||||
)
|
||||
|
||||
response = create_completion_response(
|
||||
response_text, prompt_tokens, completion_tokens, model_path.name
|
||||
)
|
||||
response = create_completion_response(**generation)
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -545,12 +543,12 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
new_generation = MODEL_CONTAINER.generate_gen(
|
||||
prompt, **data.to_gen_params()
|
||||
)
|
||||
for part, _, _ in new_generation:
|
||||
for generation in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id, part, model_path.name
|
||||
const_id, generation.get("chunk"), model_path.name
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue