* Model: Implement basic lora support * Add ability to load loras from config on launch * Supports loading multiple loras and lora scaling * Add function to unload loras * Colab: Update for basic lora support * Model: Test vram alloc after lora load, add docs * Git: Add loras folder to .gitignore * API: Add basic lora-related endpoints * Add /loras/ endpoint for querying available loras * Add /model/lora endpoint for querying currently loaded loras * Add /model/lora/load endpoint for loading loras * Add /model/lora/unload endpoint for unloading loras * Move lora config-checking logic to main.py for better compat with API endpoints * Revert bad CRLF line ending changes * API: Add basic lora-related endpoints (fixed) * Add /loras/ endpoint for querying available loras * Add /model/lora endpoint for querying currently loaded loras * Add /model/lora/load endpoint for loading loras * Add /model/lora/unload endpoint for unloading loras * Move lora config-checking logic to main.py for better compat with API endpoints * Model: Unload loras first when unloading model * API + Models: Cleanup lora endpoints and functions Condenses down endpoint and model load code. Also makes the routes behave the same way as model routes to help not confuse the end user. Signed-off-by: kingbri <bdashore3@proton.me> * Loras: Optimize load endpoint Return successes and failures along with consolidating the request to the rewritten load_loras function. Signed-off-by: kingbri <bdashore3@proton.me> --------- Co-authored-by: kingbri <bdashore3@proton.me> Co-authored-by: DocShotgun <126566557+DocShotgun@users.noreply.github.com>
147 lines
4.8 KiB
Python
147 lines
4.8 KiB
Python
import os, pathlib
|
|
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
|
|
from OAI.types.chat_completion import (
|
|
ChatCompletionMessage,
|
|
ChatCompletionRespChoice,
|
|
ChatCompletionStreamChunk,
|
|
ChatCompletionResponse,
|
|
ChatCompletionStreamChoice
|
|
)
|
|
from OAI.types.common import UsageStats
|
|
from OAI.types.lora import LoraList, LoraCard
|
|
from OAI.types.model import ModelList, ModelCard
|
|
from packaging import version
|
|
from typing import Optional, List, Dict
|
|
|
|
# Check fastchat
|
|
try:
|
|
import fastchat
|
|
from fastchat.model.model_adapter import get_conversation_template
|
|
from fastchat.conversation import SeparatorStyle
|
|
_fastchat_available = True
|
|
except ImportError:
|
|
_fastchat_available = False
|
|
|
|
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
|
choice = CompletionRespChoice(
|
|
finish_reason = "Generated",
|
|
text = text
|
|
)
|
|
|
|
response = CompletionResponse(
|
|
choices = [choice],
|
|
model = model_name or "",
|
|
usage = UsageStats(prompt_tokens = prompt_tokens,
|
|
completion_tokens = completion_tokens,
|
|
total_tokens = prompt_tokens + completion_tokens)
|
|
)
|
|
|
|
return response
|
|
|
|
def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
|
message = ChatCompletionMessage(
|
|
role = "assistant",
|
|
content = text
|
|
)
|
|
|
|
choice = ChatCompletionRespChoice(
|
|
finish_reason = "Generated",
|
|
message = message
|
|
)
|
|
|
|
response = ChatCompletionResponse(
|
|
choices = [choice],
|
|
model = model_name or "",
|
|
usage = UsageStats(prompt_tokens = prompt_tokens,
|
|
completion_tokens = completion_tokens,
|
|
total_tokens = prompt_tokens + completion_tokens)
|
|
)
|
|
|
|
return response
|
|
|
|
def create_chat_completion_stream_chunk(const_id: str,
|
|
text: Optional[str] = None,
|
|
model_name: Optional[str] = None,
|
|
finish_reason: Optional[str] = None):
|
|
if finish_reason:
|
|
message = {}
|
|
else:
|
|
message = ChatCompletionMessage(
|
|
role = "assistant",
|
|
content = text
|
|
)
|
|
|
|
# The finish reason can be None
|
|
choice = ChatCompletionStreamChoice(
|
|
finish_reason = finish_reason,
|
|
delta = message
|
|
)
|
|
|
|
chunk = ChatCompletionStreamChunk(
|
|
id = const_id,
|
|
choices = [choice],
|
|
model = model_name or ""
|
|
)
|
|
|
|
return chunk
|
|
|
|
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
|
|
|
|
# Convert the draft model path to a pathlib path for equality comparisons
|
|
if draft_model_path:
|
|
draft_model_path = pathlib.Path(draft_model_path).resolve()
|
|
|
|
model_card_list = ModelList()
|
|
for path in model_path.iterdir():
|
|
|
|
# Don't include the draft models path
|
|
if path.is_dir() and path != draft_model_path:
|
|
model_card = ModelCard(id = path.name)
|
|
model_card_list.data.append(model_card)
|
|
|
|
return model_card_list
|
|
|
|
def get_lora_list(lora_path: pathlib.Path):
|
|
lora_list = LoraList()
|
|
for path in lora_path.iterdir():
|
|
if path.is_dir():
|
|
lora_card = LoraCard(id = path.name)
|
|
lora_list.data.append(lora_card)
|
|
|
|
return lora_list
|
|
|
|
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
|
|
|
|
# Check if fastchat is available
|
|
if not _fastchat_available:
|
|
raise ModuleNotFoundError(
|
|
"Fastchat must be installed to parse these chat completion messages.\n"
|
|
"Please run the following command: pip install fschat[model_worker]"
|
|
)
|
|
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
|
raise ImportError(
|
|
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
|
|
f"Current version: {fastchat.__version__}\n"
|
|
"Please upgrade fastchat by running the following command: "
|
|
"pip install -U fschat[model_worker]"
|
|
)
|
|
|
|
conv = get_conversation_template(model_path)
|
|
if conv.sep_style is None:
|
|
conv.sep_style = SeparatorStyle.LLAMA2
|
|
|
|
for message in messages:
|
|
msg_role = message.role
|
|
if msg_role == "system":
|
|
conv.set_system_message(message.content)
|
|
elif msg_role == "user":
|
|
conv.append_message(conv.roles[0], message.content)
|
|
elif msg_role == "assistant":
|
|
conv.append_message(conv.roles[1], message.content)
|
|
else:
|
|
raise ValueError(f"Unknown role: {msg_role}")
|
|
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt = conv.get_prompt()
|
|
|
|
return prompt
|