diff --git a/.gitignore b/.gitignore index 314c2f8..14163cf 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,7 @@ api_tokens.yml # Models folder models/* !models/place_your_models_here.txt + +# Loras folder +loras/* +!loras/place_your_loras_here.txt diff --git a/OAI/types/lora.py b/OAI/types/lora.py new file mode 100644 index 0000000..891b6ef --- /dev/null +++ b/OAI/types/lora.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, Field; +from time import time +from typing import Optional, List + +class LoraCard(BaseModel): + id: str = "test" + object: str = "lora" + created: int = Field(default_factory=lambda: int(time())) + owned_by: str = "tabbyAPI" + scaling: Optional[float] = None + +class LoraList(BaseModel): + object: str = "list" + data: List[LoraCard] = Field(default_factory=list) + +class LoraLoadInfo(BaseModel): + name: str + scaling: Optional[float] = 1.0 + +class LoraLoadRequest(BaseModel): + loras: List[LoraLoadInfo] + +class LoraLoadResponse(BaseModel): + success: List[str] = Field(default_factory=list) + failure: List[str] = Field(default_factory=list) diff --git a/OAI/utils.py b/OAI/utils.py index a60e3d9..bb0c93b 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -8,9 +8,10 @@ from OAI.types.chat_completion import ( ChatCompletionStreamChoice ) from OAI.types.common import UsageStats +from OAI.types.lora import LoraList, LoraCard from OAI.types.model import ModelList, ModelCard from packaging import version -from typing import Optional, List +from typing import Optional, List, Dict # Check fastchat try: @@ -100,6 +101,15 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]): return model_card_list +def get_lora_list(lora_path: pathlib.Path): + lora_list = LoraList() + for path in lora_path.iterdir(): + if path.is_dir(): + lora_card = LoraCard(id = path.name) + lora_list.data.append(lora_card) + + return lora_list + def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]): # Check if fastchat is available diff --git a/TabbyAPI_Colab_Example.ipynb b/TabbyAPI_Colab_Example.ipynb index 8a5b584..465bafa 100644 --- a/TabbyAPI_Colab_Example.ipynb +++ b/TabbyAPI_Colab_Example.ipynb @@ -36,10 +36,17 @@ "# @markdown Select model:\n", "repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n", "revision = \"4bpw\" # @param {type:\"string\"}\n", + "if revision == \"\": revision = \"main\"\n", "# @markdown ---\n", "# @markdown Select draft model (optional, for speculative decoding):\n", "draft_repo_id = \"\" # @param {type:\"string\"}\n", "draft_revision = \"\" # @param {type:\"string\"}\n", + "if draft_revision == \"\": draft_revision = \"main\"\n", + "# @markdown ---\n", + "# @markdown Select lora (optional):\n", + "lora_repo_id = \"\" # @param {type:\"string\"}\n", + "lora_revision = \"\" # @param {type:\"string\"}\n", + "if lora_revision == \"\": lora_revision = \"main\"\n", "# @markdown ---\n", "\n", "# Install tabbyAPI\n", @@ -62,8 +69,15 @@ "%cd /content/tabbyAPI/\n", "\n", "from huggingface_hub import snapshot_download\n", + "\n", "snapshot_download(repo_id=repo_id, revision=revision, local_dir=f\"./models/{repo_id.replace('/', '_')}\")\n", - "if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")" + "model = repo_id.replace('/', '_')\n", + "\n", + "if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")\n", + "draft_model = draft_repo_id.replace('/', '_')\n", + "\n", + "if len(lora_repo_id) > 0: snapshot_download(repo_id=lora_repo_id, revision=lora_revision, local_dir=f\"./loras/{lora_repo_id.replace('/', '_')}\")\n", + "lora = lora_repo_id.replace('/', '_')" ] }, { @@ -77,9 +91,6 @@ "# @title # Configure and launch API { display-mode: \"form\" }\n", "# @markdown ---\n", "# @markdown Model parameters:\n", - "\n", - "model = repo_id.replace('/', '_')\n", - "draft_model = draft_repo_id.replace('/', '_')\n", "ContextSize = 4096 # @param {type:\"integer\"}\n", "RopeScale = 1.0 # @param {type:\"number\"}\n", "RopeAlpha = 1.0 # @param {type:\"number\"}\n", @@ -88,6 +99,9 @@ "DraftRopeScale = 1.0 # @param {type:\"number\"}\n", "DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n", "# @markdown ---\n", + "# @markdown Lora parameters (optional, for loras):\n", + "LoraScaling = 1.0 # @param {type:\"number\"}\n", + "# @markdown ---\n", "# @markdown Misc options:\n", "CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n", "UseDummyModels = False # @param {type:\"boolean\"}\n", @@ -161,6 +175,16 @@ " # Rope parameters for draft models (default: 1.0)\n", " draft_rope_scale: {DraftRopeScale}\n", " draft_rope_alpha: {DraftRopeAlpha}\n", + "\n", + " # Options for loras\n", + " lora:\n", + " # Overrides the directory to look for loras (default: loras)\n", + " lora_dir: loras\n", + "\n", + " # List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.\n", + " loras:\n", + " - name: {lora}\n", + " scaling: {LoraScaling}\n", "'''\n", "with open(\"./config.yml\", \"w\") as file:\n", " file.write(write)\n", @@ -188,4 +212,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/config_sample.yml b/config_sample.yml index f1524cc..04eda77 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -60,3 +60,17 @@ model: # Rope parameters for draft models (default: 1.0) draft_rope_scale: 1.0 draft_rope_alpha: 1.0 + + # Options for loras + lora: + # Overrides the directory to look for loras (default: loras) + lora_dir: Your lora directory path + + # List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed. + loras: + - name: lora1 + scaling: 1.0 + - name: lora2 + scaling: 0.9 + - name: lora3 + scaling: 0.5 \ No newline at end of file diff --git a/loras/place_your_loras_here.txt b/loras/place_your_loras_here.txt new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py index c9ddc5b..5627f53 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ from progress.bar import IncrementalBar from generators import generate_with_semaphore from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest +from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse from OAI.types.token import ( TokenEncodeRequest, @@ -21,6 +22,7 @@ from OAI.types.token import ( from OAI.utils import ( create_completion_response, get_model_list, + get_lora_list, get_chat_completion_prompt, create_chat_completion_response, create_chat_completion_stream_chunk @@ -87,7 +89,6 @@ async def load_model(request: Request, data: ModelLoadRequest): if not data.name: raise HTTPException(400, "model_name not found.") - # TODO: Move this to model_container model_config = config.get("model") or {} model_path = pathlib.Path(model_config.get("model_dir") or "models") model_path = model_path / data.name @@ -160,7 +161,63 @@ async def unload_model(): global model_container model_container.unload() - model_container = None + model_container = None + +# Lora list endpoint +@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +async def get_all_loras(): + model_config = config.get("model") or {} + lora_config = model_config.get("lora") or {} + lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras") + + loras = get_lora_list(lora_path.resolve()) + + return loras + +# Currently loaded loras endpoint +@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +async def get_active_loras(): + active_loras = LoraList( + data = list(map( + lambda lora: LoraCard( + id = pathlib.Path(lora.lora_path).parent.name, + scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha + ), + model_container.active_loras + ) + )) + + return active_loras + +# Load lora endpoint +@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) +async def load_model(data: LoraLoadRequest): + if not data.loras: + raise HTTPException(400, "List of loras to load is not found.") + + model_config = config.get("model") or {} + lora_config = model_config.get("lora") or {} + lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras") + if not lora_dir.exists(): + raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?") + + # Clean-up existing loras if present + if len(model_container.active_loras) > 0: + model_container.unload(True) + + result = model_container.load_loras(lora_dir, **data.dict()) + return LoraLoadResponse( + success = result.get("success") or [], + failure = result.get("failure") or [] + ) + +# Unload lora endpoint +@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) +async def unload_loras(): + global model_container + + model_container.unload(True) # Encode tokens endpoint @app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @@ -308,6 +365,12 @@ if __name__ == "__main__": else: loading_bar.next() + # Load loras + lora_config = model_config.get("lora") or {} + if "loras" in lora_config: + lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras") + model_container.load_loras(lora_dir.resolve(), **lora_config) + network_config = config.get("network") or {} uvicorn.run( app, diff --git a/model.py b/model.py index 19031d7..7190cd0 100644 --- a/model.py +++ b/model.py @@ -6,6 +6,7 @@ from exllamav2 import( ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Tokenizer, + ExLlamaV2Lora ) from exllamav2.generator import( ExLlamaV2StreamingGenerator, @@ -30,6 +31,8 @@ class ModelContainer: cache_fp8: bool = False gpu_split_auto: bool = True gpu_split: list or None = None + + active_loras: List[ExLlamaV2Lora] = [] def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs): """ @@ -54,6 +57,8 @@ class ModelContainer: 'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model. By default, the draft model's alpha value is calculated automatically to scale to the size of the full model. + 'lora_dir' (str): Lora directory + 'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling' 'gpu_split_auto' (bool): Automatically split model across available devices (default: True) 'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device 'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False) @@ -141,6 +146,32 @@ class ModelContainer: """ for _ in self.load_gen(progress_callback): pass + def load_loras(self, lora_directory: pathlib.Path, **kwargs): + """ + Load loras + """ + + loras = kwargs.get("loras") or [] + success: List[str] = [] + failure: List[str] = [] + + for lora in loras: + lora_name = lora.get("name") or None + lora_scaling = lora.get("scaling") or 1.0 + + if lora_name is None: + print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.") + failure.append(lora_name) + continue + + print(f"Loading lora: {lora_name} at scaling {lora_scaling}") + lora_path = lora_directory / lora_name + self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)) + print("Lora successfully loaded.") + success.append(lora_name) + + # Return success and failure names + return { 'success': success, 'failure': failure } def load_gen(self, progress_callback = None): """ @@ -204,23 +235,30 @@ class ModelContainer: print("Model successfully loaded.") - def unload(self): + def unload(self, loras_only: bool = False): """ Free all VRAM resources used by this model """ - if self.model: self.model.unload() - self.model = None - if self.draft_model: self.draft_model.unload() - self.draft_model = None - self.config = None - self.cache = None - self.tokenizer = None - self.generator = None + for lora in self.active_loras: + lora.unload() + + self.active_loras = [] + + # Unload the entire model if not just unloading loras + if not loras_only: + if self.model: self.model.unload() + self.model = None + if self.draft_model: self.draft_model.unload() + self.draft_model = None + self.config = None + self.cache = None + self.tokenizer = None + self.generator = None + gc.collect() torch.cuda.empty_cache() - # Common function for token operations def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs): if text: @@ -381,7 +419,7 @@ class ModelContainer: active_ids = ids[:, max(0, overflow):] chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] - self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing) + self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras) # Generate