Implement lora support (#24)

* Model: Implement basic lora support

* Add ability to load loras from config on launch
* Supports loading multiple loras and lora scaling
* Add function to unload loras

* Colab: Update for basic lora support

* Model: Test vram alloc after lora load, add docs

* Git: Add loras folder to .gitignore

* API: Add basic lora-related endpoints

* Add /loras/ endpoint for querying available loras
* Add /model/lora endpoint for querying currently loaded loras
* Add /model/lora/load endpoint for loading loras
* Add /model/lora/unload endpoint for unloading loras
* Move lora config-checking logic to main.py for better compat with API endpoints

* Revert bad CRLF line ending changes

* API: Add basic lora-related endpoints (fixed)

* Add /loras/ endpoint for querying available loras
* Add /model/lora endpoint for querying currently loaded loras
* Add /model/lora/load endpoint for loading loras
* Add /model/lora/unload endpoint for unloading loras
* Move lora config-checking logic to main.py for better compat with API endpoints

* Model: Unload loras first when unloading model

* API + Models: Cleanup lora endpoints and functions

Condenses down endpoint and model load code. Also makes the routes
behave the same way as model routes to help not confuse the end user.

Signed-off-by: kingbri <bdashore3@proton.me>

* Loras: Optimize load endpoint

Return successes and failures along with consolidating the request
to the rewritten load_loras function.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Co-authored-by: kingbri <bdashore3@proton.me>
Co-authored-by: DocShotgun <126566557+DocShotgun@users.noreply.github.com>
This commit is contained in:
DocShotgun 2023-12-08 20:36:40 -08:00 committed by kingbri
parent 161c9d2c19
commit 7380a3b79a
8 changed files with 197 additions and 19 deletions

4
.gitignore vendored
View file

@ -182,3 +182,7 @@ api_tokens.yml
# Models folder # Models folder
models/* models/*
!models/place_your_models_here.txt !models/place_your_models_here.txt
# Loras folder
loras/*
!loras/place_your_loras_here.txt

25
OAI/types/lora.py Normal file
View file

@ -0,0 +1,25 @@
from pydantic import BaseModel, Field;
from time import time
from typing import Optional, List
class LoraCard(BaseModel):
id: str = "test"
object: str = "lora"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
scaling: Optional[float] = None
class LoraList(BaseModel):
object: str = "list"
data: List[LoraCard] = Field(default_factory=list)
class LoraLoadInfo(BaseModel):
name: str
scaling: Optional[float] = 1.0
class LoraLoadRequest(BaseModel):
loras: List[LoraLoadInfo]
class LoraLoadResponse(BaseModel):
success: List[str] = Field(default_factory=list)
failure: List[str] = Field(default_factory=list)

View file

@ -8,9 +8,10 @@ from OAI.types.chat_completion import (
ChatCompletionStreamChoice ChatCompletionStreamChoice
) )
from OAI.types.common import UsageStats from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard from OAI.types.model import ModelList, ModelCard
from packaging import version from packaging import version
from typing import Optional, List from typing import Optional, List, Dict
# Check fastchat # Check fastchat
try: try:
@ -100,6 +101,15 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
return model_card_list return model_card_list
def get_lora_list(lora_path: pathlib.Path):
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id = path.name)
lora_list.data.append(lora_card)
return lora_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]): def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
# Check if fastchat is available # Check if fastchat is available

View file

@ -36,10 +36,17 @@
"# @markdown Select model:\n", "# @markdown Select model:\n",
"repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n", "repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n",
"revision = \"4bpw\" # @param {type:\"string\"}\n", "revision = \"4bpw\" # @param {type:\"string\"}\n",
"if revision == \"\": revision = \"main\"\n",
"# @markdown ---\n", "# @markdown ---\n",
"# @markdown Select draft model (optional, for speculative decoding):\n", "# @markdown Select draft model (optional, for speculative decoding):\n",
"draft_repo_id = \"\" # @param {type:\"string\"}\n", "draft_repo_id = \"\" # @param {type:\"string\"}\n",
"draft_revision = \"\" # @param {type:\"string\"}\n", "draft_revision = \"\" # @param {type:\"string\"}\n",
"if draft_revision == \"\": draft_revision = \"main\"\n",
"# @markdown ---\n",
"# @markdown Select lora (optional):\n",
"lora_repo_id = \"\" # @param {type:\"string\"}\n",
"lora_revision = \"\" # @param {type:\"string\"}\n",
"if lora_revision == \"\": lora_revision = \"main\"\n",
"# @markdown ---\n", "# @markdown ---\n",
"\n", "\n",
"# Install tabbyAPI\n", "# Install tabbyAPI\n",
@ -62,8 +69,15 @@
"%cd /content/tabbyAPI/\n", "%cd /content/tabbyAPI/\n",
"\n", "\n",
"from huggingface_hub import snapshot_download\n", "from huggingface_hub import snapshot_download\n",
"\n",
"snapshot_download(repo_id=repo_id, revision=revision, local_dir=f\"./models/{repo_id.replace('/', '_')}\")\n", "snapshot_download(repo_id=repo_id, revision=revision, local_dir=f\"./models/{repo_id.replace('/', '_')}\")\n",
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")" "model = repo_id.replace('/', '_')\n",
"\n",
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")\n",
"draft_model = draft_repo_id.replace('/', '_')\n",
"\n",
"if len(lora_repo_id) > 0: snapshot_download(repo_id=lora_repo_id, revision=lora_revision, local_dir=f\"./loras/{lora_repo_id.replace('/', '_')}\")\n",
"lora = lora_repo_id.replace('/', '_')"
] ]
}, },
{ {
@ -77,9 +91,6 @@
"# @title # Configure and launch API { display-mode: \"form\" }\n", "# @title # Configure and launch API { display-mode: \"form\" }\n",
"# @markdown ---\n", "# @markdown ---\n",
"# @markdown Model parameters:\n", "# @markdown Model parameters:\n",
"\n",
"model = repo_id.replace('/', '_')\n",
"draft_model = draft_repo_id.replace('/', '_')\n",
"ContextSize = 4096 # @param {type:\"integer\"}\n", "ContextSize = 4096 # @param {type:\"integer\"}\n",
"RopeScale = 1.0 # @param {type:\"number\"}\n", "RopeScale = 1.0 # @param {type:\"number\"}\n",
"RopeAlpha = 1.0 # @param {type:\"number\"}\n", "RopeAlpha = 1.0 # @param {type:\"number\"}\n",
@ -88,6 +99,9 @@
"DraftRopeScale = 1.0 # @param {type:\"number\"}\n", "DraftRopeScale = 1.0 # @param {type:\"number\"}\n",
"DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n", "DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n", "# @markdown ---\n",
"# @markdown Lora parameters (optional, for loras):\n",
"LoraScaling = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n",
"# @markdown Misc options:\n", "# @markdown Misc options:\n",
"CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n", "CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n",
"UseDummyModels = False # @param {type:\"boolean\"}\n", "UseDummyModels = False # @param {type:\"boolean\"}\n",
@ -161,6 +175,16 @@
" # Rope parameters for draft models (default: 1.0)\n", " # Rope parameters for draft models (default: 1.0)\n",
" draft_rope_scale: {DraftRopeScale}\n", " draft_rope_scale: {DraftRopeScale}\n",
" draft_rope_alpha: {DraftRopeAlpha}\n", " draft_rope_alpha: {DraftRopeAlpha}\n",
"\n",
" # Options for loras\n",
" lora:\n",
" # Overrides the directory to look for loras (default: loras)\n",
" lora_dir: loras\n",
"\n",
" # List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.\n",
" loras:\n",
" - name: {lora}\n",
" scaling: {LoraScaling}\n",
"'''\n", "'''\n",
"with open(\"./config.yml\", \"w\") as file:\n", "with open(\"./config.yml\", \"w\") as file:\n",
" file.write(write)\n", " file.write(write)\n",
@ -188,4 +212,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }

View file

@ -60,3 +60,17 @@ model:
# Rope parameters for draft models (default: 1.0) # Rope parameters for draft models (default: 1.0)
draft_rope_scale: 1.0 draft_rope_scale: 1.0
draft_rope_alpha: 1.0 draft_rope_alpha: 1.0
# Options for loras
lora:
# Overrides the directory to look for loras (default: loras)
lora_dir: Your lora directory path
# List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.
loras:
- name: lora1
scaling: 1.0
- name: lora2
scaling: 0.9
- name: lora3
scaling: 0.5

View file

67
main.py
View file

@ -11,6 +11,7 @@ from progress.bar import IncrementalBar
from generators import generate_with_semaphore from generators import generate_with_semaphore
from OAI.types.completion import CompletionRequest from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import ( from OAI.types.token import (
TokenEncodeRequest, TokenEncodeRequest,
@ -21,6 +22,7 @@ from OAI.types.token import (
from OAI.utils import ( from OAI.utils import (
create_completion_response, create_completion_response,
get_model_list, get_model_list,
get_lora_list,
get_chat_completion_prompt, get_chat_completion_prompt,
create_chat_completion_response, create_chat_completion_response,
create_chat_completion_stream_chunk create_chat_completion_stream_chunk
@ -87,7 +89,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name: if not data.name:
raise HTTPException(400, "model_name not found.") raise HTTPException(400, "model_name not found.")
# TODO: Move this to model_container
model_config = config.get("model") or {} model_config = config.get("model") or {}
model_path = pathlib.Path(model_config.get("model_dir") or "models") model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / data.name model_path = model_path / data.name
@ -160,7 +161,63 @@ async def unload_model():
global model_container global model_container
model_container.unload() model_container.unload()
model_container = None model_container = None
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_all_loras():
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
loras = get_lora_list(lora_path.resolve())
return loras
# Currently loaded loras endpoint
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_active_loras():
active_loras = LoraList(
data = list(map(
lambda lora: LoraCard(
id = pathlib.Path(lora.lora_path).parent.name,
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha
),
model_container.active_loras
)
))
return active_loras
# Load lora endpoint
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
async def load_model(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
# Clean-up existing loras if present
if len(model_container.active_loras) > 0:
model_container.unload(True)
result = model_container.load_loras(lora_dir, **data.dict())
return LoraLoadResponse(
success = result.get("success") or [],
failure = result.get("failure") or []
)
# Unload lora endpoint
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
async def unload_loras():
global model_container
model_container.unload(True)
# Encode tokens endpoint # Encode tokens endpoint
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@ -308,6 +365,12 @@ if __name__ == "__main__":
else: else:
loading_bar.next() loading_bar.next()
# Load loras
lora_config = model_config.get("lora") or {}
if "loras" in lora_config:
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
model_container.load_loras(lora_dir.resolve(), **lora_config)
network_config = config.get("network") or {} network_config = config.get("network") or {}
uvicorn.run( uvicorn.run(
app, app,

View file

@ -6,6 +6,7 @@ from exllamav2 import(
ExLlamaV2Cache, ExLlamaV2Cache,
ExLlamaV2Cache_8bit, ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer, ExLlamaV2Tokenizer,
ExLlamaV2Lora
) )
from exllamav2.generator import( from exllamav2.generator import(
ExLlamaV2StreamingGenerator, ExLlamaV2StreamingGenerator,
@ -30,6 +31,8 @@ class ModelContainer:
cache_fp8: bool = False cache_fp8: bool = False
gpu_split_auto: bool = True gpu_split_auto: bool = True
gpu_split: list or None = None gpu_split: list or None = None
active_loras: List[ExLlamaV2Lora] = []
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs): def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
""" """
@ -54,6 +57,8 @@ class ModelContainer:
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model. 'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
By default, the draft model's alpha value is calculated automatically to scale to the size of the By default, the draft model's alpha value is calculated automatically to scale to the size of the
full model. full model.
'lora_dir' (str): Lora directory
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across available devices (default: True) 'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device 'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False) 'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
@ -141,6 +146,32 @@ class ModelContainer:
""" """
for _ in self.load_gen(progress_callback): pass for _ in self.load_gen(progress_callback): pass
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
loras = kwargs.get("loras") or []
success: List[str] = []
failure: List[str] = []
for lora in loras:
lora_name = lora.get("name") or None
lora_scaling = lora.get("scaling") or 1.0
if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
failure.append(lora_name)
continue
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling))
print("Lora successfully loaded.")
success.append(lora_name)
# Return success and failure names
return { 'success': success, 'failure': failure }
def load_gen(self, progress_callback = None): def load_gen(self, progress_callback = None):
""" """
@ -204,23 +235,30 @@ class ModelContainer:
print("Model successfully loaded.") print("Model successfully loaded.")
def unload(self): def unload(self, loras_only: bool = False):
""" """
Free all VRAM resources used by this model Free all VRAM resources used by this model
""" """
if self.model: self.model.unload() for lora in self.active_loras:
self.model = None lora.unload()
if self.draft_model: self.draft_model.unload()
self.draft_model = None self.active_loras = []
self.config = None
self.cache = None # Unload the entire model if not just unloading loras
self.tokenizer = None if not loras_only:
self.generator = None if self.model: self.model.unload()
self.model = None
if self.draft_model: self.draft_model.unload()
self.draft_model = None
self.config = None
self.cache = None
self.tokenizer = None
self.generator = None
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Common function for token operations # Common function for token operations
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs): def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
if text: if text:
@ -381,7 +419,7 @@ class ModelContainer:
active_ids = ids[:, max(0, overflow):] active_ids = ids[:, max(0, overflow):]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing) self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras)
# Generate # Generate