From fa47f51f8584f3adc10acac3e8b62f599fe4685d Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Fri, 22 Dec 2023 16:20:35 +0000 Subject: [PATCH] feat: workflows for formatting/linting (#35) * add github workflows for pylint and yapf * yapf * docstrings for auth * fix auth.py * fix generators.py * fix gen_logging.py * fix main.py * fix model.py * fix templating.py * fix utils.py * update formatting.sh to include subdirs for pylint * fix model_test.py * fix wheel_test.py * rename utils to utils_oai * fix OAI/utils_oai.py * fix completion.py * fix token.py * fix lora.py * fix common.py * add pylintrc and fix model.py * finish up pylint * fix attribute error * main.py formatting * add formatting batch script * Main: Remove unnecessary global Linter suggestion. Signed-off-by: kingbri * switch to ruff * Formatting + Linting: Add ruff.toml Signed-off-by: kingbri * Formatting + Linting: Switch scripts to use ruff Also remove the file and recent file change functions from both scripts. Signed-off-by: kingbri * Tree: Format and lint Signed-off-by: kingbri * Scripts + Workflows: Format Signed-off-by: kingbri * Tree: Remove pylint flags We use ruff now Signed-off-by: kingbri * Tree: Format Signed-off-by: kingbri * Formatting: Line length is 88 Use the same value as Black. Signed-off-by: kingbri * Tree: Format Update to new line length rules. Signed-off-by: kingbri --------- Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Co-authored-by: kingbri --- .github/workflows/ruff.yml | 32 +++ .ruff.toml | 111 ++++++++++ OAI/types/chat_completion.py | 6 + OAI/types/common.py | 58 +++-- OAI/types/completion.py | 21 +- OAI/types/lora.py | 19 +- OAI/types/model.py | 50 ++++- OAI/types/token.py | 27 ++- OAI/utils.py | 103 --------- OAI/utils_oai.py | 114 ++++++++++ auth.py | 120 ++++++----- formatting.bat | 36 ++++ formatting.sh | 53 +++++ gen_logging.py | 38 ++-- generators.py | 4 + main.py | 363 +++++++++++++++++++------------ model.py | 399 ++++++++++++++++++++++++----------- requirements-dev.txt | 15 ++ templating.py | 53 +++-- tests/model_test.py | 53 +++-- tests/wheel_test.py | 9 +- utils.py | 37 ++-- 22 files changed, 1210 insertions(+), 511 deletions(-) create mode 100644 .github/workflows/ruff.yml create mode 100644 .ruff.toml delete mode 100644 OAI/utils.py create mode 100644 OAI/utils_oai.py create mode 100644 formatting.bat create mode 100755 formatting.sh create mode 100644 requirements-dev.txt diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..138e716 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,32 @@ +name: ruff + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + ruff: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + - name: Format and show diff with ruff + run: | + ruff format --diff + - name: Lint code with ruff + run: | + ruff check diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 0000000..53c7bd4 --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,111 @@ +# Exclude a variety of commonly ignored directories. +exclude = [ + ".git", + ".git-rewrite", + ".mypy_cache", + ".pyenv", + ".pytest_cache", + ".ruff_cache", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.10 +target-version = "py310" + +[lint] +# Enable preview +preview = true + +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +# Enable flake8-bugbear (`B`) rules, in addition to the defaults. +select = ["E4", "E7", "E9", "F", "B"] +extend-select = [ + "D419", # empty-docstring + "PLC2401", # non-ascii-name + "E501", # line-too-long + "W291", # trailing-whitespace + "PLC0414", # useless-import-alias + "E999", # syntax-error + "PLE0101", # return-in-init + "F706", # return-outside-function + "F704", # yield-outside-function + "PLE0116", # continue-in-finally + "PLE0117", # nonlocal-without-binding + "PLE0241", # duplicate-bases + "PLE0302", # unexpected-special-method-signature + "PLE0604", # invalid-all-object + "PLE0704", # misplaced-bare-raise + "PLE1205", # logging-too-many-args + "PLE1206", # logging-too-few-args + "PLE1307", # bad-string-format-type + "PLE1310", # bad-str-strip-call + "PLE1507", # invalid-envvar-value + "PLR0124", # comparison-with-itself + "PLR0202", # no-classmethod-decorator + "PLR0203", # no-staticmethod-decorator + "PLR0206", # property-with-parameters + "PLR1704", # redefined-argument-from-local + "PLR1711", # useless-return + "C416", # unnecessary-comprehension + "PLW0108", # unnecessary-lambda + "PLW0127", # self-assigning-variable + "PLW0129", # assert-on-string-literal + "PLW0602", # global-variable-not-assigned + "PLW0604", # global-at-module-level + "F401", # unused-import + "F841", # unused-variable + "E722", # bare-except + "PLW0711", # binary-op-exception + "PLW1501", # bad-open-mode + "PLW1508", # invalid-envvar-default + "PLW1509", # subprocess-popen-preexec-fn +] +ignore = [ + "PLR6301", # no-self-use + "UP004", # useless-object-inheritance + "PLR0904", # too-many-public-methods + "PLR0911", # too-many-return-statements + "PLR0912", # too-many-branches + "PLR0913", # too-many-arguments + "PLR0914", # too-many-locals + "PLR0915", # too-many-statements + "PLR0916", # too-many-boolean-expressions + "PLW0120", # useless-else-on-loop + "PLW0406", # import-self + "PLW0603", # global-statement + "PLW1641", # eq-without-hash +] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = ["B"] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index 891f548..5e0e80b 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -4,22 +4,26 @@ from pydantic import BaseModel, Field from typing import Union, List, Optional, Dict from OAI.types.common import UsageStats, CommonCompletionRequest + class ChatCompletionMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + class ChatCompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: str message: ChatCompletionMessage + class ChatCompletionStreamChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: Optional[str] delta: Union[ChatCompletionMessage, dict] = {} + # Inherited from common request class ChatCompletionRequest(CommonCompletionRequest): # Messages @@ -28,6 +32,7 @@ class ChatCompletionRequest(CommonCompletionRequest): prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") choices: List[ChatCompletionRespChoice] @@ -36,6 +41,7 @@ class ChatCompletionResponse(BaseModel): object: str = "chat.completion" usage: Optional[UsageStats] = None + class ChatCompletionStreamChunk(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") choices: List[ChatCompletionStreamChoice] diff --git a/OAI/types/common.py b/OAI/types/common.py index ca636b9..f9dfcb0 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,30 +1,54 @@ -from pydantic import BaseModel, Field, AliasChoices +""" Common types for OAI. """ from typing import List, Dict, Optional, Union + +from pydantic import BaseModel, Field, AliasChoices + from utils import unwrap + class LogProbs(BaseModel): + """Represents log probabilities.""" + text_offset: List[int] = Field(default_factory=list) token_logprobs: List[float] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) top_logprobs: List[Dict[str, float]] = Field(default_factory=list) + class UsageStats(BaseModel): + """Represents usage stats.""" + prompt_tokens: int completion_tokens: int total_tokens: int + class CommonCompletionRequest(BaseModel): + """Represents a common completion request.""" + # Model information # This parameter is not used, the loaded model is used instead model: Optional[str] = None # Extra OAI request stuff - best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) - echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False) - logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) - n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1) - suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) - user: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) + best_of: Optional[int] = Field( + description="Not parsed. Only used for OAI compliance.", default=None + ) + echo: Optional[bool] = Field( + description="Not parsed. Only used for OAI compliance.", default=False + ) + logprobs: Optional[int] = Field( + description="Not parsed. Only used for OAI compliance.", default=None + ) + n: Optional[int] = Field( + description="Not parsed. Only used for OAI compliance.", default=1 + ) + suffix: Optional[str] = Field( + description="Not parsed. Only used for OAI compliance.", default=None + ) + user: Optional[str] = Field( + description="Not parsed. Only used for OAI compliance.", default=None + ) # Generation info # seed: Optional[int] = -1 @@ -35,8 +59,9 @@ class CommonCompletionRequest(BaseModel): max_tokens: Optional[int] = 150 # Aliased to repetition_penalty - # TODO: Maybe make this an alias to rep pen - frequency_penalty: Optional[float] = Field(description = "Aliased to Repetition Penalty", default = 0.0) + frequency_penalty: Optional[float] = Field( + description="Aliased to Repetition Penalty", default=0.0 + ) # Sampling params token_healing: Optional[bool] = False @@ -58,18 +83,21 @@ class CommonCompletionRequest(BaseModel): # Aliased variables repetition_range: Optional[int] = Field( - default = None, - validation_alias = AliasChoices('repetition_range', 'repetition_penalty_range') + default=None, + validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"), ) - # Converts to internal generation parameters def to_gen_params(self): + """Converts to internal generation parameters.""" # Convert stop to an array of strings if isinstance(self.stop, str): self.stop = [self.stop] - # Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined - if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty: + # Set repetition_penalty to frequency_penalty if repetition_penalty + # isn't already defined + if ( + self.repetition_penalty is None or self.repetition_penalty == 1.0 + ) and self.frequency_penalty: self.repetition_penalty = self.frequency_penalty return { @@ -87,7 +115,7 @@ class CommonCompletionRequest(BaseModel): "min_p": self.min_p, "tfs": self.tfs, "repetition_penalty": self.repetition_penalty, - "repetition_range": unwrap(self.repetition_range, -1), + "repetition_range": unwrap(self.repetition_range, -1), "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, diff --git a/OAI/types/completion.py b/OAI/types/completion.py index 55f54a3..15e84a7 100644 --- a/OAI/types/completion.py +++ b/OAI/types/completion.py @@ -1,22 +1,35 @@ -from uuid import uuid4 +""" Completion API protocols """ from time import time -from pydantic import BaseModel, Field from typing import List, Optional, Union -from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest +from uuid import uuid4 + +from pydantic import BaseModel, Field + +from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats + class CompletionRespChoice(BaseModel): + """Represents a single choice in a completion response.""" + # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: str logprobs: Optional[LogProbs] = None text: str + # Inherited from common request class CompletionRequest(CommonCompletionRequest): - # Prompt can also contain token ids, but that's out of scope for this project. + """Represents a completion request.""" + + # Prompt can also contain token ids, but that's out of scope + # for this project. prompt: Union[str, List[str]] + class CompletionResponse(BaseModel): + """Represents a completion response.""" + id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}") choices: List[CompletionRespChoice] created: int = Field(default_factory=lambda: int(time())) diff --git a/OAI/types/lora.py b/OAI/types/lora.py index 11baab9..841c3a8 100644 --- a/OAI/types/lora.py +++ b/OAI/types/lora.py @@ -1,25 +1,42 @@ -from pydantic import BaseModel, Field +""" Lora types """ from time import time from typing import Optional, List +from pydantic import BaseModel, Field + + class LoraCard(BaseModel): + """Represents a single Lora card.""" + id: str = "test" object: str = "lora" created: int = Field(default_factory=lambda: int(time())) owned_by: str = "tabbyAPI" scaling: Optional[float] = None + class LoraList(BaseModel): + """Represents a list of Lora cards.""" + object: str = "list" data: List[LoraCard] = Field(default_factory=list) + class LoraLoadInfo(BaseModel): + """Represents a single Lora load info.""" + name: str scaling: Optional[float] = 1.0 + class LoraLoadRequest(BaseModel): + """Represents a Lora load request.""" + loras: List[LoraLoadInfo] + class LoraLoadResponse(BaseModel): + """Represents a Lora load response.""" + success: List[str] = Field(default_factory=list) failure: List[str] = Field(default_factory=list) diff --git a/OAI/types/model.py b/OAI/types/model.py index 3ee0f02..3721477 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -1,19 +1,29 @@ -from pydantic import BaseModel, Field, ConfigDict +""" Contains model card types. """ from time import time from typing import List, Optional + +from pydantic import BaseModel, Field, ConfigDict + from gen_logging import LogConfig + class ModelCardParameters(BaseModel): - # Safe to do this since it's guaranteed to fetch a max seq len from model_container + """Represents model card parameters.""" + + # Safe to do this since it's guaranteed to fetch a max seq len + # from model_container max_seq_len: Optional[int] = None rope_scale: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0 cache_mode: Optional[str] = "FP16" prompt_template: Optional[str] = None num_experts_per_token: Optional[int] = None - draft: Optional['ModelCard'] = None + draft: Optional["ModelCard"] = None + class ModelCard(BaseModel): + """Represents a single model card.""" + id: str = "test" object: str = "model" created: int = Field(default_factory=lambda: int(time())) @@ -21,26 +31,47 @@ class ModelCard(BaseModel): logging: Optional[LogConfig] = None parameters: Optional[ModelCardParameters] = None + class ModelList(BaseModel): + """Represents a list of model cards.""" + object: str = "list" data: List[ModelCard] = Field(default_factory=list) + class DraftModelLoadRequest(BaseModel): + """Represents a draft model load request.""" + draft_model_name: str draft_rope_scale: Optional[float] = 1.0 - draft_rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None) + draft_rope_alpha: Optional[float] = Field( + description="Automatically calculated if not present", default=None + ) + # TODO: Unify this with ModelCardParams class ModelLoadRequest(BaseModel): + """Represents a model load request.""" + name: str # Max seq len is fetched from config.json of the model by default - max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None) - override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None) + max_seq_len: Optional[int] = Field( + description="Leave this blank to use the model's base sequence length", + default=None, + ) + override_base_seq_len: Optional[int] = Field( + description=( + "Overrides the model's base sequence length. " "Leave blank if unsure" + ), + default=None, + ) gpu_split_auto: Optional[bool] = True gpu_split: Optional[List[float]] = Field(default_factory=list) rope_scale: Optional[float] = 1.0 - rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None) + rope_alpha: Optional[float] = Field( + description="Automatically calculated if not present", default=None + ) no_flash_attention: Optional[bool] = False # low_mem: Optional[bool] = False cache_mode: Optional[str] = "FP16" @@ -48,9 +79,12 @@ class ModelLoadRequest(BaseModel): num_experts_per_token: Optional[int] = None draft: Optional[DraftModelLoadRequest] = None + class ModelLoadResponse(BaseModel): + """Represents a model load response.""" + # Avoids pydantic namespace warning - model_config = ConfigDict(protected_namespaces = []) + model_config = ConfigDict(protected_namespaces=[]) model_type: str = "model" module: int diff --git a/OAI/types/token.py b/OAI/types/token.py index a0bf3f9..98cbc98 100644 --- a/OAI/types/token.py +++ b/OAI/types/token.py @@ -1,30 +1,51 @@ -from pydantic import BaseModel +""" Tokenization types """ from typing import List +from pydantic import BaseModel + + class CommonTokenRequest(BaseModel): + """Represents a common tokenization request.""" + add_bos_token: bool = True encode_special_tokens: bool = True decode_special_tokens: bool = True def get_params(self): + """Get the parameters for tokenization.""" return { "add_bos_token": self.add_bos_token, "encode_special_tokens": self.encode_special_tokens, - "decode_special_tokens": self.decode_special_tokens + "decode_special_tokens": self.decode_special_tokens, } + class TokenEncodeRequest(CommonTokenRequest): + """Represents a tokenization request.""" + text: str + class TokenEncodeResponse(BaseModel): + """Represents a tokenization response.""" + tokens: List[int] length: int + class TokenDecodeRequest(CommonTokenRequest): + """ " Represents a detokenization request.""" + tokens: List[int] + class TokenDecodeResponse(BaseModel): + """Represents a detokenization response.""" + text: str + class TokenCountResponse(BaseModel): - length: int + """Represents a token count response.""" + + length: int diff --git a/OAI/utils.py b/OAI/utils.py deleted file mode 100644 index 7421de3..0000000 --- a/OAI/utils.py +++ /dev/null @@ -1,103 +0,0 @@ -import pathlib -from OAI.types.completion import CompletionResponse, CompletionRespChoice -from OAI.types.chat_completion import ( - ChatCompletionMessage, - ChatCompletionRespChoice, - ChatCompletionStreamChunk, - ChatCompletionResponse, - ChatCompletionStreamChoice -) -from OAI.types.common import UsageStats -from OAI.types.lora import LoraList, LoraCard -from OAI.types.model import ModelList, ModelCard -from typing import Optional - -from utils import unwrap - -def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]): - choice = CompletionRespChoice( - finish_reason = "Generated", - text = text - ) - - response = CompletionResponse( - choices = [choice], - model = unwrap(model_name, ""), - usage = UsageStats(prompt_tokens = prompt_tokens, - completion_tokens = completion_tokens, - total_tokens = prompt_tokens + completion_tokens) - ) - - return response - -def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]): - message = ChatCompletionMessage( - role = "assistant", - content = text - ) - - choice = ChatCompletionRespChoice( - finish_reason = "Generated", - message = message - ) - - response = ChatCompletionResponse( - choices = [choice], - model = unwrap(model_name, ""), - usage = UsageStats(prompt_tokens = prompt_tokens, - completion_tokens = completion_tokens, - total_tokens = prompt_tokens + completion_tokens) - ) - - return response - -def create_chat_completion_stream_chunk(const_id: str, - text: Optional[str] = None, - model_name: Optional[str] = None, - finish_reason: Optional[str] = None): - if finish_reason: - message = {} - else: - message = ChatCompletionMessage( - role = "assistant", - content = text - ) - - # The finish reason can be None - choice = ChatCompletionStreamChoice( - finish_reason = finish_reason, - delta = message - ) - - chunk = ChatCompletionStreamChunk( - id = const_id, - choices = [choice], - model = unwrap(model_name, "") - ) - - return chunk - -def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): - - # Convert the provided draft model path to a pathlib path for equality comparisons - if draft_model_path: - draft_model_path = pathlib.Path(draft_model_path).resolve() - - model_card_list = ModelList() - for path in model_path.iterdir(): - - # Don't include the draft models path - if path.is_dir() and path != draft_model_path: - model_card = ModelCard(id = path.name) - model_card_list.data.append(model_card) - - return model_card_list - -def get_lora_list(lora_path: pathlib.Path): - lora_list = LoraList() - for path in lora_path.iterdir(): - if path.is_dir(): - lora_card = LoraCard(id = path.name) - lora_list.data.append(lora_card) - - return lora_list diff --git a/OAI/utils_oai.py b/OAI/utils_oai.py new file mode 100644 index 0000000..b3c59d6 --- /dev/null +++ b/OAI/utils_oai.py @@ -0,0 +1,114 @@ +""" Utility functions for the OpenAI server. """ +import pathlib +from typing import Optional + +from OAI.types.chat_completion import ( + ChatCompletionMessage, + ChatCompletionRespChoice, + ChatCompletionStreamChunk, + ChatCompletionResponse, + ChatCompletionStreamChoice, +) +from OAI.types.completion import CompletionResponse, CompletionRespChoice +from OAI.types.common import UsageStats +from OAI.types.lora import LoraList, LoraCard +from OAI.types.model import ModelList, ModelCard + +from utils import unwrap + + +def create_completion_response( + text: str, + prompt_tokens: int, + completion_tokens: int, + model_name: Optional[str], +): + """Create a completion response from the provided text.""" + choice = CompletionRespChoice(finish_reason="Generated", text=text) + + response = CompletionResponse( + choices=[choice], + model=unwrap(model_name, ""), + usage=UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + return response + + +def create_chat_completion_response( + text: str, + prompt_tokens: int, + completion_tokens: int, + model_name: Optional[str], +): + """Create a chat completion response from the provided text.""" + message = ChatCompletionMessage(role="assistant", content=text) + + choice = ChatCompletionRespChoice(finish_reason="Generated", message=message) + + response = ChatCompletionResponse( + choices=[choice], + model=unwrap(model_name, ""), + usage=UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + return response + + +def create_chat_completion_stream_chunk( + const_id: str, + text: Optional[str] = None, + model_name: Optional[str] = None, + finish_reason: Optional[str] = None, +): + """Create a chat completion stream chunk from the provided text.""" + if finish_reason: + message = {} + else: + message = ChatCompletionMessage(role="assistant", content=text) + + # The finish reason can be None + choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message) + + chunk = ChatCompletionStreamChunk( + id=const_id, choices=[choice], model=unwrap(model_name, "") + ) + + return chunk + + +def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): + """Get the list of models from the provided path.""" + + # Convert the provided draft model path to a pathlib path for + # equality comparisons + if draft_model_path: + draft_model_path = pathlib.Path(draft_model_path).resolve() + + model_card_list = ModelList() + for path in model_path.iterdir(): + # Don't include the draft models path + if path.is_dir() and path != draft_model_path: + model_card = ModelCard(id=path.name) + model_card_list.data.append(model_card) # pylint: disable=no-member + + return model_card_list + + +def get_lora_list(lora_path: pathlib.Path): + """Get the list of Lora cards from the provided path.""" + lora_list = LoraList() + for path in lora_path.iterdir(): + if path.is_dir(): + lora_card = LoraCard(id=path.name) + lora_list.data.append(lora_card) # pylint: disable=no-member + + return lora_list diff --git a/auth.py b/auth.py index 80611e6..451ba0e 100644 --- a/auth.py +++ b/auth.py @@ -1,105 +1,125 @@ -import secrets -import yaml -from fastapi import Header, HTTPException -from pydantic import BaseModel -from typing import Optional - """ This method of authorization is pretty insecure, but since TabbyAPI is a local application, it should be fine. """ +import secrets +from typing import Optional + +from fastapi import Header, HTTPException +from pydantic import BaseModel +import yaml + class AuthKeys(BaseModel): + """ + This class represents the authentication keys for the application. + It contains two types of keys: 'api_key' and 'admin_key'. + The 'api_key' is used for general API calls, while the 'admin_key' + is used for administrative tasks. The class also provides a method + to verify if a given key matches the stored 'api_key' or 'admin_key'. + """ + api_key: str admin_key: str def verify_key(self, test_key: str, key_type: str): - # Match statements are only available in python 3.10 and up + """Verify if a given key matches the stored key.""" if key_type == "admin_key": return test_key == self.admin_key - elif key_type == "api_key": + if key_type == "api_key": # Admin keys are valid for all API calls return test_key == self.api_key or test_key == self.admin_key - else: - return False + return False + + +AUTH_KEYS: Optional[AuthKeys] = None +DISABLE_AUTH: bool = False -auth_keys: Optional[AuthKeys] = None -disable_auth: bool = False def load_auth_keys(disable_from_config: bool): - global auth_keys - global disable_auth + """Load the authentication keys from api_tokens.yml. If the file does not + exist, generate new keys and save them to api_tokens.yml.""" + global AUTH_KEYS + global DISABLE_AUTH - disable_auth = disable_from_config + DISABLE_AUTH = disable_from_config if disable_from_config: print( - "!! Warning: Disabling authentication makes your instance vulnerable.", - "Set the \"disable_auth\" flag to False in config.yml if you want to share this", - "instance with others." + "!! Warning: Disabling authentication", + "makes your instance vulnerable.", + "Set the 'disable_auth' flag to False in config.yml", + "if you want to share this instance with others.", ) return try: - with open("api_tokens.yml", "r", encoding = 'utf8') as auth_file: + with open("api_tokens.yml", "r", encoding="utf8") as auth_file: auth_keys_dict = yaml.safe_load(auth_file) - auth_keys = AuthKeys.model_validate(auth_keys_dict) + AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except OSError: new_auth_keys = AuthKeys( - api_key = secrets.token_hex(16), - admin_key = secrets.token_hex(16) + api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16) ) - auth_keys = new_auth_keys + AUTH_KEYS = new_auth_keys - with open("api_tokens.yml", "w", encoding = "utf8") as auth_file: - yaml.safe_dump(auth_keys.model_dump(), auth_file, default_flow_style=False) + with open("api_tokens.yml", "w", encoding="utf8") as auth_file: + yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) print( - f"Your API key is: {auth_keys.api_key}\n" - f"Your admin key is: {auth_keys.admin_key}\n\n" - "If these keys get compromised, make sure to delete api_tokens.yml and restart the server. Have fun!" + f"Your API key is: {AUTH_KEYS.api_key}\n" + f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" + "If these keys get compromised, make sure to delete api_tokens.yml " + "and restart the server. Have fun!" ) + def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)): + """Check if the API key is valid.""" + # Allow request if auth is disabled - if disable_auth: + if DISABLE_AUTH: return if x_api_key: - if auth_keys.verify_key(x_api_key, "api_key"): - return x_api_key - else: + if not AUTH_KEYS.verify_key(x_api_key, "api_key"): raise HTTPException(401, "Invalid API key") - elif authorization: - split_key = authorization.split(" ") + return x_api_key + if authorization: + split_key = authorization.split(" ") if len(split_key) < 2: raise HTTPException(401, "Invalid API key") - elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "api_key"): - return authorization - else: + if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key( + split_key[1], "api_key" + ): raise HTTPException(401, "Invalid API key") - else: - raise HTTPException(401, "Please provide an API key") + + return authorization + + raise HTTPException(401, "Please provide an API key") + def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)): + """Check if the admin key is valid.""" + # Allow request if auth is disabled - if disable_auth: + if DISABLE_AUTH: return if x_admin_key: - if auth_keys.verify_key(x_admin_key, "admin_key"): - return x_admin_key - else: + if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"): raise HTTPException(401, "Invalid admin key") - elif authorization: - split_key = authorization.split(" ") + return x_admin_key + if authorization: + split_key = authorization.split(" ") if len(split_key) < 2: raise HTTPException(401, "Invalid admin key") - elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "admin_key"): - return authorization - else: + if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key( + split_key[1], "admin_key" + ): raise HTTPException(401, "Invalid admin key") - else: - raise HTTPException(401, "Please provide an admin key") + return authorization + + raise HTTPException(401, "Please provide an admin key") diff --git a/formatting.bat b/formatting.bat new file mode 100644 index 0000000..b510aea --- /dev/null +++ b/formatting.bat @@ -0,0 +1,36 @@ +@echo off +setlocal + +::Change to script's directory +cd /d %~dp0 + +::Get tool versions +for /f "tokens=2" %%i in ('ruff --version') do set RUFF_VERSION="%%i" + +::Check tool versions +call :tool_version_check "ruff" %RUFF_VERSION% "0.1.9" + +::Format and lint files +call ruff format +call ruff check + +echo tabbyAPI ruff lint and format: Done + +::Check if any files were changed +git diff --quiet +if errorlevel 1 ( + echo Reformatted files. Please review and stage the changes. + echo Changes not staged for commit: + echo. + git --no-pager diff --name-only + exit /b 1 +) + +exit /b 0 + +:tool_version_check +if not "%2"=="%3" ( + echo Wrong %1 version installed: %3 is required, not %2. + exit /b 1 +) +exit /b 0 diff --git a/formatting.sh b/formatting.sh new file mode 100755 index 0000000..2351641 --- /dev/null +++ b/formatting.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# YAPF formatter, adapted from ray and skypilot. +# +# Usage: +# # Do work and commit your work. + +# # Format files that differ from origin/main. +# bash formatting.sh + +# # Commit changed files with message 'Run yapf and ruff' +# +# +# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. + +# Cause the script to exit if a single command fails +set -eo pipefail + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" || exit 1 + +RUFF_VERSION=$(ruff --version | head -n 1 | awk '{print $2}') + +# params: tool name, tool version, required version +tool_version_check() { + if [[ $2 != $3 ]]; then + echo "Wrong $1 version installed: $3 is required, not $2." + exit 1 + fi +} + +tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" + +# Format and lint all files +format_and_lint() { + ruff format + ruff check +} + +# Call format command +format_and_lint +echo 'tabbyAPI ruff lint and format: Done' + +if ! git diff --quiet &>/dev/null; then + echo 'Reformatted files. Please review and stage the changes.' + echo 'Changes not staged for commit:' + echo + git --no-pager diff --name-only + + exit 1 +fi diff --git a/gen_logging.py b/gen_logging.py index e0986e4..c8d3620 100644 --- a/gen_logging.py +++ b/gen_logging.py @@ -1,31 +1,40 @@ +""" +Functions for logging generation events. +""" from typing import Dict from pydantic import BaseModel -# Logging preference config + class LogConfig(BaseModel): + """Logging preference config.""" + prompt: bool = False generation_params: bool = False -# Global reference to logging preferences -config = LogConfig() -# Wrapper to set the logging config for generations +# Global reference to logging preferences +CONFIG = LogConfig() + + def update_from_dict(options_dict: Dict[str, bool]): - global config + """Wrapper to set the logging config for generations""" + global CONFIG # Force bools on the dict for value in options_dict.values(): if value is None: value = False - config = LogConfig.model_validate(options_dict) + CONFIG = LogConfig.model_validate(options_dict) + def broadcast_status(): + """Broadcasts the current logging status""" enabled = [] - if config.prompt: + if CONFIG.prompt: enabled.append("prompts") - if config.generation_params: + if CONFIG.generation_params: enabled.append("generation params") if len(enabled) > 0: @@ -33,15 +42,20 @@ def broadcast_status(): else: print("Generation logging is disabled") -# Logs generation parameters to console + def log_generation_params(**kwargs): - if config.generation_params: + """Logs generation parameters to console.""" + if CONFIG.generation_params: print(f"Generation options: {kwargs}\n") + def log_prompt(prompt: str): - if config.prompt: + """Logs the prompt to console.""" + if CONFIG.prompt: print(f"Prompt: {prompt if prompt else 'Empty'}\n") + def log_response(response: str): - if config.prompt: + """Logs the response to console.""" + if CONFIG.prompt: print(f"Response: {response if response else 'Empty'}\n") diff --git a/generators.py b/generators.py index 285f26a..d2cad0a 100644 --- a/generators.py +++ b/generators.py @@ -1,3 +1,4 @@ +"""Generator functions for the tabbyAPI.""" import inspect from asyncio import Semaphore from functools import partialmethod @@ -5,8 +6,10 @@ from typing import AsyncGenerator generate_semaphore = Semaphore(1) + # Async generation that blocks on a semaphore async def generate_with_semaphore(generator: AsyncGenerator): + """Generate with a semaphore.""" async with generate_semaphore: if inspect.isasyncgenfunction: async for result in generator(): @@ -15,6 +18,7 @@ async def generate_with_semaphore(generator: AsyncGenerator): for result in generator(): yield result + # Block a function with semaphore async def call_with_semaphore(callback: partialmethod): if inspect.iscoroutinefunction(callback): diff --git a/main.py b/main.py index e39ed1e..c497999 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,16 @@ -import uvicorn -import yaml +"""The main tabbyAPI module. Contains the FastAPI server and endpoints.""" import pathlib from asyncio import CancelledError -from fastapi import FastAPI, Request, HTTPException, Depends +from typing import Optional +from uuid import uuid4 + +import uvicorn +import yaml +from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from functools import partial from progress.bar import IncrementalBar -from typing import Optional -from uuid import uuid4 import gen_logging from auth import check_admin_key, check_api_key, load_auth_keys @@ -17,19 +19,24 @@ from model import ModelContainer from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse -from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse, ModelCardParameters +from OAI.types.model import ( + ModelCard, + ModelLoadRequest, + ModelLoadResponse, + ModelCardParameters, +) from OAI.types.token import ( TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, - TokenDecodeResponse + TokenDecodeResponse, ) -from OAI.utils import ( +from OAI.utils_oai import ( create_completion_response, get_model_list, get_lora_list, - create_chat_completion_response, - create_chat_completion_stream_chunk + create_chat_completion_response, + create_chat_completion_stream_chunk, ) from templating import get_prompt_from_template from utils import get_generator_error, get_sse_packet, load_progress, unwrap @@ -37,13 +44,15 @@ from utils import get_generator_error, get_sse_packet, load_progress, unwrap app = FastAPI() # Globally scoped variables. Undefined until initalized in main -model_container: Optional[ModelContainer] = None +MODEL_CONTAINER: Optional[ModelContainer] = None config: dict = {} + def _check_model_container(): - if model_container is None or model_container.model is None: + if MODEL_CONTAINER is None or MODEL_CONTAINER.model is None: raise HTTPException(400, "No models are loaded.") + # ALlow CORS requests app.add_middleware( CORSMiddleware, @@ -53,10 +62,12 @@ app.add_middleware( allow_headers=["*"], ) + # Model list endpoint @app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): + """Lists all models in the model directory.""" model_config = unwrap(config.get("model"), {}) model_dir = unwrap(model_config.get("model_dir"), "models") model_path = pathlib.Path(model_dir) @@ -66,43 +77,53 @@ async def list_models(): models = get_model_list(model_path.resolve(), draft_model_dir) if unwrap(model_config.get("use_dummy_models"), False): - models.data.insert(0, ModelCard(id = "gpt-3.5-turbo")) + models.data.insert(0, ModelCard(id="gpt-3.5-turbo")) return models + # Currently loaded model endpoint -@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) -@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.get( + "/v1/model", + dependencies=[Depends(check_api_key), Depends(_check_model_container)], +) +@app.get( + "/v1/internal/model/info", + dependencies=[Depends(check_api_key), Depends(_check_model_container)], +) async def get_current_model(): - model_name = model_container.get_model_path().name - prompt_template = model_container.prompt_template + """Returns the currently loaded model.""" + model_name = MODEL_CONTAINER.get_model_path().name + prompt_template = MODEL_CONTAINER.prompt_template model_card = ModelCard( - id = model_name, - parameters = ModelCardParameters( - rope_scale = model_container.config.scale_pos_emb, - rope_alpha = model_container.config.scale_alpha_value, - max_seq_len = model_container.config.max_seq_len, - cache_mode = "FP8" if model_container.cache_fp8 else "FP16", - prompt_template = prompt_template.name if prompt_template else None + id=model_name, + parameters=ModelCardParameters( + rope_scale=MODEL_CONTAINER.config.scale_pos_emb, + rope_alpha=MODEL_CONTAINER.config.scale_alpha_value, + max_seq_len=MODEL_CONTAINER.config.max_seq_len, + cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16", + prompt_template=prompt_template.name if prompt_template else None, ), - logging = gen_logging.config + logging=gen_logging.CONFIG, ) - if model_container.draft_config: + if MODEL_CONTAINER.draft_config: draft_card = ModelCard( - id = model_container.get_model_path(True).name, - parameters = ModelCardParameters( - rope_scale = model_container.draft_config.scale_pos_emb, - rope_alpha = model_container.draft_config.scale_alpha_value, - max_seq_len = model_container.draft_config.max_seq_len - ) + id=MODEL_CONTAINER.get_model_path(True).name, + parameters=ModelCardParameters( + rope_scale=MODEL_CONTAINER.draft_config.scale_pos_emb, + rope_alpha=MODEL_CONTAINER.draft_config.scale_alpha_value, + max_seq_len=MODEL_CONTAINER.draft_config.max_seq_len, + ), ) model_card.parameters.draft = draft_card return model_card + @app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(): + """Lists all draft models in the model directory.""" model_config = unwrap(config.get("model"), {}) draft_config = unwrap(model_config.get("draft"), {}) draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models") @@ -112,12 +133,14 @@ async def list_draft_models(): return models + # Load model endpoint @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) async def load_model(request: Request, data: ModelLoadRequest): - global model_container + """Loads a model into the model container.""" + global MODEL_CONTAINER - if model_container and model_container.model: + if MODEL_CONTAINER and MODEL_CONTAINER.model: raise HTTPException(400, "A model is already loaded! Please unload it first.") if not data.name: @@ -129,32 +152,35 @@ async def load_model(request: Request, data: ModelLoadRequest): load_data = data.model_dump() - # TODO: Add API exception if draft directory isn't found draft_config = unwrap(model_config.get("draft"), {}) if data.draft: if not data.draft.draft_model_name: - raise HTTPException(400, "draft_model_name was not found inside the draft object.") + raise HTTPException( + 400, "draft_model_name was not found inside the draft object." + ) - load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models") + load_data["draft"]["draft_model_dir"] = unwrap( + draft_config.get("draft_model_dir"), "models" + ) if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") - model_container = ModelContainer(model_path.resolve(), False, **load_data) + MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data) async def generator(): - global model_container + """Generator for the loading process.""" - model_type = "draft" if model_container.draft_config else "model" - load_status = model_container.load_gen(load_progress) + model_type = "draft" if MODEL_CONTAINER.draft_config else "model" + load_status = MODEL_CONTAINER.load_gen(load_progress) try: - for (module, modules) in load_status: + for module, modules in load_status: if await request.is_disconnected(): break if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules) elif module == modules: loading_bar.next() loading_bar.finish() @@ -163,13 +189,13 @@ async def load_model(request: Request, data: ModelLoadRequest): model_type=model_type, module=module, modules=modules, - status="finished" + status="finished", ) yield get_sse_packet(response.model_dump_json()) # Switch to model progress if the draft model is loaded - if model_container.draft_config: + if MODEL_CONTAINER.draft_config: model_type = "model" else: loading_bar.next() @@ -178,29 +204,39 @@ async def load_model(request: Request, data: ModelLoadRequest): model_type=model_type, module=module, modules=modules, - status="processing" + status="processing", ) yield get_sse_packet(response.model_dump_json()) except CancelledError: - print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.") - except Exception as e: - yield get_generator_error(str(e)) + print( + "\nError: Model load cancelled by user. " + "Please make sure to run unload to free up resources." + ) + except Exception as exc: + yield get_generator_error(str(exc)) + + return StreamingResponse(generator(), media_type="text/event-stream") - return StreamingResponse(generator(), media_type = "text/event-stream") # Unload model endpoint -@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) +@app.get( + "/v1/model/unload", + dependencies=[Depends(check_admin_key), Depends(_check_model_container)], +) async def unload_model(): - global model_container + """Unloads the currently loaded model.""" + global MODEL_CONTAINER + + MODEL_CONTAINER.unload() + MODEL_CONTAINER = None - model_container.unload() - model_container = None # Lora list endpoint @app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) async def get_all_loras(): + """Lists all LoRAs in the lora directory.""" model_config = unwrap(config.get("model"), {}) lora_config = unwrap(model_config.get("lora"), {}) lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) @@ -209,191 +245,240 @@ async def get_all_loras(): return loras + # Currently loaded loras endpoint -@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.get( + "/v1/lora", + dependencies=[Depends(check_api_key), Depends(_check_model_container)], +) async def get_active_loras(): + """Returns the currently loaded loras.""" active_loras = LoraList( - data = list(map( - lambda lora: LoraCard( - id = pathlib.Path(lora.lora_path).parent.name, - scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha - ), - model_container.active_loras + data=list( + map( + lambda lora: LoraCard( + id=pathlib.Path(lora.lora_path).parent.name, + scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, + ), + MODEL_CONTAINER.active_loras, + ) ) - )) + ) return active_loras + # Load lora endpoint -@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) +@app.post( + "/v1/lora/load", + dependencies=[Depends(check_admin_key), Depends(_check_model_container)], +) async def load_lora(data: LoraLoadRequest): + """Loads a LoRA into the model container.""" if not data.loras: raise HTTPException(400, "List of loras to load is not found.") model_config = unwrap(config.get("model"), {}) - lora_config = unwrap(model_config.get("lora"), {}) + lora_config = unwrap(model_config.get("lora"), {}) lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) if not lora_dir.exists(): - raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?") + raise HTTPException( + 400, + "A parent lora directory does not exist. Check your config.yml?", + ) # Clean-up existing loras if present - if len(model_container.active_loras) > 0: - model_container.unload(True) + if len(MODEL_CONTAINER.active_loras) > 0: + MODEL_CONTAINER.unload(True) - result = model_container.load_loras(lora_dir, **data.model_dump()) + result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump()) return LoraLoadResponse( - success = unwrap(result.get("success"), []), - failure = unwrap(result.get("failure"), []) + success=unwrap(result.get("success"), []), + failure=unwrap(result.get("failure"), []), ) + # Unload lora endpoint -@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) +@app.get( + "/v1/lora/unload", + dependencies=[Depends(check_admin_key), Depends(_check_model_container)], +) async def unload_loras(): - model_container.unload(True) + """Unloads the currently loaded loras.""" + MODEL_CONTAINER.unload(True) + # Encode tokens endpoint -@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.post( + "/v1/token/encode", + dependencies=[Depends(check_api_key), Depends(_check_model_container)], +) async def encode_tokens(data: TokenEncodeRequest): - raw_tokens = model_container.get_tokens(data.text, None, **data.get_params()) + """Encodes a string into tokens.""" + raw_tokens = MODEL_CONTAINER.get_tokens(data.text, None, **data.get_params()) - # Have to use this if check otherwise Torch's tensors error out with a boolean issue + # Have to use this if check otherwise Torch's tensors error out + # with a boolean issue tokens = raw_tokens[0].tolist() if raw_tokens is not None else [] response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) return response + # Decode tokens endpoint -@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.post( + "/v1/token/decode", + dependencies=[Depends(check_api_key), Depends(_check_model_container)], +) async def decode_tokens(data: TokenDecodeRequest): - message = model_container.get_tokens(None, data.tokens, **data.get_params()) - response = TokenDecodeResponse(text = unwrap(message, "")) + """Decodes tokens into a string.""" + message = MODEL_CONTAINER.get_tokens(None, data.tokens, **data.get_params()) + response = TokenDecodeResponse(text=unwrap(message, "")) return response + # Completions endpoint -@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.post( + "/v1/completions", + dependencies=[Depends(check_api_key), Depends(_check_model_container)], +) async def generate_completion(request: Request, data: CompletionRequest): - model_path = model_container.get_model_path() + """Generates a completion from a prompt.""" + model_path = MODEL_CONTAINER.get_model_path() if isinstance(data.prompt, list): data.prompt = "\n".join(data.prompt) if data.stream: + async def generator(): + """Generator for the generation process.""" try: - new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) - for (part, prompt_tokens, completion_tokens) in new_generation: + new_generation = MODEL_CONTAINER.generate_gen( + data.prompt, **data.to_gen_params() + ) + for part, prompt_tokens, completion_tokens in new_generation: if await request.is_disconnected(): break - response = create_completion_response(part, - prompt_tokens, - completion_tokens, - model_path.name) + response = create_completion_response( + part, prompt_tokens, completion_tokens, model_path.name + ) yield get_sse_packet(response.model_dump_json()) except CancelledError: print("Error: Completion request cancelled by user.") - except Exception as e: - yield get_generator_error(str(e)) + except Exception as exc: + yield get_generator_error(str(exc)) return StreamingResponse( - generate_with_semaphore(generator), - media_type = "text/event-stream" + generate_with_semaphore(generator), media_type="text/event-stream" ) - else: - response_text, prompt_tokens, completion_tokens = await call_with_semaphore( - partial(model_container.generate, data.prompt, **data.to_gen_params()) - ) - response = create_completion_response(response_text, - prompt_tokens, - completion_tokens, - model_path.name) - return response + response_text, prompt_tokens, completion_tokens = await call_with_semaphore( + partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params()) + ) + + response = create_completion_response( + response_text, prompt_tokens, completion_tokens, model_path.name + ) + + return response + # Chat completions endpoint -@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.post( + "/v1/chat/completions", + dependencies=[Depends(check_api_key), Depends(_check_model_container)], +) async def generate_chat_completion(request: Request, data: ChatCompletionRequest): - if model_container.prompt_template is None: - return HTTPException(422, "This endpoint is disabled because a prompt template is not set.") + """Generates a chat completion from a prompt.""" + if MODEL_CONTAINER.prompt_template is None: + return HTTPException( + 422, + "This endpoint is disabled because a prompt template is not set.", + ) - model_path = model_container.get_model_path() + model_path = MODEL_CONTAINER.get_model_path() if isinstance(data.messages, str): prompt = data.messages else: try: - special_tokens_dict = model_container.get_special_tokens( + special_tokens_dict = MODEL_CONTAINER.get_special_tokens( unwrap(data.add_bos_token, True), - unwrap(data.ban_eos_token, False) + unwrap(data.ban_eos_token, False), ) prompt = get_prompt_from_template( data.messages, - model_container.prompt_template, + MODEL_CONTAINER.prompt_template, data.add_generation_prompt, special_tokens_dict, ) except KeyError: return HTTPException( 400, - f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?" + "Could not find a Conversation from prompt template " + f"'{MODEL_CONTAINER.prompt_template.name}'. " + "Check your spelling?", ) if data.stream: const_id = f"chatcmpl-{uuid4().hex}" + async def generator(): + """Generator for the generation process.""" try: - new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) - for (part, _, _) in new_generation: + new_generation = MODEL_CONTAINER.generate_gen( + prompt, **data.to_gen_params() + ) + for part, _, _ in new_generation: if await request.is_disconnected(): break response = create_chat_completion_stream_chunk( - const_id, - part, - model_path.name + const_id, part, model_path.name ) yield get_sse_packet(response.model_dump_json()) # Yield a finish response on successful generation finish_response = create_chat_completion_stream_chunk( - const_id, - finish_reason = "stop" + const_id, finish_reason="stop" ) yield get_sse_packet(finish_response.model_dump_json()) except CancelledError: print("Error: Chat completion cancelled by user.") - except Exception as e: - yield get_generator_error(str(e)) + except Exception as exc: + yield get_generator_error(str(exc)) return StreamingResponse( - generate_with_semaphore(generator), - media_type = "text/event-stream" + generate_with_semaphore(generator), media_type="text/event-stream" ) - else: - response_text, prompt_tokens, completion_tokens = await call_with_semaphore( - partial(model_container.generate, prompt, **data.to_gen_params()) - ) - response = create_chat_completion_response(response_text, - prompt_tokens, - completion_tokens, - model_path.name) - return response + response_text, prompt_tokens, completion_tokens = await call_with_semaphore( + partial(MODEL_CONTAINER.generate, prompt, **data.to_gen_params()) + ) + + response = create_chat_completion_response( + response_text, prompt_tokens, completion_tokens, model_path.name + ) + + return response + if __name__ == "__main__": # Load from YAML config. Possibly add a config -> kwargs conversion function try: - with open('config.yml', 'r', encoding = "utf8") as config_file: + with open("config.yml", "r", encoding="utf8") as config_file: config = unwrap(yaml.safe_load(config_file), {}) - except Exception as e: + except Exception as exc: print( "The YAML config couldn't load because of the following error:", - f"\n\n{e}", - "\n\nTabbyAPI will start anyway and not parse this config file." + f"\n\n{exc}", + "\n\nTabbyAPI will start anyway and not parse this config file.", ) config = {} @@ -409,18 +494,18 @@ if __name__ == "__main__": gen_logging.broadcast_status() - # If an initial model name is specified, create a container and load the model + # If an initial model name is specified, create a container + # and load the model model_config = unwrap(config.get("model"), {}) if "model_name" in model_config: - # TODO: Move this to model_container model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / model_config.get("model_name") - model_container = ModelContainer(model_path.resolve(), False, **model_config) - load_status = model_container.load_gen(load_progress) - for (module, modules) in load_status: + MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config) + load_status = MODEL_CONTAINER.load_gen(load_progress) + for module, modules in load_status: if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules) elif module == modules: loading_bar.next() loading_bar.finish() @@ -431,11 +516,11 @@ if __name__ == "__main__": lora_config = unwrap(model_config.get("lora"), {}) if "loras" in lora_config: lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) - model_container.load_loras(lora_dir.resolve(), **lora_config) + MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config) uvicorn.run( app, host=network_config.get("host", "127.0.0.1"), port=network_config.get("port", 5000), - log_level="debug" + log_level="debug", ) diff --git a/model.py b/model.py index 767e9ad..e4dde49 100644 --- a/model.py +++ b/model.py @@ -1,29 +1,36 @@ +"""The model container class for ExLlamaV2 models.""" import gc import pathlib import time + import torch -from exllamav2 import( +from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Tokenizer, - ExLlamaV2Lora -) -from exllamav2.generator import( - ExLlamaV2StreamingGenerator, - ExLlamaV2Sampler + ExLlamaV2Lora, ) +from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union -from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file +from templating import ( + PromptTemplate, + find_template_from_model, + get_template_from_model_json, + get_template_from_file, +) from utils import coalesce, unwrap # Bytes to reserve on first device when loading with auto split -auto_split_reserve_bytes = 96 * 1024**2 +AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2 + class ModelContainer: + """The model container class for ExLlamaV2 models.""" + config: Optional[ExLlamaV2Config] = None draft_config: Optional[ExLlamaV2Config] = None model: Optional[ExLlamaV2] = None @@ -40,35 +47,51 @@ class ModelContainer: active_loras: List[ExLlamaV2Lora] = [] - def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs): + def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): """ Create model container Args: - model_dir (int): Model directory containing config.json, tokenizer.model etc. + model_dir (int): Model directory containing config.json, + tokenizer.model etc. quiet (bool): Suppress console output - load_progress_callback (function, optional): A function to call for each module loaded. Prototype: - def progress(loaded_modules: int, total_modules: int, loading_draft: bool) + load_progress_callback (function, optional): A function to call for + each module loaded. Prototype: + def progress(loaded_modules: int, total_modules: int, + loading_draft: bool) **kwargs: - `cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16") - 'max_seq_len' (int): Override model's default max sequence length (default: 4096) - 'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0) - 'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0) - 'prompt_template' (str): Manually sets the prompt template for this model (default: None) - 'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048) - Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller - batches. This limits the size of temporary buffers needed for the hidden state and attention - weights. + `cache_mode` (str): Sets cache mode, "FP16" or "FP8" + (defaulf: "FP16") + 'max_seq_len' (int): Override model's default max sequence + length (default: 4096) + 'rope_scale' (float): Set RoPE scaling factor for model + (default: 1.0) + 'rope_alpha' (float): Set RoPE alpha (NTK) factor for model + (default: 1.0) + 'prompt_template' (str): Manually sets the prompt template for + this model (default: None) + 'chunk_size' (int): Sets the maximum chunk size for the model + (default: 2048) + Inferencing in chunks reduces overall VRAM overhead by + processing very long sequences in smaller batches. This + limits the size of temporary buffers needed for the hidden + state and attention weights. 'draft_model_dir' (str): Draft model directory - 'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0) - 'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model. - By default, the draft model's alpha value is calculated automatically to scale to the size of the + 'draft_rope_scale' (float): Set RoPE scaling factor for draft + model (default: 1.0) + 'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft + model. By default, the draft model's alpha value is + calculated automatically to scale to the size of the full model. - 'lora_dir' (str): Lora directory - 'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling' - 'gpu_split_auto' (bool): Automatically split model across available devices (default: True) - 'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device - 'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False) + 'lora_dir' (str): LoRA directory + 'loras' (list[dict]): List of loras to be loaded, consisting of + 'name' and 'scaling' + 'gpu_split_auto' (bool): Automatically split model across + available devices (default: True) + 'gpu_split' (list[float]): Allocation for weights and (some) + tensors, per device + 'no_flash_attn' (bool): Turns off flash attention + (increases vram usage) (default: False) """ self.quiet = quiet @@ -90,7 +113,8 @@ class ModelContainer: if override_base_seq_len: self.config.max_seq_len = override_base_seq_len - # Grab the base model's sequence length before overrides for rope calculations + # Grab the base model's sequence length before overrides for + # rope calculations base_seq_len = self.config.max_seq_len # Set the target seq len if present @@ -103,14 +127,14 @@ class ModelContainer: # Automatically calculate rope alpha self.config.scale_alpha_value = unwrap( - kwargs.get("rope_alpha"), - self.calculate_rope_alpha(base_seq_len) + kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) ) # Turn off flash attention? self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False) - # low_mem is currently broken in exllamav2. Don't use it until it's fixed. + # low_mem is currently broken in exllamav2. Don't use it until it's + # fixed. """ if "low_mem" in kwargs and kwargs["low_mem"]: self.config.set_low_mem() @@ -119,7 +143,10 @@ class ModelContainer: # Set prompt template override if provided prompt_template_name = kwargs.get("prompt_template") if prompt_template_name: - print(f"Attempting to load prompt template with name {prompt_template_name}") + print( + "Attempting to load prompt template with name", + {prompt_template_name}, + ) # Read the template self.prompt_template = get_template_from_file(prompt_template_name) else: @@ -127,16 +154,17 @@ class ModelContainer: self.prompt_template = get_template_from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", "chat_template", - "from_tokenizer_config" + "from_tokenizer_config", ) # Try finding the chat template from the model's config.json - # TODO: This may not even be used with huggingface models, mark for removal. + # TODO: This may not even be used with huggingface models, + # mark for removal. if self.prompt_template is None: self.prompt_template = get_template_from_model_json( pathlib.Path(self.config.model_config), "chat_template", - "from_model_config" + "from_model_config", ) # If that fails, attempt fetching from model name @@ -147,10 +175,13 @@ class ModelContainer: # Catch all for template lookup errors if self.prompt_template: - print(f"Using template {self.prompt_template.name} for chat completions.") + print( + f"Using template {self.prompt_template.name} for chat " "completions." + ) else: print( - "Chat completions are disabled because a prompt template wasn't provided or auto-detected." + "Chat completions are disabled because a prompt template", + "wasn't provided or auto-detected.", ) # Set num of experts per token if provided @@ -159,11 +190,16 @@ class ModelContainer: if hasattr(self.config, "num_experts_per_token"): self.config.num_experts_per_token = num_experts_override else: - print(" !! Warning: Currently installed ExLlamaV2 does not support overriding MoE experts") + print( + " !! Warning: Currently installed ExLlamaV2 does not " + "support overriding MoE experts" + ) - chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len) + chunk_size = min( + unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len + ) self.config.max_input_len = chunk_size - self.config.max_attn_size = chunk_size ** 2 + self.config.max_attn_size = chunk_size**2 draft_args = unwrap(kwargs.get("draft"), {}) draft_model_name = draft_args.get("draft_model_name") @@ -171,47 +207,63 @@ class ModelContainer: # Always disable draft if params are incorrectly configured if draft_args and draft_model_name is None: - print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.") + print( + "A draft config was found but a model name was not given. " + "Please check your config.yml! Skipping draft load." + ) enable_draft = False if enable_draft: self.draft_config = ExLlamaV2Config() - draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models")) + draft_model_path = pathlib.Path( + unwrap(draft_args.get("draft_model_dir"), "models") + ) draft_model_path = draft_model_path / draft_model_name self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0) + self.draft_config.scale_pos_emb = unwrap( + draft_args.get("draft_rope_scale"), 1.0 + ) # Automatically calculate draft rope alpha self.draft_config.scale_alpha_value = unwrap( draft_args.get("draft_rope_alpha"), - self.calculate_rope_alpha(self.draft_config.max_seq_len) + self.calculate_rope_alpha(self.draft_config.max_seq_len), ) - self.draft_config.max_seq_len = self.config.max_seq_len + self.draft_config.max_seq_len = self.config.max_seq_len if "chunk_size" in kwargs: self.draft_config.max_input_len = kwargs["chunk_size"] self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 def calculate_rope_alpha(self, base_seq_len): + """Calculate the rope alpha value for a given sequence length.""" ratio = self.config.max_seq_len / base_seq_len - # Default to a 1 alpha if the sequence length is ever less than or equal to 1 - alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 + # Default to a 1 alpha if the sequence length is ever less + # than or equal to 1 + if ratio <= 1.0: + alpha = 1 + else: + alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2 return alpha def get_model_path(self, is_draft: bool = False): - model_path = pathlib.Path(self.draft_config.model_dir if is_draft else self.config.model_dir) + """Get the path for this model.""" + model_path = pathlib.Path( + self.draft_config.model_dir if is_draft else self.config.model_dir + ) return model_path - def load(self, progress_callback = None): + def load(self, progress_callback=None): """ Load model Args: - progress_callback (function, optional): A function to call for each module loaded. Prototype: + progress_callback (function, optional): A function to call for each + module loaded. Prototype: def progress(loaded_modules: int, total_modules: int) """ for _ in self.load_gen(progress_callback): @@ -231,25 +283,32 @@ class ModelContainer: lora_scaling = unwrap(lora.get("scaling"), 1.0) if lora_name is None: - print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.") + print( + "One of your loras does not have a name. Please check your " + "config.yml! Skipping lora load." + ) failure.append(lora_name) continue print(f"Loading lora: {lora_name} at scaling {lora_scaling}") lora_path = lora_directory / lora_name - self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)) + # FIXME(alpin): Does self.model need to be passed here? + self.active_loras.append( + ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) + ) print("Lora successfully loaded.") success.append(lora_name) # Return success and failure names - return { 'success': success, 'failure': failure } + return {"success": success, "failure": failure} - def load_gen(self, progress_callback = None): + def load_gen(self, progress_callback=None): """ Load model, generator function Args: - progress_callback (function, optional): A function to call for each module loaded. Prototype: + progress_callback (function, optional): A function to call for each + module loaded. Prototype: def progress(loaded_modules: int, total_modules: int) """ @@ -262,13 +321,18 @@ class ModelContainer: if not self.quiet: print("Loading draft model: " + self.draft_config.model_dir) - self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy = True) - reserve = [auto_split_reserve_bytes] + [0] * 16 - yield from self.draft_model.load_autosplit_gen(self.draft_cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback) + self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True) + reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16 + yield from self.draft_model.load_autosplit_gen( + self.draft_cache, + reserve_vram=reserve, + last_id_only=True, + callback_gen=progress_callback, + ) # Test VRAM allocation with a full-length forward pass - input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long) - self.draft_model.forward(input_ids, cache = self.cache, preprocess_only = True) + input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) + self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True) # Load model self.model = ExLlamaV2(self.config) @@ -276,29 +340,41 @@ class ModelContainer: print("Loading model: " + self.config.model_dir) if not self.gpu_split_auto: - for value in self.model.load_gen(self.gpu_split, callback_gen = progress_callback): + for value in self.model.load_gen( + self.gpu_split, callback_gen=progress_callback + ): if isinstance(value, str): yield value if self.cache_fp8: - self.cache = ExLlamaV2Cache_8bit(self.model, lazy = self.gpu_split_auto) + self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto) else: - self.cache = ExLlamaV2Cache(self.model, lazy = self.gpu_split_auto) + self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto) if self.gpu_split_auto: - reserve = [auto_split_reserve_bytes] + [0] * 16 - yield from self.model.load_autosplit_gen(self.cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback) + reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16 + yield from self.model.load_autosplit_gen( + self.cache, + reserve_vram=reserve, + last_id_only=True, + callback_gen=progress_callback, + ) # Test VRAM allocation with a full-length forward pass - input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long) - self.model.forward(input_ids, cache = self.cache, preprocess_only = True) + input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) + self.model.forward(input_ids, cache=self.cache, preprocess_only=True) # Create generator - self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer, self.draft_model, self.draft_cache) + self.generator = ExLlamaV2StreamingGenerator( + self.model, + self.cache, + self.tokenizer, + self.draft_model, + self.draft_cache, + ) print("Model successfully loaded.") - def unload(self, loras_only: bool = False): """ Free all VRAM resources used by this model @@ -327,19 +403,24 @@ class ModelContainer: gc.collect() torch.cuda.empty_cache() - # Common function for token operations def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs): + """Common function for token operations""" if text: # Assume token encoding return self.tokenizer.encode( text, - add_bos = unwrap(kwargs.get("add_bos_token"), True), - encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True) + add_bos=unwrap(kwargs.get("add_bos_token"), True), + encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), ) if ids: # Assume token decoding ids = torch.tensor([ids]) - return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0] + return self.tokenizer.decode( + ids, + decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), + )[0] + + return None def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool): return { @@ -350,13 +431,15 @@ class ModelContainer: } def generate(self, prompt: str, **kwargs): + """Generate a response to a prompt""" generation = list(self.generate_gen(prompt, **kwargs)) if generation: response = "".join(map(lambda chunk: chunk[0], generation)) return response, generation[-1][1], generation[-1][2] - else: - return "", 0, 0 + return "", 0, 0 + + # pylint: disable=too-many-locals,too-many-branches,too-many-statements def generate_gen(self, prompt: str, **kwargs): """ Create generator function for prompt completion @@ -366,7 +449,8 @@ class ModelContainer: **kwargs: 'token_healing' (bool): Use token healing (default: False) 'temperature' (float): Sampling temperature (default: 1.0) - 'temperature_last' (bool): Apply temperature after all other samplers (default: False) + 'temperature_last' (bool): Apply temperature after all other + samplers (default: False) 'top_k' (int): Sampling top-K (default: 0) 'top_p' (float): Sampling top-P (default: 1.0) 'min_p' (float): Sampling min-P (default: 0.0) @@ -375,19 +459,27 @@ class ModelContainer: 'mirostat' (bool): Use Mirostat (default: False) 'mirostat_tau' (float) Mirostat tau parameter (default: 1.5) 'mirostat_eta' (float) Mirostat eta parameter (default: 0.1) - 'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) - 'repetition_range' (int): Repetition penalty range (default: whole context) - 'repetition_decay' (int): Repetition penalty range (default: same as range) - 'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS]) + 'repetition_penalty' (float): Token repetition/presence penalty + (default: 1.15) + 'repetition_range' (int): Repetition penalty range + (default: whole context) + 'repetition_decay' (int): Repetition penalty range + (default: same as range) + 'stop' (List[Union[str, int]]): List of stop strings/tokens to + end response (default: [EOS]) 'max_tokens' (int): Max no. tokens in response (default: 150) - 'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True) - 'ban_eos_token' (bool): Bans the EOS token from generation (default: False) - 'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None) - 'stream_interval' (float): Interval in seconds between each output chunk (default: immediate) - 'generate_window' (int): Space to reserve at the end of the model's context when generating. - Rolls context window by the same amount if context length is exceeded to allow generating past - the models max_seq_len. - + 'add_bos_token' (bool): Adds the BOS token to the start of the + prompt (default: True) + 'ban_eos_token' (bool): Bans the EOS token from generation + (default: False) + 'logit_bias' (Dict[int, float]): Biases specific tokens to + either show up more or less (default: None) + 'stream_interval' (float): Interval in seconds between each + output chunk (default: immediate) + 'generate_window' (int): Space to reserve at the end of the + model's context when generating. Rolls context window by + the same amount if context length is exceeded to allow + generating pastthe models max_seq_len. """ token_healing = unwrap(kwargs.get("token_healing"), False) @@ -399,17 +491,37 @@ class ModelContainer: gen_settings = ExLlamaV2Sampler.Settings() # Warn of unsupported settings if the setting is enabled - if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"): - print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") + if (unwrap(kwargs.get("mirostat"), False)) and not hasattr( + gen_settings, "mirostat" + ): + print( + " !! Warning: Currently installed ExLlamaV2 does not support " + "Mirostat sampling" + ) - if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): - print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling") + if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr( + gen_settings, "min_p" + ): + print( + " !! Warning: Currently installed ExLlamaV2 does not " + "support min-P sampling" + ) - if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): - print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") + if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr( + gen_settings, "tfs" + ): + print( + " !! Warning: Currently installed ExLlamaV2 does not support " + "tail-free sampling (TFS)" + ) - if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"): - print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") + if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr( + gen_settings, "temperature_last" + ): + print( + " !! Warning: Currently installed ExLlamaV2 does not support " + "temperature_last" + ) # Apply settings gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) @@ -424,14 +536,24 @@ class ModelContainer: # Default tau and eta fallbacks don't matter if mirostat is off gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) - gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0) - gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len) + gen_settings.token_repetition_penalty = unwrap( + kwargs.get("repetition_penalty"), 1.0 + ) + gen_settings.token_repetition_range = unwrap( + kwargs.get("repetition_range"), self.config.max_seq_len + ) # Always make sure the fallback is 0 if range < 0 - # It's technically fine to use -1, but this just validates the passed fallback + # It's technically fine to use -1, but this just validates the passed + # fallback # Always default to 0 if something goes wrong - fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range - gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0) + if gen_settings.token_repetition_range <= 0: + fallback_decay = 0 + else: + fallback_decay = gen_settings.token_repetition_range + gen_settings.token_repetition_decay = coalesce( + kwargs.get("repetition_decay"), fallback_decay, 0 + ) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) add_bos_token = unwrap(kwargs.get("add_bos_token"), True) @@ -448,13 +570,13 @@ class ModelContainer: # Log generation options to console # Some options are too large, so log the args instead log_generation_params( - max_tokens = max_tokens, + max_tokens=max_tokens, **vars(gen_settings), - token_healing = token_healing, - add_bos_token = add_bos_token, - ban_eos_token = ban_eos_token, - stop_conditions = stop_conditions, - logit_bias = logit_bias + token_healing=token_healing, + add_bos_token=add_bos_token, + ban_eos_token=ban_eos_token, + stop_conditions=stop_conditions, + logit_bias=logit_bias, ) # Log prompt to console @@ -465,13 +587,17 @@ class ModelContainer: # Create a vocab tensor if it doesn't exist for token biasing if gen_settings.token_bias is None: padding = -self.tokenizer.config.vocab_size % 32 - gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float) + gen_settings.token_bias = torch.zeros( + (self.tokenizer.config.vocab_size + padding,), + dtype=torch.float, + ) # Map logits to the tensor with their biases for token, bias in logit_bias.items(): gen_settings.token_bias[token] = bias - # Ban the EOS token if specified. If not, append to stop conditions as well. + # Ban the EOS token if specified. If not, append to stop conditions + # as well. # Set this below logging to avoid polluting the stop strings array if ban_eos_token: gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) @@ -483,16 +609,15 @@ class ModelContainer: # Tokenized context ids = self.tokenizer.encode( - prompt, - add_bos = add_bos_token, - encode_special_tokens = True + prompt, add_bos=add_bos_token, encode_special_tokens=True ) context_len = len(ids[0]) if context_len > self.config.max_seq_len: print( - f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.", - "Generation is truncated and metrics may not be accurate." + f"WARNING: The context length {context_len} is greater than " + f"the max_seq_len {self.config.max_seq_len}.", + "Generation is truncated and metrics may not be accurate.", ) prompt_tokens = ids.shape[-1] @@ -503,26 +628,32 @@ class ModelContainer: start_time = time.time() last_chunk_time = start_time - save_tokens = torch.empty((1, 0), dtype = torch.bool) + save_tokens = torch.empty((1, 0), dtype=torch.bool) chunk_buffer = "" chunk_tokens = 0 while True: # Ingest prompt if chunk_tokens == 0: - ids = torch.cat((ids, save_tokens), dim = - 1) - save_tokens = torch.empty((1, 0), dtype = torch.bool) + ids = torch.cat((ids, save_tokens), dim=-1) + save_tokens = torch.empty((1, 0), dtype=torch.bool) overflow = ids.shape[-1] + generate_window - self.config.max_seq_len - active_ids = ids[:, max(0, overflow):] + active_ids = ids[:, max(0, overflow) :] chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] - self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras) + self.generator.begin_stream( + active_ids, + gen_settings, + token_healing=token_healing, + loras=self.active_loras, + ) # Generate chunk, eos, tokens = self.generator.stream() if token_healing: - ids[:, -1] = self.generator.sequence_ids[:, -2] # Extract healed token + # Extract healed token + ids[:, -1] = self.generator.sequence_ids[:, -2] token_healing = False save_tokens = torch.cat((save_tokens, tokens), dim=-1) @@ -535,7 +666,9 @@ class ModelContainer: now = time.time() elapsed = now - last_chunk_time - if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens): + if chunk_buffer != "" and ( + elapsed > stream_interval or eos or generated_tokens == max_tokens + ): yield chunk_buffer, prompt_tokens, generated_tokens full_response += chunk_buffer chunk_buffer = "" @@ -549,12 +682,20 @@ class ModelContainer: elapsed_time = last_chunk_time - start_time - initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds" + initial_response = ( + f"Metrics: {generated_tokens} tokens generated in " + f"{round(elapsed_time, 2)} seconds" + ) itemization = [] extra_parts = [] # Add tokens per second - itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s") + tokens_per_second = ( + "Indeterminate" + if elapsed_time == 0 + else round(generated_tokens / elapsed_time, 2) + ) + itemization.append(f"{tokens_per_second} T/s") # Add context (original token count) if ids is not None: @@ -564,4 +705,10 @@ class ModelContainer: extra_parts.append("<-- Not accurate (truncated)") # Print output - print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts)) + print( + initial_response + + " (" + + ", ".join(itemization) + + ") " + + " ".join(extra_parts) + ) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..319245a --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,15 @@ +# formatting +ruff==0.1.9 + +## Implement below dependencies when support is added + +# type checking +# mypy==0.991 +# types-PyYAML +# types-requests +# types-setuptools + +# testing +# pytest +# pytest-forked +# pytest-asyncio \ No newline at end of file diff --git a/templating.py b/templating.py index c7fde76..f1a86e3 100644 --- a/templating.py +++ b/templating.py @@ -1,46 +1,56 @@ +"""Small replication of AutoTokenizer's chat template system for efficiency""" import json import pathlib from functools import lru_cache from importlib.metadata import version as package_version + from jinja2.sandbox import ImmutableSandboxedEnvironment from packaging import version from pydantic import BaseModel from typing import Optional, Dict -# Small replication of AutoTokenizer's chat template system for efficiency class PromptTemplate(BaseModel): + """A template for chat completion prompts.""" + name: str template: str -def get_prompt_from_template(messages, - prompt_template: PromptTemplate, - add_generation_prompt: bool, - special_tokens: Optional[Dict[str, str]] = None): + +def get_prompt_from_template( + messages, + prompt_template: PromptTemplate, + add_generation_prompt: bool, + special_tokens: Optional[Dict[str, str]] = None, +): + """Get a prompt from a template and a list of messages.""" if version.parse(package_version("jinja2")) < version.parse("3.0.0"): raise ImportError( - "Parsing these chat completion messages requires jinja2 3.0.0 or greater. " - f"Current version: {version('jinja2')}\n" + "Parsing these chat completion messages requires jinja2 3.0.0 " + f"or greater. Current version: {package_version('jinja2')}\n" "Please upgrade jinja by running the following command: " "pip install --upgrade jinja2" ) compiled_template = _compile_template(prompt_template.template) return compiled_template.render( - messages = messages, - add_generation_prompt = add_generation_prompt, + messages=messages, + add_generation_prompt=add_generation_prompt, **special_tokens, ) -# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 + +# Inspired from +# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 @lru_cache def _compile_template(template: str): - jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True) + jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) jinja_template = jinja_env.from_string(template) return jinja_template -# Find a matching template name from a model path + def find_template_from_model(model_path: pathlib.Path): + """Find a matching template name from a model path.""" model_name = model_path.name template_directory = pathlib.Path("templates") for filepath in template_directory.glob("*.jinja"): @@ -50,14 +60,16 @@ def find_template_from_model(model_path: pathlib.Path): if template_name in model_name.lower(): return template_name -# Get a template from a jinja file + return None + + def get_template_from_file(prompt_template_name: str): + """Get a template from a jinja file.""" template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") if template_path.exists(): - with open(template_path, "r", encoding = "utf8") as raw_template: + with open(template_path, "r", encoding="utf8") as raw_template: return PromptTemplate( - name = prompt_template_name, - template = raw_template.read() + name=prompt_template_name, template=raw_template.read() ) return None @@ -66,15 +78,12 @@ def get_template_from_file(prompt_template_name: str): # Get a template from a JSON file # Requires a key and template name def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str): + """Get a template from a JSON file. Requires a key and template name""" if json_path.exists(): - with open(json_path, "r", encoding = "utf8") as config_file: + with open(json_path, "r", encoding="utf8") as config_file: model_config = json.load(config_file) chat_template = model_config.get(key) if chat_template: - return PromptTemplate( - name = name, - template = chat_template - ) + return PromptTemplate(name=name, template=chat_template) return None - diff --git a/tests/model_test.py b/tests/model_test.py index ee5375d..b4ac158 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -1,22 +1,49 @@ +""" Test the model container. """ from model import ModelContainer + def progress(module, modules): + """Wrapper callback for load progress.""" yield module, modules -container = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/") -loader = container.load_gen(progress) -for (module, modules) in loader: - print(module, modules) -generator = container.generate_gen("Once upon a tim", token_healing = True) -for g in generator: - print(g, end = "") +def test_load_gen(model_path): + """Test loading a model.""" + container = ModelContainer(model_path) + loader = container.load_gen(progress) + for module, modules in loader: + print(module, modules) + container.unload() + del container -container.unload() -del container -mc = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/") -mc.load(progress) +def test_generate_gen(model_path): + """Test generating from a model.""" + container = ModelContainer(model_path) + generator = container.generate_gen("Once upon a tim", token_healing=True) + for chunk in generator: + print(chunk, end="") + container.unload() + del container -response = mc.generate("All work and no play makes turbo a derpy cat.\nAll work and no play makes turbo a derpy cat.\nAll", top_k = 1, max_new_tokens = 1000, stream_interval = 0.5) -print (response) + +def test_generate(model_path): + """Test generating from a model.""" + model_container = ModelContainer(model_path) + model_container.load(progress) + prompt = ( + "All work and no play makes turbo a derpy cat.\n" + "All work and no play makes turbo a derpy cat.\nAll" + ) + response = model_container.generate( + prompt, top_k=1, max_new_tokens=1000, stream_interval=0.5 + ) + print(response) + + +if __name__ == "__main__": + MODEL1 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/" + MODEL2 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/" + test_load_gen(MODEL1) + test_generate_gen(MODEL1) + test_generate(MODEL2) diff --git a/tests/wheel_test.py b/tests/wheel_test.py index 58343bf..c343ed5 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -1,3 +1,4 @@ +""" Test if the wheels are installed correctly. """ from importlib.metadata import version from importlib.util import find_spec @@ -34,8 +35,12 @@ else: print( f"\nSuccessful imports: {', '.join(successful_packages)}", - f"\nErrored imports: {''.join(errored_packages)}" + f"\nErrored imports: {''.join(errored_packages)}", ) if len(errored_packages) > 0: - print("\nIf packages are installed, but not found on this test, please check the wheel versions for the correct python version and CUDA version (if applicable).") + print( + "\nIf packages are installed, but not found on this test, please " + "check the wheel versions for the correct python version and CUDA " + "version (if applicable)." + ) diff --git a/utils.py b/utils.py index a94fcc9..6f00d3e 100644 --- a/utils.py +++ b/utils.py @@ -1,43 +1,54 @@ +"""Common utilities for the tabbyAPI""" import traceback -from pydantic import BaseModel from typing import Optional -# Wrapper callback for load progress +from pydantic import BaseModel + + def load_progress(module, modules): + """Wrapper callback for load progress.""" yield module, modules -# Common error types + class TabbyGeneratorErrorMessage(BaseModel): + """Common error types.""" + message: str trace: Optional[str] = None + class TabbyGeneratorError(BaseModel): + """Common error types.""" + error: TabbyGeneratorErrorMessage + def get_generator_error(message: str): + """Get a generator error.""" error_message = TabbyGeneratorErrorMessage( - message = message, - trace = traceback.format_exc() + message=message, trace=traceback.format_exc() ) - generator_error = TabbyGeneratorError( - error = error_message - ) + generator_error = TabbyGeneratorError(error=error_message) # Log and send the exception print(f"\n{generator_error.error.trace}") return get_sse_packet(generator_error.model_dump_json()) + def get_sse_packet(json_data: str): + """Get an SSE packet.""" return f"data: {json_data}\n\n" -# Unwrap function for Optionals -def unwrap(wrapped, default = None): + +def unwrap(wrapped, default=None): + """Unwrap function for Optionals.""" if wrapped is None: return default - else: - return wrapped -# Coalesce function for multiple unwraps + return wrapped + + def coalesce(*args): + """Coalesce function for multiple unwraps.""" return next((arg for arg in args if arg is not None), None)