Implement lora support (#24)

* Model: Implement basic lora support

* Add ability to load loras from config on launch
* Supports loading multiple loras and lora scaling
* Add function to unload loras

* Colab: Update for basic lora support

* Model: Test vram alloc after lora load, add docs

* Git: Add loras folder to .gitignore

* API: Add basic lora-related endpoints

* Add /loras/ endpoint for querying available loras
* Add /model/lora endpoint for querying currently loaded loras
* Add /model/lora/load endpoint for loading loras
* Add /model/lora/unload endpoint for unloading loras
* Move lora config-checking logic to main.py for better compat with API endpoints

* Revert bad CRLF line ending changes

* API: Add basic lora-related endpoints (fixed)

* Add /loras/ endpoint for querying available loras
* Add /model/lora endpoint for querying currently loaded loras
* Add /model/lora/load endpoint for loading loras
* Add /model/lora/unload endpoint for unloading loras
* Move lora config-checking logic to main.py for better compat with API endpoints

* Model: Unload loras first when unloading model

* API + Models: Cleanup lora endpoints and functions

Condenses down endpoint and model load code. Also makes the routes
behave the same way as model routes to help not confuse the end user.

Signed-off-by: kingbri <bdashore3@proton.me>

* Loras: Optimize load endpoint

Return successes and failures along with consolidating the request
to the rewritten load_loras function.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Co-authored-by: kingbri <bdashore3@proton.me>
Co-authored-by: DocShotgun <126566557+DocShotgun@users.noreply.github.com>
This commit is contained in:
DocShotgun 2023-12-08 20:36:40 -08:00 committed by kingbri
parent 161c9d2c19
commit 7380a3b79a
8 changed files with 197 additions and 19 deletions

4
.gitignore vendored
View file

@ -182,3 +182,7 @@ api_tokens.yml
# Models folder
models/*
!models/place_your_models_here.txt
# Loras folder
loras/*
!loras/place_your_loras_here.txt

25
OAI/types/lora.py Normal file
View file

@ -0,0 +1,25 @@
from pydantic import BaseModel, Field;
from time import time
from typing import Optional, List
class LoraCard(BaseModel):
id: str = "test"
object: str = "lora"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
scaling: Optional[float] = None
class LoraList(BaseModel):
object: str = "list"
data: List[LoraCard] = Field(default_factory=list)
class LoraLoadInfo(BaseModel):
name: str
scaling: Optional[float] = 1.0
class LoraLoadRequest(BaseModel):
loras: List[LoraLoadInfo]
class LoraLoadResponse(BaseModel):
success: List[str] = Field(default_factory=list)
failure: List[str] = Field(default_factory=list)

View file

@ -8,9 +8,10 @@ from OAI.types.chat_completion import (
ChatCompletionStreamChoice
)
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List
from typing import Optional, List, Dict
# Check fastchat
try:
@ -100,6 +101,15 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
return model_card_list
def get_lora_list(lora_path: pathlib.Path):
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id = path.name)
lora_list.data.append(lora_card)
return lora_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):
# Check if fastchat is available

View file

@ -36,10 +36,17 @@
"# @markdown Select model:\n",
"repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n",
"revision = \"4bpw\" # @param {type:\"string\"}\n",
"if revision == \"\": revision = \"main\"\n",
"# @markdown ---\n",
"# @markdown Select draft model (optional, for speculative decoding):\n",
"draft_repo_id = \"\" # @param {type:\"string\"}\n",
"draft_revision = \"\" # @param {type:\"string\"}\n",
"if draft_revision == \"\": draft_revision = \"main\"\n",
"# @markdown ---\n",
"# @markdown Select lora (optional):\n",
"lora_repo_id = \"\" # @param {type:\"string\"}\n",
"lora_revision = \"\" # @param {type:\"string\"}\n",
"if lora_revision == \"\": lora_revision = \"main\"\n",
"# @markdown ---\n",
"\n",
"# Install tabbyAPI\n",
@ -62,8 +69,15 @@
"%cd /content/tabbyAPI/\n",
"\n",
"from huggingface_hub import snapshot_download\n",
"\n",
"snapshot_download(repo_id=repo_id, revision=revision, local_dir=f\"./models/{repo_id.replace('/', '_')}\")\n",
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")"
"model = repo_id.replace('/', '_')\n",
"\n",
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")\n",
"draft_model = draft_repo_id.replace('/', '_')\n",
"\n",
"if len(lora_repo_id) > 0: snapshot_download(repo_id=lora_repo_id, revision=lora_revision, local_dir=f\"./loras/{lora_repo_id.replace('/', '_')}\")\n",
"lora = lora_repo_id.replace('/', '_')"
]
},
{
@ -77,9 +91,6 @@
"# @title # Configure and launch API { display-mode: \"form\" }\n",
"# @markdown ---\n",
"# @markdown Model parameters:\n",
"\n",
"model = repo_id.replace('/', '_')\n",
"draft_model = draft_repo_id.replace('/', '_')\n",
"ContextSize = 4096 # @param {type:\"integer\"}\n",
"RopeScale = 1.0 # @param {type:\"number\"}\n",
"RopeAlpha = 1.0 # @param {type:\"number\"}\n",
@ -88,6 +99,9 @@
"DraftRopeScale = 1.0 # @param {type:\"number\"}\n",
"DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n",
"# @markdown Lora parameters (optional, for loras):\n",
"LoraScaling = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n",
"# @markdown Misc options:\n",
"CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n",
"UseDummyModels = False # @param {type:\"boolean\"}\n",
@ -161,6 +175,16 @@
" # Rope parameters for draft models (default: 1.0)\n",
" draft_rope_scale: {DraftRopeScale}\n",
" draft_rope_alpha: {DraftRopeAlpha}\n",
"\n",
" # Options for loras\n",
" lora:\n",
" # Overrides the directory to look for loras (default: loras)\n",
" lora_dir: loras\n",
"\n",
" # List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.\n",
" loras:\n",
" - name: {lora}\n",
" scaling: {LoraScaling}\n",
"'''\n",
"with open(\"./config.yml\", \"w\") as file:\n",
" file.write(write)\n",

View file

@ -60,3 +60,17 @@ model:
# Rope parameters for draft models (default: 1.0)
draft_rope_scale: 1.0
draft_rope_alpha: 1.0
# Options for loras
lora:
# Overrides the directory to look for loras (default: loras)
lora_dir: Your lora directory path
# List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.
loras:
- name: lora1
scaling: 1.0
- name: lora2
scaling: 0.9
- name: lora3
scaling: 0.5

View file

65
main.py
View file

@ -11,6 +11,7 @@ from progress.bar import IncrementalBar
from generators import generate_with_semaphore
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import (
TokenEncodeRequest,
@ -21,6 +22,7 @@ from OAI.types.token import (
from OAI.utils import (
create_completion_response,
get_model_list,
get_lora_list,
get_chat_completion_prompt,
create_chat_completion_response,
create_chat_completion_stream_chunk
@ -87,7 +89,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "model_name not found.")
# TODO: Move this to model_container
model_config = config.get("model") or {}
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / data.name
@ -162,6 +163,62 @@ async def unload_model():
model_container.unload()
model_container = None
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_all_loras():
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")
loras = get_lora_list(lora_path.resolve())
return loras
# Currently loaded loras endpoint
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_active_loras():
active_loras = LoraList(
data = list(map(
lambda lora: LoraCard(
id = pathlib.Path(lora.lora_path).parent.name,
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha
),
model_container.active_loras
)
))
return active_loras
# Load lora endpoint
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
async def load_model(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
# Clean-up existing loras if present
if len(model_container.active_loras) > 0:
model_container.unload(True)
result = model_container.load_loras(lora_dir, **data.dict())
return LoraLoadResponse(
success = result.get("success") or [],
failure = result.get("failure") or []
)
# Unload lora endpoint
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
async def unload_loras():
global model_container
model_container.unload(True)
# Encode tokens endpoint
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def encode_tokens(data: TokenEncodeRequest):
@ -308,6 +365,12 @@ if __name__ == "__main__":
else:
loading_bar.next()
# Load loras
lora_config = model_config.get("lora") or {}
if "loras" in lora_config:
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
model_container.load_loras(lora_dir.resolve(), **lora_config)
network_config = config.get("network") or {}
uvicorn.run(
app,

View file

@ -6,6 +6,7 @@ from exllamav2 import(
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
ExLlamaV2Lora
)
from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
@ -31,6 +32,8 @@ class ModelContainer:
gpu_split_auto: bool = True
gpu_split: list or None = None
active_loras: List[ExLlamaV2Lora] = []
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
"""
Create model container
@ -54,6 +57,8 @@ class ModelContainer:
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
By default, the draft model's alpha value is calculated automatically to scale to the size of the
full model.
'lora_dir' (str): Lora directory
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
@ -141,6 +146,32 @@ class ModelContainer:
"""
for _ in self.load_gen(progress_callback): pass
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
loras = kwargs.get("loras") or []
success: List[str] = []
failure: List[str] = []
for lora in loras:
lora_name = lora.get("name") or None
lora_scaling = lora.get("scaling") or 1.0
if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
failure.append(lora_name)
continue
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling))
print("Lora successfully loaded.")
success.append(lora_name)
# Return success and failure names
return { 'success': success, 'failure': failure }
def load_gen(self, progress_callback = None):
"""
@ -204,11 +235,18 @@ class ModelContainer:
print("Model successfully loaded.")
def unload(self):
def unload(self, loras_only: bool = False):
"""
Free all VRAM resources used by this model
"""
for lora in self.active_loras:
lora.unload()
self.active_loras = []
# Unload the entire model if not just unloading loras
if not loras_only:
if self.model: self.model.unload()
self.model = None
if self.draft_model: self.draft_model.unload()
@ -217,10 +255,10 @@ class ModelContainer:
self.cache = None
self.tokenizer = None
self.generator = None
gc.collect()
torch.cuda.empty_cache()
# Common function for token operations
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
if text:
@ -381,7 +419,7 @@ class ModelContainer:
active_ids = ids[:, max(0, overflow):]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing)
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras)
# Generate