Implement lora support (#24)
* Model: Implement basic lora support * Add ability to load loras from config on launch * Supports loading multiple loras and lora scaling * Add function to unload loras * Colab: Update for basic lora support * Model: Test vram alloc after lora load, add docs * Git: Add loras folder to .gitignore * API: Add basic lora-related endpoints * Add /loras/ endpoint for querying available loras * Add /model/lora endpoint for querying currently loaded loras * Add /model/lora/load endpoint for loading loras * Add /model/lora/unload endpoint for unloading loras * Move lora config-checking logic to main.py for better compat with API endpoints * Revert bad CRLF line ending changes * API: Add basic lora-related endpoints (fixed) * Add /loras/ endpoint for querying available loras * Add /model/lora endpoint for querying currently loaded loras * Add /model/lora/load endpoint for loading loras * Add /model/lora/unload endpoint for unloading loras * Move lora config-checking logic to main.py for better compat with API endpoints * Model: Unload loras first when unloading model * API + Models: Cleanup lora endpoints and functions Condenses down endpoint and model load code. Also makes the routes behave the same way as model routes to help not confuse the end user. Signed-off-by: kingbri <bdashore3@proton.me> * Loras: Optimize load endpoint Return successes and failures along with consolidating the request to the rewritten load_loras function. Signed-off-by: kingbri <bdashore3@proton.me> --------- Co-authored-by: kingbri <bdashore3@proton.me> Co-authored-by: DocShotgun <126566557+DocShotgun@users.noreply.github.com>
This commit is contained in:
parent
161c9d2c19
commit
7380a3b79a
8 changed files with 197 additions and 19 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -182,3 +182,7 @@ api_tokens.yml
|
|||
# Models folder
|
||||
models/*
|
||||
!models/place_your_models_here.txt
|
||||
|
||||
# Loras folder
|
||||
loras/*
|
||||
!loras/place_your_loras_here.txt
|
||||
|
|
|
|||
25
OAI/types/lora.py
Normal file
25
OAI/types/lora.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from pydantic import BaseModel, Field;
|
||||
from time import time
|
||||
from typing import Optional, List
|
||||
|
||||
class LoraCard(BaseModel):
|
||||
id: str = "test"
|
||||
object: str = "lora"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
scaling: Optional[float] = None
|
||||
|
||||
class LoraList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[LoraCard] = Field(default_factory=list)
|
||||
|
||||
class LoraLoadInfo(BaseModel):
|
||||
name: str
|
||||
scaling: Optional[float] = 1.0
|
||||
|
||||
class LoraLoadRequest(BaseModel):
|
||||
loras: List[LoraLoadInfo]
|
||||
|
||||
class LoraLoadResponse(BaseModel):
|
||||
success: List[str] = Field(default_factory=list)
|
||||
failure: List[str] = Field(default_factory=list)
|
||||
12
OAI/utils.py
12
OAI/utils.py
|
|
@ -8,9 +8,10 @@ from OAI.types.chat_completion import (
|
|||
ChatCompletionStreamChoice
|
||||
)
|
||||
from OAI.types.common import UsageStats
|
||||
from OAI.types.lora import LoraList, LoraCard
|
||||
from OAI.types.model import ModelList, ModelCard
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
# Check fastchat
|
||||
try:
|
||||
|
|
@ -100,6 +101,15 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
|
|||
|
||||
return model_card_list
|
||||
|
||||
def get_lora_list(lora_path: pathlib.Path):
|
||||
lora_list = LoraList()
|
||||
for path in lora_path.iterdir():
|
||||
if path.is_dir():
|
||||
lora_card = LoraCard(id = path.name)
|
||||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
|
||||
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
|
||||
|
||||
# Check if fastchat is available
|
||||
|
|
|
|||
|
|
@ -36,10 +36,17 @@
|
|||
"# @markdown Select model:\n",
|
||||
"repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n",
|
||||
"revision = \"4bpw\" # @param {type:\"string\"}\n",
|
||||
"if revision == \"\": revision = \"main\"\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Select draft model (optional, for speculative decoding):\n",
|
||||
"draft_repo_id = \"\" # @param {type:\"string\"}\n",
|
||||
"draft_revision = \"\" # @param {type:\"string\"}\n",
|
||||
"if draft_revision == \"\": draft_revision = \"main\"\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Select lora (optional):\n",
|
||||
"lora_repo_id = \"\" # @param {type:\"string\"}\n",
|
||||
"lora_revision = \"\" # @param {type:\"string\"}\n",
|
||||
"if lora_revision == \"\": lora_revision = \"main\"\n",
|
||||
"# @markdown ---\n",
|
||||
"\n",
|
||||
"# Install tabbyAPI\n",
|
||||
|
|
@ -62,8 +69,15 @@
|
|||
"%cd /content/tabbyAPI/\n",
|
||||
"\n",
|
||||
"from huggingface_hub import snapshot_download\n",
|
||||
"\n",
|
||||
"snapshot_download(repo_id=repo_id, revision=revision, local_dir=f\"./models/{repo_id.replace('/', '_')}\")\n",
|
||||
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")"
|
||||
"model = repo_id.replace('/', '_')\n",
|
||||
"\n",
|
||||
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")\n",
|
||||
"draft_model = draft_repo_id.replace('/', '_')\n",
|
||||
"\n",
|
||||
"if len(lora_repo_id) > 0: snapshot_download(repo_id=lora_repo_id, revision=lora_revision, local_dir=f\"./loras/{lora_repo_id.replace('/', '_')}\")\n",
|
||||
"lora = lora_repo_id.replace('/', '_')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -77,9 +91,6 @@
|
|||
"# @title # Configure and launch API { display-mode: \"form\" }\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Model parameters:\n",
|
||||
"\n",
|
||||
"model = repo_id.replace('/', '_')\n",
|
||||
"draft_model = draft_repo_id.replace('/', '_')\n",
|
||||
"ContextSize = 4096 # @param {type:\"integer\"}\n",
|
||||
"RopeScale = 1.0 # @param {type:\"number\"}\n",
|
||||
"RopeAlpha = 1.0 # @param {type:\"number\"}\n",
|
||||
|
|
@ -88,6 +99,9 @@
|
|||
"DraftRopeScale = 1.0 # @param {type:\"number\"}\n",
|
||||
"DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Lora parameters (optional, for loras):\n",
|
||||
"LoraScaling = 1.0 # @param {type:\"number\"}\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Misc options:\n",
|
||||
"CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n",
|
||||
"UseDummyModels = False # @param {type:\"boolean\"}\n",
|
||||
|
|
@ -161,6 +175,16 @@
|
|||
" # Rope parameters for draft models (default: 1.0)\n",
|
||||
" draft_rope_scale: {DraftRopeScale}\n",
|
||||
" draft_rope_alpha: {DraftRopeAlpha}\n",
|
||||
"\n",
|
||||
" # Options for loras\n",
|
||||
" lora:\n",
|
||||
" # Overrides the directory to look for loras (default: loras)\n",
|
||||
" lora_dir: loras\n",
|
||||
"\n",
|
||||
" # List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.\n",
|
||||
" loras:\n",
|
||||
" - name: {lora}\n",
|
||||
" scaling: {LoraScaling}\n",
|
||||
"'''\n",
|
||||
"with open(\"./config.yml\", \"w\") as file:\n",
|
||||
" file.write(write)\n",
|
||||
|
|
@ -188,4 +212,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
|
@ -60,3 +60,17 @@ model:
|
|||
# Rope parameters for draft models (default: 1.0)
|
||||
draft_rope_scale: 1.0
|
||||
draft_rope_alpha: 1.0
|
||||
|
||||
# Options for loras
|
||||
lora:
|
||||
# Overrides the directory to look for loras (default: loras)
|
||||
lora_dir: Your lora directory path
|
||||
|
||||
# List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.
|
||||
loras:
|
||||
- name: lora1
|
||||
scaling: 1.0
|
||||
- name: lora2
|
||||
scaling: 0.9
|
||||
- name: lora3
|
||||
scaling: 0.5
|
||||
0
loras/place_your_loras_here.txt
Normal file
0
loras/place_your_loras_here.txt
Normal file
67
main.py
67
main.py
|
|
@ -11,6 +11,7 @@ from progress.bar import IncrementalBar
|
|||
from generators import generate_with_semaphore
|
||||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
||||
from OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
|
|
@ -21,6 +22,7 @@ from OAI.types.token import (
|
|||
from OAI.utils import (
|
||||
create_completion_response,
|
||||
get_model_list,
|
||||
get_lora_list,
|
||||
get_chat_completion_prompt,
|
||||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk
|
||||
|
|
@ -87,7 +89,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not data.name:
|
||||
raise HTTPException(400, "model_name not found.")
|
||||
|
||||
# TODO: Move this to model_container
|
||||
model_config = config.get("model") or {}
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / data.name
|
||||
|
|
@ -160,7 +161,63 @@ async def unload_model():
|
|||
global model_container
|
||||
|
||||
model_container.unload()
|
||||
model_container = None
|
||||
model_container = None
|
||||
|
||||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_all_loras():
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
return loras
|
||||
|
||||
# Currently loaded loras endpoint
|
||||
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_active_loras():
|
||||
active_loras = LoraList(
|
||||
data = list(map(
|
||||
lambda lora: LoraCard(
|
||||
id = pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha
|
||||
),
|
||||
model_container.active_loras
|
||||
)
|
||||
))
|
||||
|
||||
return active_loras
|
||||
|
||||
# Load lora endpoint
|
||||
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
||||
async def load_model(data: LoraLoadRequest):
|
||||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
model_config = config.get("model") or {}
|
||||
lora_config = model_config.get("lora") or {}
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
if not lora_dir.exists():
|
||||
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
|
||||
|
||||
# Clean-up existing loras if present
|
||||
if len(model_container.active_loras) > 0:
|
||||
model_container.unload(True)
|
||||
|
||||
result = model_container.load_loras(lora_dir, **data.dict())
|
||||
return LoraLoadResponse(
|
||||
success = result.get("success") or [],
|
||||
failure = result.get("failure") or []
|
||||
)
|
||||
|
||||
# Unload lora endpoint
|
||||
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
||||
async def unload_loras():
|
||||
global model_container
|
||||
|
||||
model_container.unload(True)
|
||||
|
||||
# Encode tokens endpoint
|
||||
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
|
|
@ -308,6 +365,12 @@ if __name__ == "__main__":
|
|||
else:
|
||||
loading_bar.next()
|
||||
|
||||
# Load loras
|
||||
lora_config = model_config.get("lora") or {}
|
||||
if "loras" in lora_config:
|
||||
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
|
||||
model_container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
network_config = config.get("network") or {}
|
||||
uvicorn.run(
|
||||
app,
|
||||
|
|
|
|||
60
model.py
60
model.py
|
|
@ -6,6 +6,7 @@ from exllamav2 import(
|
|||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2Lora
|
||||
)
|
||||
from exllamav2.generator import(
|
||||
ExLlamaV2StreamingGenerator,
|
||||
|
|
@ -30,6 +31,8 @@ class ModelContainer:
|
|||
cache_fp8: bool = False
|
||||
gpu_split_auto: bool = True
|
||||
gpu_split: list or None = None
|
||||
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
|
||||
"""
|
||||
|
|
@ -54,6 +57,8 @@ class ModelContainer:
|
|||
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
|
||||
By default, the draft model's alpha value is calculated automatically to scale to the size of the
|
||||
full model.
|
||||
'lora_dir' (str): Lora directory
|
||||
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling'
|
||||
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
|
||||
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
|
||||
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
|
||||
|
|
@ -141,6 +146,32 @@ class ModelContainer:
|
|||
"""
|
||||
for _ in self.load_gen(progress_callback): pass
|
||||
|
||||
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Load loras
|
||||
"""
|
||||
|
||||
loras = kwargs.get("loras") or []
|
||||
success: List[str] = []
|
||||
failure: List[str] = []
|
||||
|
||||
for lora in loras:
|
||||
lora_name = lora.get("name") or None
|
||||
lora_scaling = lora.get("scaling") or 1.0
|
||||
|
||||
if lora_name is None:
|
||||
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
|
||||
failure.append(lora_name)
|
||||
continue
|
||||
|
||||
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||
lora_path = lora_directory / lora_name
|
||||
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling))
|
||||
print("Lora successfully loaded.")
|
||||
success.append(lora_name)
|
||||
|
||||
# Return success and failure names
|
||||
return { 'success': success, 'failure': failure }
|
||||
|
||||
def load_gen(self, progress_callback = None):
|
||||
"""
|
||||
|
|
@ -204,23 +235,30 @@ class ModelContainer:
|
|||
print("Model successfully loaded.")
|
||||
|
||||
|
||||
def unload(self):
|
||||
def unload(self, loras_only: bool = False):
|
||||
"""
|
||||
Free all VRAM resources used by this model
|
||||
"""
|
||||
|
||||
if self.model: self.model.unload()
|
||||
self.model = None
|
||||
if self.draft_model: self.draft_model.unload()
|
||||
self.draft_model = None
|
||||
self.config = None
|
||||
self.cache = None
|
||||
self.tokenizer = None
|
||||
self.generator = None
|
||||
for lora in self.active_loras:
|
||||
lora.unload()
|
||||
|
||||
self.active_loras = []
|
||||
|
||||
# Unload the entire model if not just unloading loras
|
||||
if not loras_only:
|
||||
if self.model: self.model.unload()
|
||||
self.model = None
|
||||
if self.draft_model: self.draft_model.unload()
|
||||
self.draft_model = None
|
||||
self.config = None
|
||||
self.cache = None
|
||||
self.tokenizer = None
|
||||
self.generator = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Common function for token operations
|
||||
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
|
||||
if text:
|
||||
|
|
@ -381,7 +419,7 @@ class ModelContainer:
|
|||
active_ids = ids[:, max(0, overflow):]
|
||||
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
||||
|
||||
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing)
|
||||
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras)
|
||||
|
||||
# Generate
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue