OAI: Add ability to specify fastchat prompt template
Sometimes fastchat may not be able to detect the prompt template from the model path. Therefore, add the ability to set it in config.yml or via the request object itself. Also send the provided prompt template on model info request. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9f195af5ad
commit
db87efde4a
7 changed files with 34 additions and 8 deletions
|
|
@ -25,6 +25,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
# Messages
|
||||
# Take in a string as well even though it's not part of the OAI spec
|
||||
messages: Union[str, List[ChatCompletionMessage]]
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel, Field;
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import Optional, List
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ class ModelCardParameters(BaseModel):
|
|||
max_seq_len: Optional[int] = 4096
|
||||
rope_scale: Optional[float] = 1.0
|
||||
rope_alpha: Optional[float] = 1.0
|
||||
prompt_template: Optional[str] = None
|
||||
draft: Optional['ModelCard'] = None
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
|
|
@ -34,6 +35,7 @@ class ModelLoadRequest(BaseModel):
|
|||
rope_alpha: Optional[float] = 1.0
|
||||
no_flash_attention: Optional[bool] = False
|
||||
low_mem: Optional[bool] = False
|
||||
prompt_template: Optional[str] = None
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
|
|
|
|||
18
OAI/utils.py
18
OAI/utils.py
|
|
@ -1,5 +1,5 @@
|
|||
import os, pathlib
|
||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
|
||||
import pathlib
|
||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice
|
||||
from OAI.types.chat_completion import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRespChoice,
|
||||
|
|
@ -11,13 +11,13 @@ from OAI.types.common import UsageStats
|
|||
from OAI.types.lora import LoraList, LoraCard
|
||||
from OAI.types.model import ModelList, ModelCard
|
||||
from packaging import version
|
||||
from typing import Optional, List, Dict
|
||||
from typing import Optional, List
|
||||
from utils import unwrap
|
||||
|
||||
# Check fastchat
|
||||
try:
|
||||
import fastchat
|
||||
from fastchat.model.model_adapter import get_conversation_template
|
||||
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
_fastchat_available = True
|
||||
except ImportError:
|
||||
|
|
@ -111,8 +111,9 @@ def get_lora_list(lora_path: pathlib.Path):
|
|||
|
||||
return lora_list
|
||||
|
||||
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
|
||||
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
|
||||
|
||||
# TODO: Replace fastchat with in-house jinja templates
|
||||
# Check if fastchat is available
|
||||
if not _fastchat_available:
|
||||
raise ModuleNotFoundError(
|
||||
|
|
@ -127,7 +128,11 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
|
|||
"pip install -U fschat[model_worker]"
|
||||
)
|
||||
|
||||
conv = get_conversation_template(model_path)
|
||||
if prompt_template:
|
||||
conv = get_conv_template(prompt_template)
|
||||
else:
|
||||
conv = get_conversation_template(model_path)
|
||||
|
||||
if conv.sep_style is None:
|
||||
conv.sep_style = SeparatorStyle.LLAMA2
|
||||
|
||||
|
|
@ -145,4 +150,5 @@ def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMes
|
|||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
print(prompt)
|
||||
return prompt
|
||||
|
|
|
|||
|
|
@ -48,6 +48,10 @@ model:
|
|||
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
|
||||
cache_mode: FP16
|
||||
|
||||
# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None)
|
||||
# NOTE: Only works with chat completion message lists!
|
||||
prompt_template:
|
||||
|
||||
# Options for draft models (speculative decoding). This will use more VRAM!
|
||||
draft:
|
||||
# Overrides the directory to look for draft (default: models)
|
||||
|
|
|
|||
10
main.py
10
main.py
|
|
@ -80,6 +80,7 @@ async def get_current_model():
|
|||
rope_scale = model_container.config.scale_pos_emb,
|
||||
rope_alpha = model_container.config.scale_alpha_value,
|
||||
max_seq_len = model_container.config.max_seq_len,
|
||||
prompt_template = unwrap(model_container.prompt_template, "auto")
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -302,7 +303,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = get_chat_completion_prompt(model_path.name, data.messages)
|
||||
# If the request specified prompt template isn't found, use the one from model container
|
||||
# Otherwise, let fastchat figure it out
|
||||
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
|
||||
|
||||
try:
|
||||
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
|
||||
except KeyError:
|
||||
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
|
||||
|
||||
if data.stream:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
|
|
|||
5
model.py
5
model.py
|
|
@ -27,6 +27,7 @@ class ModelContainer:
|
|||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2StreamingGenerator] = None
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
cache_fp8: bool = False
|
||||
gpu_split_auto: bool = True
|
||||
|
|
@ -48,6 +49,7 @@ class ModelContainer:
|
|||
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
|
||||
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
|
||||
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
|
||||
'prompt_template' (str): Manually sets the prompt template for this model (default: None)
|
||||
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
|
||||
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller
|
||||
batches. This limits the size of temporary buffers needed for the hidden state and attention
|
||||
|
|
@ -93,6 +95,9 @@ class ModelContainer:
|
|||
self.config.set_low_mem()
|
||||
"""
|
||||
|
||||
# Set prompt template override if provided
|
||||
self.prompt_template = kwargs.get("prompt_template")
|
||||
|
||||
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
|
||||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue