feat: workflows for formatting/linting (#35)

* add github workflows for pylint and yapf

* yapf

* docstrings for auth

* fix auth.py

* fix generators.py

* fix gen_logging.py

* fix main.py

* fix model.py

* fix templating.py

* fix utils.py

* update formatting.sh to include subdirs for pylint

* fix model_test.py

* fix wheel_test.py

* rename utils to utils_oai

* fix OAI/utils_oai.py

* fix completion.py

* fix token.py

* fix lora.py

* fix common.py

* add pylintrc and fix model.py

* finish up pylint

* fix attribute error

* main.py formatting

* add formatting batch script

* Main: Remove unnecessary global

Linter suggestion.

Signed-off-by: kingbri <bdashore3@proton.me>

* switch to ruff

* Formatting + Linting: Add ruff.toml

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting + Linting: Switch scripts to use ruff

Also remove the file and recent file change functions from both
scripts.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format and lint

Signed-off-by: kingbri <bdashore3@proton.me>

* Scripts + Workflows: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Remove pylint flags

We use ruff now

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting: Line length is 88

Use the same value as Black.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Update to new line length rules.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
Co-authored-by: kingbri <bdashore3@proton.me>
This commit is contained in:
AlpinDale 2023-12-22 16:20:35 +00:00 committed by GitHub
parent a14abfe21c
commit fa47f51f85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1210 additions and 511 deletions

32
.github/workflows/ruff.yml vendored Normal file
View file

@ -0,0 +1,32 @@
name: ruff
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Format and show diff with ruff
run: |
ruff format --diff
- name: Lint code with ruff
run: |
ruff check

111
.ruff.toml Normal file
View file

@ -0,0 +1,111 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".git",
".git-rewrite",
".mypy_cache",
".pyenv",
".pytest_cache",
".ruff_cache",
".venv",
".vscode",
"__pypackages__",
"_build",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
# Same as Black.
line-length = 88
indent-width = 4
# Assume Python 3.10
target-version = "py310"
[lint]
# Enable preview
preview = true
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
# Enable flake8-bugbear (`B`) rules, in addition to the defaults.
select = ["E4", "E7", "E9", "F", "B"]
extend-select = [
"D419", # empty-docstring
"PLC2401", # non-ascii-name
"E501", # line-too-long
"W291", # trailing-whitespace
"PLC0414", # useless-import-alias
"E999", # syntax-error
"PLE0101", # return-in-init
"F706", # return-outside-function
"F704", # yield-outside-function
"PLE0116", # continue-in-finally
"PLE0117", # nonlocal-without-binding
"PLE0241", # duplicate-bases
"PLE0302", # unexpected-special-method-signature
"PLE0604", # invalid-all-object
"PLE0704", # misplaced-bare-raise
"PLE1205", # logging-too-many-args
"PLE1206", # logging-too-few-args
"PLE1307", # bad-string-format-type
"PLE1310", # bad-str-strip-call
"PLE1507", # invalid-envvar-value
"PLR0124", # comparison-with-itself
"PLR0202", # no-classmethod-decorator
"PLR0203", # no-staticmethod-decorator
"PLR0206", # property-with-parameters
"PLR1704", # redefined-argument-from-local
"PLR1711", # useless-return
"C416", # unnecessary-comprehension
"PLW0108", # unnecessary-lambda
"PLW0127", # self-assigning-variable
"PLW0129", # assert-on-string-literal
"PLW0602", # global-variable-not-assigned
"PLW0604", # global-at-module-level
"F401", # unused-import
"F841", # unused-variable
"E722", # bare-except
"PLW0711", # binary-op-exception
"PLW1501", # bad-open-mode
"PLW1508", # invalid-envvar-default
"PLW1509", # subprocess-popen-preexec-fn
]
ignore = [
"PLR6301", # no-self-use
"UP004", # useless-object-inheritance
"PLR0904", # too-many-public-methods
"PLR0911", # too-many-return-statements
"PLR0912", # too-many-branches
"PLR0913", # too-many-arguments
"PLR0914", # too-many-locals
"PLR0915", # too-many-statements
"PLR0916", # too-many-boolean-expressions
"PLW0120", # useless-else-on-loop
"PLW0406", # import-self
"PLW0603", # global-statement
"PLW1641", # eq-without-hash
]
# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = ["B"]
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

View file

@ -4,22 +4,26 @@ from pydantic import BaseModel, Field
from typing import Union, List, Optional, Dict
from OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
message: ChatCompletionMessage
class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: Optional[str]
delta: Union[ChatCompletionMessage, dict] = {}
# Inherited from common request
class ChatCompletionRequest(CommonCompletionRequest):
# Messages
@ -28,6 +32,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionRespChoice]
@ -36,6 +41,7 @@ class ChatCompletionResponse(BaseModel):
object: str = "chat.completion"
usage: Optional[UsageStats] = None
class ChatCompletionStreamChunk(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionStreamChoice]

View file

@ -1,30 +1,54 @@
from pydantic import BaseModel, Field, AliasChoices
""" Common types for OAI. """
from typing import List, Dict, Optional, Union
from pydantic import BaseModel, Field, AliasChoices
from utils import unwrap
class LogProbs(BaseModel):
"""Represents log probabilities."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[float] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel):
"""Represents usage stats."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CommonCompletionRequest(BaseModel):
"""Represents a common completion request."""
# Model information
# This parameter is not used, the loaded model is used instead
model: Optional[str] = None
# Extra OAI request stuff
best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False)
logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1)
suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
user: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
best_of: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
logprobs: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
suffix: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
user: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info
# seed: Optional[int] = -1
@ -35,8 +59,9 @@ class CommonCompletionRequest(BaseModel):
max_tokens: Optional[int] = 150
# Aliased to repetition_penalty
# TODO: Maybe make this an alias to rep pen
frequency_penalty: Optional[float] = Field(description = "Aliased to Repetition Penalty", default = 0.0)
frequency_penalty: Optional[float] = Field(
description="Aliased to Repetition Penalty", default=0.0
)
# Sampling params
token_healing: Optional[bool] = False
@ -58,18 +83,21 @@ class CommonCompletionRequest(BaseModel):
# Aliased variables
repetition_range: Optional[int] = Field(
default = None,
validation_alias = AliasChoices('repetition_range', 'repetition_penalty_range')
default=None,
validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"),
)
# Converts to internal generation parameters
def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
# Set repetition_penalty to frequency_penalty if repetition_penalty
# isn't already defined
if (
self.repetition_penalty is None or self.repetition_penalty == 1.0
) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
@ -87,7 +115,7 @@ class CommonCompletionRequest(BaseModel):
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_range": unwrap(self.repetition_range, -1),
"repetition_range": unwrap(self.repetition_range, -1),
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,

View file

@ -1,22 +1,35 @@
from uuid import uuid4
""" Completion API protocols """
from time import time
from pydantic import BaseModel, Field
from typing import List, Optional, Union
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
from uuid import uuid4
from pydantic import BaseModel, Field
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
class CompletionRespChoice(BaseModel):
"""Represents a single choice in a completion response."""
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
logprobs: Optional[LogProbs] = None
text: str
# Inherited from common request
class CompletionRequest(CommonCompletionRequest):
# Prompt can also contain token ids, but that's out of scope for this project.
"""Represents a completion request."""
# Prompt can also contain token ids, but that's out of scope
# for this project.
prompt: Union[str, List[str]]
class CompletionResponse(BaseModel):
"""Represents a completion response."""
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
choices: List[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))

View file

@ -1,25 +1,42 @@
from pydantic import BaseModel, Field
""" Lora types """
from time import time
from typing import Optional, List
from pydantic import BaseModel, Field
class LoraCard(BaseModel):
"""Represents a single Lora card."""
id: str = "test"
object: str = "lora"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
scaling: Optional[float] = None
class LoraList(BaseModel):
"""Represents a list of Lora cards."""
object: str = "list"
data: List[LoraCard] = Field(default_factory=list)
class LoraLoadInfo(BaseModel):
"""Represents a single Lora load info."""
name: str
scaling: Optional[float] = 1.0
class LoraLoadRequest(BaseModel):
"""Represents a Lora load request."""
loras: List[LoraLoadInfo]
class LoraLoadResponse(BaseModel):
"""Represents a Lora load response."""
success: List[str] = Field(default_factory=list)
failure: List[str] = Field(default_factory=list)

View file

@ -1,19 +1,29 @@
from pydantic import BaseModel, Field, ConfigDict
""" Contains model card types. """
from time import time
from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from gen_logging import LogConfig
class ModelCardParameters(BaseModel):
# Safe to do this since it's guaranteed to fetch a max seq len from model_container
"""Represents model card parameters."""
# Safe to do this since it's guaranteed to fetch a max seq len
# from model_container
max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
draft: Optional['ModelCard'] = None
draft: Optional["ModelCard"] = None
class ModelCard(BaseModel):
"""Represents a single model card."""
id: str = "test"
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
@ -21,26 +31,47 @@ class ModelCard(BaseModel):
logging: Optional[LogConfig] = None
parameters: Optional[ModelCardParameters] = None
class ModelList(BaseModel):
"""Represents a list of model cards."""
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
draft_model_name: str
draft_rope_scale: Optional[float] = 1.0
draft_rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None)
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present", default=None
)
# TODO: Unify this with ModelCardParams
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""
name: str
# Max seq len is fetched from config.json of the model by default
max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None)
override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None)
max_seq_len: Optional[int] = Field(
description="Leave this blank to use the model's base sequence length",
default=None,
)
override_base_seq_len: Optional[int] = Field(
description=(
"Overrides the model's base sequence length. " "Leave blank if unsure"
),
default=None,
)
gpu_split_auto: Optional[bool] = True
gpu_split: Optional[List[float]] = Field(default_factory=list)
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None)
rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present", default=None
)
no_flash_attention: Optional[bool] = False
# low_mem: Optional[bool] = False
cache_mode: Optional[str] = "FP16"
@ -48,9 +79,12 @@ class ModelLoadRequest(BaseModel):
num_experts_per_token: Optional[int] = None
draft: Optional[DraftModelLoadRequest] = None
class ModelLoadResponse(BaseModel):
"""Represents a model load response."""
# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces = [])
model_config = ConfigDict(protected_namespaces=[])
model_type: str = "model"
module: int

View file

@ -1,30 +1,51 @@
from pydantic import BaseModel
""" Tokenization types """
from typing import List
from pydantic import BaseModel
class CommonTokenRequest(BaseModel):
"""Represents a common tokenization request."""
add_bos_token: bool = True
encode_special_tokens: bool = True
decode_special_tokens: bool = True
def get_params(self):
"""Get the parameters for tokenization."""
return {
"add_bos_token": self.add_bos_token,
"encode_special_tokens": self.encode_special_tokens,
"decode_special_tokens": self.decode_special_tokens
"decode_special_tokens": self.decode_special_tokens,
}
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: str
class TokenEncodeResponse(BaseModel):
"""Represents a tokenization response."""
tokens: List[int]
length: int
class TokenDecodeRequest(CommonTokenRequest):
""" " Represents a detokenization request."""
tokens: List[int]
class TokenDecodeResponse(BaseModel):
"""Represents a detokenization response."""
text: str
class TokenCountResponse(BaseModel):
length: int
"""Represents a token count response."""
length: int

View file

@ -1,103 +0,0 @@
import pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
ChatCompletionStreamChunk,
ChatCompletionResponse,
ChatCompletionStreamChoice
)
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from typing import Optional
from utils import unwrap
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
choice = CompletionRespChoice(
finish_reason = "Generated",
text = text
)
response = CompletionResponse(
choices = [choice],
model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
)
return response
def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
message = ChatCompletionMessage(
role = "assistant",
content = text
)
choice = ChatCompletionRespChoice(
finish_reason = "Generated",
message = message
)
response = ChatCompletionResponse(
choices = [choice],
model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
)
return response
def create_chat_completion_stream_chunk(const_id: str,
text: Optional[str] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None):
if finish_reason:
message = {}
else:
message = ChatCompletionMessage(
role = "assistant",
content = text
)
# The finish reason can be None
choice = ChatCompletionStreamChoice(
finish_reason = finish_reason,
delta = message
)
chunk = ChatCompletionStreamChunk(
id = const_id,
choices = [choice],
model = unwrap(model_name, "")
)
return chunk
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
# Convert the provided draft model path to a pathlib path for equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id = path.name)
model_card_list.data.append(model_card)
return model_card_list
def get_lora_list(lora_path: pathlib.Path):
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id = path.name)
lora_list.data.append(lora_card)
return lora_list

114
OAI/utils_oai.py Normal file
View file

@ -0,0 +1,114 @@
""" Utility functions for the OpenAI server. """
import pathlib
from typing import Optional
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
ChatCompletionStreamChunk,
ChatCompletionResponse,
ChatCompletionStreamChoice,
)
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from utils import unwrap
def create_completion_response(
text: str,
prompt_tokens: int,
completion_tokens: int,
model_name: Optional[str],
):
"""Create a completion response from the provided text."""
choice = CompletionRespChoice(finish_reason="Generated", text=text)
response = CompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def create_chat_completion_response(
text: str,
prompt_tokens: int,
completion_tokens: int,
model_name: Optional[str],
):
"""Create a chat completion response from the provided text."""
message = ChatCompletionMessage(role="assistant", content=text)
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)
response = ChatCompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def create_chat_completion_stream_chunk(
const_id: str,
text: Optional[str] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None,
):
"""Create a chat completion stream chunk from the provided text."""
if finish_reason:
message = {}
else:
message = ChatCompletionMessage(role="assistant", content=text)
# The finish reason can be None
choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message)
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
)
return chunk
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
"""Get the list of models from the provided path."""
# Convert the provided draft model path to a pathlib path for
# equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id=path.name)
model_card_list.data.append(model_card) # pylint: disable=no-member
return model_card_list
def get_lora_list(lora_path: pathlib.Path):
"""Get the list of Lora cards from the provided path."""
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card) # pylint: disable=no-member
return lora_list

120
auth.py
View file

@ -1,105 +1,125 @@
import secrets
import yaml
from fastapi import Header, HTTPException
from pydantic import BaseModel
from typing import Optional
"""
This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
import secrets
from typing import Optional
from fastapi import Header, HTTPException
from pydantic import BaseModel
import yaml
class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
"""
api_key: str
admin_key: str
def verify_key(self, test_key: str, key_type: str):
# Match statements are only available in python 3.10 and up
"""Verify if a given key matches the stored key."""
if key_type == "admin_key":
return test_key == self.admin_key
elif key_type == "api_key":
if key_type == "api_key":
# Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key
else:
return False
return False
AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
auth_keys: Optional[AuthKeys] = None
disable_auth: bool = False
def load_auth_keys(disable_from_config: bool):
global auth_keys
global disable_auth
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
global DISABLE_AUTH
disable_auth = disable_from_config
DISABLE_AUTH = disable_from_config
if disable_from_config:
print(
"!! Warning: Disabling authentication makes your instance vulnerable.",
"Set the \"disable_auth\" flag to False in config.yml if you want to share this",
"instance with others."
"!! Warning: Disabling authentication",
"makes your instance vulnerable.",
"Set the 'disable_auth' flag to False in config.yml",
"if you want to share this instance with others.",
)
return
try:
with open("api_tokens.yml", "r", encoding = 'utf8') as auth_file:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
auth_keys = AuthKeys.model_validate(auth_keys_dict)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except OSError:
new_auth_keys = AuthKeys(
api_key = secrets.token_hex(16),
admin_key = secrets.token_hex(16)
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
)
auth_keys = new_auth_keys
AUTH_KEYS = new_auth_keys
with open("api_tokens.yml", "w", encoding = "utf8") as auth_file:
yaml.safe_dump(auth_keys.model_dump(), auth_file, default_flow_style=False)
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
print(
f"Your API key is: {auth_keys.api_key}\n"
f"Your admin key is: {auth_keys.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml and restart the server. Have fun!"
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
)
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
"""Check if the API key is valid."""
# Allow request if auth is disabled
if disable_auth:
if DISABLE_AUTH:
return
if x_api_key:
if auth_keys.verify_key(x_api_key, "api_key"):
return x_api_key
else:
if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
raise HTTPException(401, "Invalid API key")
elif authorization:
split_key = authorization.split(" ")
return x_api_key
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid API key")
elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "api_key"):
return authorization
else:
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "api_key"
):
raise HTTPException(401, "Invalid API key")
else:
raise HTTPException(401, "Please provide an API key")
return authorization
raise HTTPException(401, "Please provide an API key")
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
"""Check if the admin key is valid."""
# Allow request if auth is disabled
if disable_auth:
if DISABLE_AUTH:
return
if x_admin_key:
if auth_keys.verify_key(x_admin_key, "admin_key"):
return x_admin_key
else:
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
raise HTTPException(401, "Invalid admin key")
elif authorization:
split_key = authorization.split(" ")
return x_admin_key
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:
raise HTTPException(401, "Invalid admin key")
elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "admin_key"):
return authorization
else:
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
split_key[1], "admin_key"
):
raise HTTPException(401, "Invalid admin key")
else:
raise HTTPException(401, "Please provide an admin key")
return authorization
raise HTTPException(401, "Please provide an admin key")

36
formatting.bat Normal file
View file

@ -0,0 +1,36 @@
@echo off
setlocal
::Change to script's directory
cd /d %~dp0
::Get tool versions
for /f "tokens=2" %%i in ('ruff --version') do set RUFF_VERSION="%%i"
::Check tool versions
call :tool_version_check "ruff" %RUFF_VERSION% "0.1.9"
::Format and lint files
call ruff format
call ruff check
echo tabbyAPI ruff lint and format: Done
::Check if any files were changed
git diff --quiet
if errorlevel 1 (
echo Reformatted files. Please review and stage the changes.
echo Changes not staged for commit:
echo.
git --no-pager diff --name-only
exit /b 1
)
exit /b 0
:tool_version_check
if not "%2"=="%3" (
echo Wrong %1 version installed: %3 is required, not %2.
exit /b 1
)
exit /b 0

53
formatting.sh Executable file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash formatting.sh
# # Commit changed files with message 'Run yapf and ruff'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
RUFF_VERSION=$(ruff --version | head -n 1 | awk '{print $2}')
# params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)"
# Format and lint all files
format_and_lint() {
ruff format
ruff check
}
# Call format command
format_and_lint
echo 'tabbyAPI ruff lint and format: Done'
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

View file

@ -1,31 +1,40 @@
"""
Functions for logging generation events.
"""
from typing import Dict
from pydantic import BaseModel
# Logging preference config
class LogConfig(BaseModel):
"""Logging preference config."""
prompt: bool = False
generation_params: bool = False
# Global reference to logging preferences
config = LogConfig()
# Wrapper to set the logging config for generations
# Global reference to logging preferences
CONFIG = LogConfig()
def update_from_dict(options_dict: Dict[str, bool]):
global config
"""Wrapper to set the logging config for generations"""
global CONFIG
# Force bools on the dict
for value in options_dict.values():
if value is None:
value = False
config = LogConfig.model_validate(options_dict)
CONFIG = LogConfig.model_validate(options_dict)
def broadcast_status():
"""Broadcasts the current logging status"""
enabled = []
if config.prompt:
if CONFIG.prompt:
enabled.append("prompts")
if config.generation_params:
if CONFIG.generation_params:
enabled.append("generation params")
if len(enabled) > 0:
@ -33,15 +42,20 @@ def broadcast_status():
else:
print("Generation logging is disabled")
# Logs generation parameters to console
def log_generation_params(**kwargs):
if config.generation_params:
"""Logs generation parameters to console."""
if CONFIG.generation_params:
print(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str):
if config.prompt:
"""Logs the prompt to console."""
if CONFIG.prompt:
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
def log_response(response: str):
if config.prompt:
"""Logs the response to console."""
if CONFIG.prompt:
print(f"Response: {response if response else 'Empty'}\n")

View file

@ -1,3 +1,4 @@
"""Generator functions for the tabbyAPI."""
import inspect
from asyncio import Semaphore
from functools import partialmethod
@ -5,8 +6,10 @@ from typing import AsyncGenerator
generate_semaphore = Semaphore(1)
# Async generation that blocks on a semaphore
async def generate_with_semaphore(generator: AsyncGenerator):
"""Generate with a semaphore."""
async with generate_semaphore:
if inspect.isasyncgenfunction:
async for result in generator():
@ -15,6 +18,7 @@ async def generate_with_semaphore(generator: AsyncGenerator):
for result in generator():
yield result
# Block a function with semaphore
async def call_with_semaphore(callback: partialmethod):
if inspect.iscoroutinefunction(callback):

363
main.py
View file

@ -1,14 +1,16 @@
import uvicorn
import yaml
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import pathlib
from asyncio import CancelledError
from fastapi import FastAPI, Request, HTTPException, Depends
from typing import Optional
from uuid import uuid4
import uvicorn
import yaml
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from functools import partial
from progress.bar import IncrementalBar
from typing import Optional
from uuid import uuid4
import gen_logging
from auth import check_admin_key, check_api_key, load_auth_keys
@ -17,19 +19,24 @@ from model import ModelContainer
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse, ModelCardParameters
from OAI.types.model import (
ModelCard,
ModelLoadRequest,
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
TokenDecodeRequest,
TokenDecodeResponse
TokenDecodeResponse,
)
from OAI.utils import (
from OAI.utils_oai import (
create_completion_response,
get_model_list,
get_lora_list,
create_chat_completion_response,
create_chat_completion_stream_chunk
create_chat_completion_response,
create_chat_completion_stream_chunk,
)
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
@ -37,13 +44,15 @@ from utils import get_generator_error, get_sse_packet, load_progress, unwrap
app = FastAPI()
# Globally scoped variables. Undefined until initalized in main
model_container: Optional[ModelContainer] = None
MODEL_CONTAINER: Optional[ModelContainer] = None
config: dict = {}
def _check_model_container():
if model_container is None or model_container.model is None:
if MODEL_CONTAINER is None or MODEL_CONTAINER.model is None:
raise HTTPException(400, "No models are loaded.")
# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
@ -53,10 +62,12 @@ app.add_middleware(
allow_headers=["*"],
)
# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = unwrap(config.get("model"), {})
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
@ -66,43 +77,53 @@ async def list_models():
models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo"))
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
# Currently loaded model endpoint
@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
@app.get(
"/v1/internal/model/info",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def get_current_model():
model_name = model_container.get_model_path().name
prompt_template = model_container.prompt_template
"""Returns the currently loaded model."""
model_name = MODEL_CONTAINER.get_model_path().name
prompt_template = MODEL_CONTAINER.prompt_template
model_card = ModelCard(
id = model_name,
parameters = ModelCardParameters(
rope_scale = model_container.config.scale_pos_emb,
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
prompt_template = prompt_template.name if prompt_template else None
id=model_name,
parameters=ModelCardParameters(
rope_scale=MODEL_CONTAINER.config.scale_pos_emb,
rope_alpha=MODEL_CONTAINER.config.scale_alpha_value,
max_seq_len=MODEL_CONTAINER.config.max_seq_len,
cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16",
prompt_template=prompt_template.name if prompt_template else None,
),
logging = gen_logging.config
logging=gen_logging.CONFIG,
)
if model_container.draft_config:
if MODEL_CONTAINER.draft_config:
draft_card = ModelCard(
id = model_container.get_model_path(True).name,
parameters = ModelCardParameters(
rope_scale = model_container.draft_config.scale_pos_emb,
rope_alpha = model_container.draft_config.scale_alpha_value,
max_seq_len = model_container.draft_config.max_seq_len
)
id=MODEL_CONTAINER.get_model_path(True).name,
parameters=ModelCardParameters(
rope_scale=MODEL_CONTAINER.draft_config.scale_pos_emb,
rope_alpha=MODEL_CONTAINER.draft_config.scale_alpha_value,
max_seq_len=MODEL_CONTAINER.draft_config.max_seq_len,
),
)
model_card.parameters.draft = draft_card
return model_card
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
model_config = unwrap(config.get("model"), {})
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
@ -112,12 +133,14 @@ async def list_draft_models():
return models
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(request: Request, data: ModelLoadRequest):
global model_container
"""Loads a model into the model container."""
global MODEL_CONTAINER
if model_container and model_container.model:
if MODEL_CONTAINER and MODEL_CONTAINER.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.")
if not data.name:
@ -129,32 +152,35 @@ async def load_model(request: Request, data: ModelLoadRequest):
load_data = data.model_dump()
# TODO: Add API exception if draft directory isn't found
draft_config = unwrap(model_config.get("draft"), {})
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
raise HTTPException(
400, "draft_model_name was not found inside the draft object."
)
load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models")
load_data["draft"]["draft_model_dir"] = unwrap(
draft_config.get("draft_model_dir"), "models"
)
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
model_container = ModelContainer(model_path.resolve(), False, **load_data)
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data)
async def generator():
global model_container
"""Generator for the loading process."""
model_type = "draft" if model_container.draft_config else "model"
load_status = model_container.load_gen(load_progress)
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
load_status = MODEL_CONTAINER.load_gen(load_progress)
try:
for (module, modules) in load_status:
for module, modules in load_status:
if await request.is_disconnected():
break
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
@ -163,13 +189,13 @@ async def load_model(request: Request, data: ModelLoadRequest):
model_type=model_type,
module=module,
modules=modules,
status="finished"
status="finished",
)
yield get_sse_packet(response.model_dump_json())
# Switch to model progress if the draft model is loaded
if model_container.draft_config:
if MODEL_CONTAINER.draft_config:
model_type = "model"
else:
loading_bar.next()
@ -178,29 +204,39 @@ async def load_model(request: Request, data: ModelLoadRequest):
model_type=model_type,
module=module,
modules=modules,
status="processing"
status="processing",
)
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.")
except Exception as e:
yield get_generator_error(str(e))
print(
"\nError: Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
except Exception as exc:
yield get_generator_error(str(exc))
return StreamingResponse(generator(), media_type="text/event-stream")
return StreamingResponse(generator(), media_type = "text/event-stream")
# Unload model endpoint
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
@app.get(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def unload_model():
global model_container
"""Unloads the currently loaded model."""
global MODEL_CONTAINER
MODEL_CONTAINER.unload()
MODEL_CONTAINER = None
model_container.unload()
model_container = None
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
@ -209,191 +245,240 @@ async def get_all_loras():
return loras
# Currently loaded loras endpoint
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def get_active_loras():
"""Returns the currently loaded loras."""
active_loras = LoraList(
data = list(map(
lambda lora: LoraCard(
id = pathlib.Path(lora.lora_path).parent.name,
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha
),
model_container.active_loras
data=list(
map(
lambda lora: LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
),
MODEL_CONTAINER.active_loras,
)
)
))
)
return active_loras
# Load lora endpoint
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
@app.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container."""
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
raise HTTPException(
400,
"A parent lora directory does not exist. Check your config.yml?",
)
# Clean-up existing loras if present
if len(model_container.active_loras) > 0:
model_container.unload(True)
if len(MODEL_CONTAINER.active_loras) > 0:
MODEL_CONTAINER.unload(True)
result = model_container.load_loras(lora_dir, **data.model_dump())
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success = unwrap(result.get("success"), []),
failure = unwrap(result.get("failure"), [])
success=unwrap(result.get("success"), []),
failure=unwrap(result.get("failure"), []),
)
# Unload lora endpoint
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
@app.get(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def unload_loras():
model_container.unload(True)
"""Unloads the currently loaded loras."""
MODEL_CONTAINER.unload(True)
# Encode tokens endpoint
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest):
raw_tokens = model_container.get_tokens(data.text, None, **data.get_params())
"""Encodes a string into tokens."""
raw_tokens = MODEL_CONTAINER.get_tokens(data.text, None, **data.get_params())
# Have to use this if check otherwise Torch's tensors error out with a boolean issue
# Have to use this if check otherwise Torch's tensors error out
# with a boolean issue
tokens = raw_tokens[0].tolist() if raw_tokens is not None else []
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response
# Decode tokens endpoint
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text = unwrap(message, ""))
"""Decodes tokens into a string."""
message = MODEL_CONTAINER.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response
# Completions endpoint
@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def generate_completion(request: Request, data: CompletionRequest):
model_path = model_container.get_model_path()
"""Generates a completion from a prompt."""
model_path = MODEL_CONTAINER.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
if data.stream:
async def generator():
"""Generator for the generation process."""
try:
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
for (part, prompt_tokens, completion_tokens) in new_generation:
new_generation = MODEL_CONTAINER.generate_gen(
data.prompt, **data.to_gen_params()
)
for part, prompt_tokens, completion_tokens in new_generation:
if await request.is_disconnected():
break
response = create_completion_response(part,
prompt_tokens,
completion_tokens,
model_path.name)
response = create_completion_response(
part, prompt_tokens, completion_tokens, model_path.name
)
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("Error: Completion request cancelled by user.")
except Exception as e:
yield get_generator_error(str(e))
except Exception as exc:
yield get_generator_error(str(exc))
return StreamingResponse(
generate_with_semaphore(generator),
media_type = "text/event-stream"
generate_with_semaphore(generator), media_type="text/event-stream"
)
else:
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(model_container.generate, data.prompt, **data.to_gen_params())
)
response = create_completion_response(response_text,
prompt_tokens,
completion_tokens,
model_path.name)
return response
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params())
)
response = create_completion_response(
response_text, prompt_tokens, completion_tokens, model_path.name
)
return response
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
if model_container.prompt_template is None:
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
"""Generates a chat completion from a prompt."""
if MODEL_CONTAINER.prompt_template is None:
return HTTPException(
422,
"This endpoint is disabled because a prompt template is not set.",
)
model_path = model_container.get_model_path()
model_path = MODEL_CONTAINER.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
try:
special_tokens_dict = model_container.get_special_tokens(
special_tokens_dict = MODEL_CONTAINER.get_special_tokens(
unwrap(data.add_bos_token, True),
unwrap(data.ban_eos_token, False)
unwrap(data.ban_eos_token, False),
)
prompt = get_prompt_from_template(
data.messages,
model_container.prompt_template,
MODEL_CONTAINER.prompt_template,
data.add_generation_prompt,
special_tokens_dict,
)
except KeyError:
return HTTPException(
400,
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
"Could not find a Conversation from prompt template "
f"'{MODEL_CONTAINER.prompt_template.name}'. "
"Check your spelling?",
)
if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"
async def generator():
"""Generator for the generation process."""
try:
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
for (part, _, _) in new_generation:
new_generation = MODEL_CONTAINER.generate_gen(
prompt, **data.to_gen_params()
)
for part, _, _ in new_generation:
if await request.is_disconnected():
break
response = create_chat_completion_stream_chunk(
const_id,
part,
model_path.name
const_id, part, model_path.name
)
yield get_sse_packet(response.model_dump_json())
# Yield a finish response on successful generation
finish_response = create_chat_completion_stream_chunk(
const_id,
finish_reason = "stop"
const_id, finish_reason="stop"
)
yield get_sse_packet(finish_response.model_dump_json())
except CancelledError:
print("Error: Chat completion cancelled by user.")
except Exception as e:
yield get_generator_error(str(e))
except Exception as exc:
yield get_generator_error(str(exc))
return StreamingResponse(
generate_with_semaphore(generator),
media_type = "text/event-stream"
generate_with_semaphore(generator), media_type="text/event-stream"
)
else:
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(model_container.generate, prompt, **data.to_gen_params())
)
response = create_chat_completion_response(response_text,
prompt_tokens,
completion_tokens,
model_path.name)
return response
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(MODEL_CONTAINER.generate, prompt, **data.to_gen_params())
)
response = create_chat_completion_response(
response_text, prompt_tokens, completion_tokens, model_path.name
)
return response
if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function
try:
with open('config.yml', 'r', encoding = "utf8") as config_file:
with open("config.yml", "r", encoding="utf8") as config_file:
config = unwrap(yaml.safe_load(config_file), {})
except Exception as e:
except Exception as exc:
print(
"The YAML config couldn't load because of the following error:",
f"\n\n{e}",
"\n\nTabbyAPI will start anyway and not parse this config file."
f"\n\n{exc}",
"\n\nTabbyAPI will start anyway and not parse this config file.",
)
config = {}
@ -409,18 +494,18 @@ if __name__ == "__main__":
gen_logging.broadcast_status()
# If an initial model name is specified, create a container and load the model
# If an initial model name is specified, create a container
# and load the model
model_config = unwrap(config.get("model"), {})
if "model_name" in model_config:
# TODO: Move this to model_container
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path.resolve(), False, **model_config)
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config)
load_status = MODEL_CONTAINER.load_gen(load_progress)
for module, modules in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
@ -431,11 +516,11 @@ if __name__ == "__main__":
lora_config = unwrap(model_config.get("lora"), {})
if "loras" in lora_config:
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model_container.load_loras(lora_dir.resolve(), **lora_config)
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)
uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),
port=network_config.get("port", 5000),
log_level="debug"
log_level="debug",
)

399
model.py
View file

@ -1,29 +1,36 @@
"""The model container class for ExLlamaV2 models."""
import gc
import pathlib
import time
import torch
from exllamav2 import(
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
ExLlamaV2Lora
)
from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
ExLlamaV2Lora,
)
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file
from templating import (
PromptTemplate,
find_template_from_model,
get_template_from_model_json,
get_template_from_file,
)
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
class ModelContainer:
"""The model container class for ExLlamaV2 models."""
config: Optional[ExLlamaV2Config] = None
draft_config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None
@ -40,35 +47,51 @@ class ModelContainer:
active_loras: List[ExLlamaV2Lora] = []
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Create model container
Args:
model_dir (int): Model directory containing config.json, tokenizer.model etc.
model_dir (int): Model directory containing config.json,
tokenizer.model etc.
quiet (bool): Suppress console output
load_progress_callback (function, optional): A function to call for each module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
load_progress_callback (function, optional): A function to call for
each module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int,
loading_draft: bool)
**kwargs:
`cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16")
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
'prompt_template' (str): Manually sets the prompt template for this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller
batches. This limits the size of temporary buffers needed for the hidden state and attention
weights.
`cache_mode` (str): Sets cache mode, "FP16" or "FP8"
(defaulf: "FP16")
'max_seq_len' (int): Override model's default max sequence
length (default: 4096)
'rope_scale' (float): Set RoPE scaling factor for model
(default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
(default: 1.0)
'prompt_template' (str): Manually sets the prompt template for
this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model
(default: 2048)
Inferencing in chunks reduces overall VRAM overhead by
processing very long sequences in smaller batches. This
limits the size of temporary buffers needed for the hidden
state and attention weights.
'draft_model_dir' (str): Draft model directory
'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0)
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
By default, the draft model's alpha value is calculated automatically to scale to the size of the
'draft_rope_scale' (float): Set RoPE scaling factor for draft
model (default: 1.0)
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
model. By default, the draft model's alpha value is
calculated automatically to scale to the size of the
full model.
'lora_dir' (str): Lora directory
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
'lora_dir' (str): LoRA directory
'loras' (list[dict]): List of loras to be loaded, consisting of
'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across
available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some)
tensors, per device
'no_flash_attn' (bool): Turns off flash attention
(increases vram usage) (default: False)
"""
self.quiet = quiet
@ -90,7 +113,8 @@ class ModelContainer:
if override_base_seq_len:
self.config.max_seq_len = override_base_seq_len
# Grab the base model's sequence length before overrides for rope calculations
# Grab the base model's sequence length before overrides for
# rope calculations
base_seq_len = self.config.max_seq_len
# Set the target seq len if present
@ -103,14 +127,14 @@ class ModelContainer:
# Automatically calculate rope alpha
self.config.scale_alpha_value = unwrap(
kwargs.get("rope_alpha"),
self.calculate_rope_alpha(base_seq_len)
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)
# Turn off flash attention?
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
# low_mem is currently broken in exllamav2. Don't use it until it's
# fixed.
"""
if "low_mem" in kwargs and kwargs["low_mem"]:
self.config.set_low_mem()
@ -119,7 +143,10 @@ class ModelContainer:
# Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
print(f"Attempting to load prompt template with name {prompt_template_name}")
print(
"Attempting to load prompt template with name",
{prompt_template_name},
)
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
@ -127,16 +154,17 @@ class ModelContainer:
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config"
"from_tokenizer_config",
)
# Try finding the chat template from the model's config.json
# TODO: This may not even be used with huggingface models, mark for removal.
# TODO: This may not even be used with huggingface models,
# mark for removal.
if self.prompt_template is None:
self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_config),
"chat_template",
"from_model_config"
"from_model_config",
)
# If that fails, attempt fetching from model name
@ -147,10 +175,13 @@ class ModelContainer:
# Catch all for template lookup errors
if self.prompt_template:
print(f"Using template {self.prompt_template.name} for chat completions.")
print(
f"Using template {self.prompt_template.name} for chat " "completions."
)
else:
print(
"Chat completions are disabled because a prompt template wasn't provided or auto-detected."
"Chat completions are disabled because a prompt template",
"wasn't provided or auto-detected.",
)
# Set num of experts per token if provided
@ -159,11 +190,16 @@ class ModelContainer:
if hasattr(self.config, "num_experts_per_token"):
self.config.num_experts_per_token = num_experts_override
else:
print(" !! Warning: Currently installed ExLlamaV2 does not support overriding MoE experts")
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support overriding MoE experts"
)
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
chunk_size = min(
unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len
)
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2
self.config.max_attn_size = chunk_size**2
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
@ -171,47 +207,63 @@ class ModelContainer:
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
print(
"A draft config was found but a model name was not given. "
"Please check your config.yml! Skipping draft load."
)
enable_draft = False
if enable_draft:
self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
)
# Automatically calculate draft rope alpha
self.draft_config.scale_alpha_value = unwrap(
draft_args.get("draft_rope_alpha"),
self.calculate_rope_alpha(self.draft_config.max_seq_len)
self.calculate_rope_alpha(self.draft_config.max_seq_len),
)
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.max_seq_len = self.config.max_seq_len
if "chunk_size" in kwargs:
self.draft_config.max_input_len = kwargs["chunk_size"]
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""
ratio = self.config.max_seq_len / base_seq_len
# Default to a 1 alpha if the sequence length is ever less than or equal to 1
alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
# Default to a 1 alpha if the sequence length is ever less
# than or equal to 1
if ratio <= 1.0:
alpha = 1
else:
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
return alpha
def get_model_path(self, is_draft: bool = False):
model_path = pathlib.Path(self.draft_config.model_dir if is_draft else self.config.model_dir)
"""Get the path for this model."""
model_path = pathlib.Path(
self.draft_config.model_dir if is_draft else self.config.model_dir
)
return model_path
def load(self, progress_callback = None):
def load(self, progress_callback=None):
"""
Load model
Args:
progress_callback (function, optional): A function to call for each module loaded. Prototype:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int)
"""
for _ in self.load_gen(progress_callback):
@ -231,25 +283,32 @@ class ModelContainer:
lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
print(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name)
continue
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling))
# FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
print("Lora successfully loaded.")
success.append(lora_name)
# Return success and failure names
return { 'success': success, 'failure': failure }
return {"success": success, "failure": failure}
def load_gen(self, progress_callback = None):
def load_gen(self, progress_callback=None):
"""
Load model, generator function
Args:
progress_callback (function, optional): A function to call for each module loaded. Prototype:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int)
"""
@ -262,13 +321,18 @@ class ModelContainer:
if not self.quiet:
print("Loading draft model: " + self.draft_config.model_dir)
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy = True)
reserve = [auto_split_reserve_bytes] + [0] * 16
yield from self.draft_model.load_autosplit_gen(self.draft_cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback)
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
yield from self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=reserve,
last_id_only=True,
callback_gen=progress_callback,
)
# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
self.draft_model.forward(input_ids, cache = self.cache, preprocess_only = True)
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Load model
self.model = ExLlamaV2(self.config)
@ -276,29 +340,41 @@ class ModelContainer:
print("Loading model: " + self.config.model_dir)
if not self.gpu_split_auto:
for value in self.model.load_gen(self.gpu_split, callback_gen = progress_callback):
for value in self.model.load_gen(
self.gpu_split, callback_gen=progress_callback
):
if isinstance(value, str):
yield value
if self.cache_fp8:
self.cache = ExLlamaV2Cache_8bit(self.model, lazy = self.gpu_split_auto)
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
else:
self.cache = ExLlamaV2Cache(self.model, lazy = self.gpu_split_auto)
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
if self.gpu_split_auto:
reserve = [auto_split_reserve_bytes] + [0] * 16
yield from self.model.load_autosplit_gen(self.cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback)
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
yield from self.model.load_autosplit_gen(
self.cache,
reserve_vram=reserve,
last_id_only=True,
callback_gen=progress_callback,
)
# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
self.model.forward(input_ids, cache = self.cache, preprocess_only = True)
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Create generator
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer, self.draft_model, self.draft_cache)
self.generator = ExLlamaV2StreamingGenerator(
self.model,
self.cache,
self.tokenizer,
self.draft_model,
self.draft_cache,
)
print("Model successfully loaded.")
def unload(self, loras_only: bool = False):
"""
Free all VRAM resources used by this model
@ -327,19 +403,24 @@ class ModelContainer:
gc.collect()
torch.cuda.empty_cache()
# Common function for token operations
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
"""Common function for token operations"""
if text:
# Assume token encoding
return self.tokenizer.encode(
text,
add_bos = unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)
if ids:
# Assume token decoding
ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
return None
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
return {
@ -350,13 +431,15 @@ class ModelContainer:
}
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs))
if generation:
response = "".join(map(lambda chunk: chunk[0], generation))
return response, generation[-1][1], generation[-1][2]
else:
return "", 0, 0
return "", 0, 0
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def generate_gen(self, prompt: str, **kwargs):
"""
Create generator function for prompt completion
@ -366,7 +449,8 @@ class ModelContainer:
**kwargs:
'token_healing' (bool): Use token healing (default: False)
'temperature' (float): Sampling temperature (default: 1.0)
'temperature_last' (bool): Apply temperature after all other samplers (default: False)
'temperature_last' (bool): Apply temperature after all other
samplers (default: False)
'top_k' (int): Sampling top-K (default: 0)
'top_p' (float): Sampling top-P (default: 1.0)
'min_p' (float): Sampling min-P (default: 0.0)
@ -375,19 +459,27 @@ class ModelContainer:
'mirostat' (bool): Use Mirostat (default: False)
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
'repetition_range' (int): Repetition penalty range (default: whole context)
'repetition_decay' (int): Repetition penalty range (default: same as range)
'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS])
'repetition_penalty' (float): Token repetition/presence penalty
(default: 1.15)
'repetition_range' (int): Repetition penalty range
(default: whole context)
'repetition_decay' (int): Repetition penalty range
(default: same as range)
'stop' (List[Union[str, int]]): List of stop strings/tokens to
end response (default: [EOS])
'max_tokens' (int): Max no. tokens in response (default: 150)
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None)
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the model's context when generating.
Rolls context window by the same amount if context length is exceeded to allow generating past
the models max_seq_len.
'add_bos_token' (bool): Adds the BOS token to the start of the
prompt (default: True)
'ban_eos_token' (bool): Bans the EOS token from generation
(default: False)
'logit_bias' (Dict[int, float]): Biases specific tokens to
either show up more or less (default: None)
'stream_interval' (float): Interval in seconds between each
output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the
model's context when generating. Rolls context window by
the same amount if context length is exceeded to allow
generating pastthe models max_seq_len.
"""
token_healing = unwrap(kwargs.get("token_healing"), False)
@ -399,17 +491,37 @@ class ModelContainer:
gen_settings = ExLlamaV2Sampler.Settings()
# Warn of unsupported settings if the setting is enabled
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
gen_settings, "mirostat"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"Mirostat sampling"
)
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "min_p"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support min-P sampling"
)
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "tfs"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"tail-free sampling (TFS)"
)
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
gen_settings, "temperature_last"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"temperature_last"
)
# Apply settings
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
@ -424,14 +536,24 @@ class ModelContainer:
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)
gen_settings.token_repetition_range = unwrap(
kwargs.get("repetition_range"), self.config.max_seq_len
)
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed fallback
# It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
if gen_settings.token_repetition_range <= 0:
fallback_decay = 0
else:
fallback_decay = gen_settings.token_repetition_range
gen_settings.token_repetition_decay = coalesce(
kwargs.get("repetition_decay"), fallback_decay, 0
)
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
@ -448,13 +570,13 @@ class ModelContainer:
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
max_tokens = max_tokens,
max_tokens=max_tokens,
**vars(gen_settings),
token_healing = token_healing,
add_bos_token = add_bos_token,
ban_eos_token = ban_eos_token,
stop_conditions = stop_conditions,
logit_bias = logit_bias
token_healing=token_healing,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
stop_conditions=stop_conditions,
logit_bias=logit_bias,
)
# Log prompt to console
@ -465,13 +587,17 @@ class ModelContainer:
# Create a vocab tensor if it doesn't exist for token biasing
if gen_settings.token_bias is None:
padding = -self.tokenizer.config.vocab_size % 32
gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float)
gen_settings.token_bias = torch.zeros(
(self.tokenizer.config.vocab_size + padding,),
dtype=torch.float,
)
# Map logits to the tensor with their biases
for token, bias in logit_bias.items():
gen_settings.token_bias[token] = bias
# Ban the EOS token if specified. If not, append to stop conditions as well.
# Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
@ -483,16 +609,15 @@ class ModelContainer:
# Tokenized context
ids = self.tokenizer.encode(
prompt,
add_bos = add_bos_token,
encode_special_tokens = True
prompt, add_bos=add_bos_token, encode_special_tokens=True
)
context_len = len(ids[0])
if context_len > self.config.max_seq_len:
print(
f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.",
"Generation is truncated and metrics may not be accurate."
f"WARNING: The context length {context_len} is greater than "
f"the max_seq_len {self.config.max_seq_len}.",
"Generation is truncated and metrics may not be accurate.",
)
prompt_tokens = ids.shape[-1]
@ -503,26 +628,32 @@ class ModelContainer:
start_time = time.time()
last_chunk_time = start_time
save_tokens = torch.empty((1, 0), dtype = torch.bool)
save_tokens = torch.empty((1, 0), dtype=torch.bool)
chunk_buffer = ""
chunk_tokens = 0
while True:
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim = - 1)
save_tokens = torch.empty((1, 0), dtype = torch.bool)
ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((1, 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow):]
active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras)
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)
# Generate
chunk, eos, tokens = self.generator.stream()
if token_healing:
ids[:, -1] = self.generator.sequence_ids[:, -2] # Extract healed token
# Extract healed token
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False
save_tokens = torch.cat((save_tokens, tokens), dim=-1)
@ -535,7 +666,9 @@ class ModelContainer:
now = time.time()
elapsed = now - last_chunk_time
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
if chunk_buffer != "" and (
elapsed > stream_interval or eos or generated_tokens == max_tokens
):
yield chunk_buffer, prompt_tokens, generated_tokens
full_response += chunk_buffer
chunk_buffer = ""
@ -549,12 +682,20 @@ class ModelContainer:
elapsed_time = last_chunk_time - start_time
initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
initial_response = (
f"Metrics: {generated_tokens} tokens generated in "
f"{round(elapsed_time, 2)} seconds"
)
itemization = []
extra_parts = []
# Add tokens per second
itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s")
tokens_per_second = (
"Indeterminate"
if elapsed_time == 0
else round(generated_tokens / elapsed_time, 2)
)
itemization.append(f"{tokens_per_second} T/s")
# Add context (original token count)
if ids is not None:
@ -564,4 +705,10 @@ class ModelContainer:
extra_parts.append("<-- Not accurate (truncated)")
# Print output
print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts))
print(
initial_response
+ " ("
+ ", ".join(itemization)
+ ") "
+ " ".join(extra_parts)
)

15
requirements-dev.txt Normal file
View file

@ -0,0 +1,15 @@
# formatting
ruff==0.1.9
## Implement below dependencies when support is added
# type checking
# mypy==0.991
# types-PyYAML
# types-requests
# types-setuptools
# testing
# pytest
# pytest-forked
# pytest-asyncio

View file

@ -1,46 +1,56 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from jinja2.sandbox import ImmutableSandboxedEnvironment
from packaging import version
from pydantic import BaseModel
from typing import Optional, Dict
# Small replication of AutoTokenizer's chat template system for efficiency
class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
name: str
template: str
def get_prompt_from_template(messages,
prompt_template: PromptTemplate,
add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None):
def get_prompt_from_template(
messages,
prompt_template: PromptTemplate,
add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None,
):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 or greater. "
f"Current version: {version('jinja2')}\n"
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"pip install --upgrade jinja2"
)
compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(
messages = messages,
add_generation_prompt = add_generation_prompt,
messages=messages,
add_generation_prompt=add_generation_prompt,
**special_tokens,
)
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
# Inspired from
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_template = jinja_env.from_string(template)
return jinja_template
# Find a matching template name from a model path
def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"):
@ -50,14 +60,16 @@ def find_template_from_model(model_path: pathlib.Path):
if template_name in model_name.lower():
return template_name
# Get a template from a jinja file
return None
def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding = "utf8") as raw_template:
with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
name=prompt_template_name, template=raw_template.read()
)
return None
@ -66,15 +78,12 @@ def get_template_from_file(prompt_template_name: str):
# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
"""Get a template from a JSON file. Requires a key and template name"""
if json_path.exists():
with open(json_path, "r", encoding = "utf8") as config_file:
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(
name = name,
template = chat_template
)
return PromptTemplate(name=name, template=chat_template)
return None

View file

@ -1,22 +1,49 @@
""" Test the model container. """
from model import ModelContainer
def progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
container = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/")
loader = container.load_gen(progress)
for (module, modules) in loader:
print(module, modules)
generator = container.generate_gen("Once upon a tim", token_healing = True)
for g in generator:
print(g, end = "")
def test_load_gen(model_path):
"""Test loading a model."""
container = ModelContainer(model_path)
loader = container.load_gen(progress)
for module, modules in loader:
print(module, modules)
container.unload()
del container
container.unload()
del container
mc = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/")
mc.load(progress)
def test_generate_gen(model_path):
"""Test generating from a model."""
container = ModelContainer(model_path)
generator = container.generate_gen("Once upon a tim", token_healing=True)
for chunk in generator:
print(chunk, end="")
container.unload()
del container
response = mc.generate("All work and no play makes turbo a derpy cat.\nAll work and no play makes turbo a derpy cat.\nAll", top_k = 1, max_new_tokens = 1000, stream_interval = 0.5)
print (response)
def test_generate(model_path):
"""Test generating from a model."""
model_container = ModelContainer(model_path)
model_container.load(progress)
prompt = (
"All work and no play makes turbo a derpy cat.\n"
"All work and no play makes turbo a derpy cat.\nAll"
)
response = model_container.generate(
prompt, top_k=1, max_new_tokens=1000, stream_interval=0.5
)
print(response)
if __name__ == "__main__":
MODEL1 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
MODEL2 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/"
test_load_gen(MODEL1)
test_generate_gen(MODEL1)
test_generate(MODEL2)

View file

@ -1,3 +1,4 @@
""" Test if the wheels are installed correctly. """
from importlib.metadata import version
from importlib.util import find_spec
@ -34,8 +35,12 @@ else:
print(
f"\nSuccessful imports: {', '.join(successful_packages)}",
f"\nErrored imports: {''.join(errored_packages)}"
f"\nErrored imports: {''.join(errored_packages)}",
)
if len(errored_packages) > 0:
print("\nIf packages are installed, but not found on this test, please check the wheel versions for the correct python version and CUDA version (if applicable).")
print(
"\nIf packages are installed, but not found on this test, please "
"check the wheel versions for the correct python version and CUDA "
"version (if applicable)."
)

View file

@ -1,43 +1,54 @@
"""Common utilities for the tabbyAPI"""
import traceback
from pydantic import BaseModel
from typing import Optional
# Wrapper callback for load progress
from pydantic import BaseModel
def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
# Common error types
class TabbyGeneratorErrorMessage(BaseModel):
"""Common error types."""
message: str
trace: Optional[str] = None
class TabbyGeneratorError(BaseModel):
"""Common error types."""
error: TabbyGeneratorErrorMessage
def get_generator_error(message: str):
"""Get a generator error."""
error_message = TabbyGeneratorErrorMessage(
message = message,
trace = traceback.format_exc()
message=message, trace=traceback.format_exc()
)
generator_error = TabbyGeneratorError(
error = error_message
)
generator_error = TabbyGeneratorError(error=error_message)
# Log and send the exception
print(f"\n{generator_error.error.trace}")
return get_sse_packet(generator_error.model_dump_json())
def get_sse_packet(json_data: str):
"""Get an SSE packet."""
return f"data: {json_data}\n\n"
# Unwrap function for Optionals
def unwrap(wrapped, default = None):
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None:
return default
else:
return wrapped
# Coalesce function for multiple unwraps
return wrapped
def coalesce(*args):
"""Coalesce function for multiple unwraps."""
return next((arg for arg in args if arg is not None), None)