API: Default tool call ID and type
Doing this helps reduce the model's burden of generating the tool call ID and type (which is always "function"). Follow mistral's spec for tool call IDs by using a 9 character alphanumeric string. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
5b1db3ad83
commit
707d005aad
3 changed files with 8 additions and 11 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Literal
|
from typing import Dict, Literal
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
class Function(BaseModel):
|
||||||
|
|
@ -29,6 +30,6 @@ class Tool(BaseModel):
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
"""Represents an OAI tool description."""
|
"""Represents an OAI tool description."""
|
||||||
|
|
||||||
id: str
|
id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9])
|
||||||
function: Tool
|
function: Tool
|
||||||
type: Literal["function"]
|
type: Literal["function"] = "function"
|
||||||
|
|
|
||||||
|
|
@ -70,12 +70,11 @@ def _create_response(
|
||||||
|
|
||||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||||
|
|
||||||
# Initialize finish_reason with a default value or from generation data
|
# Set finish reason
|
||||||
finish_reason = generation.get("finish_reason", "stop")
|
|
||||||
|
|
||||||
# If a tool call is present, mark the finish reason as such
|
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
|
else:
|
||||||
|
finish_reason = generation.get("finish_reason", "stop")
|
||||||
|
|
||||||
choice = ChatCompletionRespChoice(
|
choice = ChatCompletionRespChoice(
|
||||||
index=index,
|
index=index,
|
||||||
|
|
@ -152,7 +151,6 @@ def _create_stream_chunk(
|
||||||
choice.finish_reason = "tool_calls"
|
choice.finish_reason = "tool_calls"
|
||||||
|
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
message = ChatCompletionMessage(
|
message = ChatCompletionMessage(
|
||||||
role="assistant", content=unwrap(generation.get("text"), "")
|
role="assistant", content=unwrap(generation.get("text"), "")
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ TOOL_CALL_SCHEMA = {
|
||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": {"type": "string"},
|
|
||||||
"function": {
|
"function": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -23,9 +22,8 @@ TOOL_CALL_SCHEMA = {
|
||||||
},
|
},
|
||||||
"required": ["name", "arguments"],
|
"required": ["name", "arguments"],
|
||||||
},
|
},
|
||||||
"type": {"type": "string", "enum": ["function"]},
|
|
||||||
},
|
},
|
||||||
"required": ["id", "function", "type"],
|
"required": ["function"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue