feat: workflows for formatting/linting (#35)

* add github workflows for pylint and yapf

* yapf

* docstrings for auth

* fix auth.py

* fix generators.py

* fix gen_logging.py

* fix main.py

* fix model.py

* fix templating.py

* fix utils.py

* update formatting.sh to include subdirs for pylint

* fix model_test.py

* fix wheel_test.py

* rename utils to utils_oai

* fix OAI/utils_oai.py

* fix completion.py

* fix token.py

* fix lora.py

* fix common.py

* add pylintrc and fix model.py

* finish up pylint

* fix attribute error

* main.py formatting

* add formatting batch script

* Main: Remove unnecessary global

Linter suggestion.

Signed-off-by: kingbri <bdashore3@proton.me>

* switch to ruff

* Formatting + Linting: Add ruff.toml

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting + Linting: Switch scripts to use ruff

Also remove the file and recent file change functions from both
scripts.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format and lint

Signed-off-by: kingbri <bdashore3@proton.me>

* Scripts + Workflows: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Remove pylint flags

We use ruff now

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting: Line length is 88

Use the same value as Black.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Update to new line length rules.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
Co-authored-by: kingbri <bdashore3@proton.me>
This commit is contained in:
AlpinDale 2023-12-22 16:20:35 +00:00 committed by GitHub
parent a14abfe21c
commit fa47f51f85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1210 additions and 511 deletions

32
.github/workflows/ruff.yml vendored Normal file
View file

@ -0,0 +1,32 @@
name: ruff
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Format and show diff with ruff
run: |
ruff format --diff
- name: Lint code with ruff
run: |
ruff check

111
.ruff.toml Normal file
View file

@ -0,0 +1,111 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".git",
".git-rewrite",
".mypy_cache",
".pyenv",
".pytest_cache",
".ruff_cache",
".venv",
".vscode",
"__pypackages__",
"_build",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
# Same as Black.
line-length = 88
indent-width = 4
# Assume Python 3.10
target-version = "py310"
[lint]
# Enable preview
preview = true
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
# Enable flake8-bugbear (`B`) rules, in addition to the defaults.
select = ["E4", "E7", "E9", "F", "B"]
extend-select = [
"D419", # empty-docstring
"PLC2401", # non-ascii-name
"E501", # line-too-long
"W291", # trailing-whitespace
"PLC0414", # useless-import-alias
"E999", # syntax-error
"PLE0101", # return-in-init
"F706", # return-outside-function
"F704", # yield-outside-function
"PLE0116", # continue-in-finally
"PLE0117", # nonlocal-without-binding
"PLE0241", # duplicate-bases
"PLE0302", # unexpected-special-method-signature
"PLE0604", # invalid-all-object
"PLE0704", # misplaced-bare-raise
"PLE1205", # logging-too-many-args
"PLE1206", # logging-too-few-args
"PLE1307", # bad-string-format-type
"PLE1310", # bad-str-strip-call
"PLE1507", # invalid-envvar-value
"PLR0124", # comparison-with-itself
"PLR0202", # no-classmethod-decorator
"PLR0203", # no-staticmethod-decorator
"PLR0206", # property-with-parameters
"PLR1704", # redefined-argument-from-local
"PLR1711", # useless-return
"C416", # unnecessary-comprehension
"PLW0108", # unnecessary-lambda
"PLW0127", # self-assigning-variable
"PLW0129", # assert-on-string-literal
"PLW0602", # global-variable-not-assigned
"PLW0604", # global-at-module-level
"F401", # unused-import
"F841", # unused-variable
"E722", # bare-except
"PLW0711", # binary-op-exception
"PLW1501", # bad-open-mode
"PLW1508", # invalid-envvar-default
"PLW1509", # subprocess-popen-preexec-fn
]
ignore = [
"PLR6301", # no-self-use
"UP004", # useless-object-inheritance
"PLR0904", # too-many-public-methods
"PLR0911", # too-many-return-statements
"PLR0912", # too-many-branches
"PLR0913", # too-many-arguments
"PLR0914", # too-many-locals
"PLR0915", # too-many-statements
"PLR0916", # too-many-boolean-expressions
"PLW0120", # useless-else-on-loop
"PLW0406", # import-self
"PLW0603", # global-statement
"PLW1641", # eq-without-hash
]
# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = ["B"]
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

View file

@ -4,22 +4,26 @@ from pydantic import BaseModel, Field
from typing import Union, List, Optional, Dict from typing import Union, List, Optional, Dict
from OAI.types.common import UsageStats, CommonCompletionRequest from OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
class ChatCompletionRespChoice(BaseModel): class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices # Index is 0 since we aren't using multiple choices
index: int = 0 index: int = 0
finish_reason: str finish_reason: str
message: ChatCompletionMessage message: ChatCompletionMessage
class ChatCompletionStreamChoice(BaseModel): class ChatCompletionStreamChoice(BaseModel):
# Index is 0 since we aren't using multiple choices # Index is 0 since we aren't using multiple choices
index: int = 0 index: int = 0
finish_reason: Optional[str] finish_reason: Optional[str]
delta: Union[ChatCompletionMessage, dict] = {} delta: Union[ChatCompletionMessage, dict] = {}
# Inherited from common request # Inherited from common request
class ChatCompletionRequest(CommonCompletionRequest): class ChatCompletionRequest(CommonCompletionRequest):
# Messages # Messages
@ -28,6 +32,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True add_generation_prompt: Optional[bool] = True
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionRespChoice] choices: List[ChatCompletionRespChoice]
@ -36,6 +41,7 @@ class ChatCompletionResponse(BaseModel):
object: str = "chat.completion" object: str = "chat.completion"
usage: Optional[UsageStats] = None usage: Optional[UsageStats] = None
class ChatCompletionStreamChunk(BaseModel): class ChatCompletionStreamChunk(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionStreamChoice] choices: List[ChatCompletionStreamChoice]

View file

@ -1,30 +1,54 @@
from pydantic import BaseModel, Field, AliasChoices """ Common types for OAI. """
from typing import List, Dict, Optional, Union from typing import List, Dict, Optional, Union
from pydantic import BaseModel, Field, AliasChoices
from utils import unwrap from utils import unwrap
class LogProbs(BaseModel): class LogProbs(BaseModel):
"""Represents log probabilities."""
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[float] = Field(default_factory=list) token_logprobs: List[float] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Dict[str, float]] = Field(default_factory=list) top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel): class UsageStats(BaseModel):
"""Represents usage stats."""
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
total_tokens: int total_tokens: int
class CommonCompletionRequest(BaseModel): class CommonCompletionRequest(BaseModel):
"""Represents a common completion request."""
# Model information # Model information
# This parameter is not used, the loaded model is used instead # This parameter is not used, the loaded model is used instead
model: Optional[str] = None model: Optional[str] = None
# Extra OAI request stuff # Extra OAI request stuff
best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) best_of: Optional[int] = Field(
echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False) description="Not parsed. Only used for OAI compliance.", default=None
logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) )
n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1) echo: Optional[bool] = Field(
suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) description="Not parsed. Only used for OAI compliance.", default=False
user: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None) )
logprobs: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
suffix: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
user: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info # Generation info
# seed: Optional[int] = -1 # seed: Optional[int] = -1
@ -35,8 +59,9 @@ class CommonCompletionRequest(BaseModel):
max_tokens: Optional[int] = 150 max_tokens: Optional[int] = 150
# Aliased to repetition_penalty # Aliased to repetition_penalty
# TODO: Maybe make this an alias to rep pen frequency_penalty: Optional[float] = Field(
frequency_penalty: Optional[float] = Field(description = "Aliased to Repetition Penalty", default = 0.0) description="Aliased to Repetition Penalty", default=0.0
)
# Sampling params # Sampling params
token_healing: Optional[bool] = False token_healing: Optional[bool] = False
@ -58,18 +83,21 @@ class CommonCompletionRequest(BaseModel):
# Aliased variables # Aliased variables
repetition_range: Optional[int] = Field( repetition_range: Optional[int] = Field(
default = None, default=None,
validation_alias = AliasChoices('repetition_range', 'repetition_penalty_range') validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"),
) )
# Converts to internal generation parameters
def to_gen_params(self): def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings # Convert stop to an array of strings
if isinstance(self.stop, str): if isinstance(self.stop, str):
self.stop = [self.stop] self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined # Set repetition_penalty to frequency_penalty if repetition_penalty
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty: # isn't already defined
if (
self.repetition_penalty is None or self.repetition_penalty == 1.0
) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty self.repetition_penalty = self.frequency_penalty
return { return {

View file

@ -1,22 +1,35 @@
from uuid import uuid4 """ Completion API protocols """
from time import time from time import time
from pydantic import BaseModel, Field
from typing import List, Optional, Union from typing import List, Optional, Union
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest from uuid import uuid4
from pydantic import BaseModel, Field
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
class CompletionRespChoice(BaseModel): class CompletionRespChoice(BaseModel):
"""Represents a single choice in a completion response."""
# Index is 0 since we aren't using multiple choices # Index is 0 since we aren't using multiple choices
index: int = 0 index: int = 0
finish_reason: str finish_reason: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
text: str text: str
# Inherited from common request # Inherited from common request
class CompletionRequest(CommonCompletionRequest): class CompletionRequest(CommonCompletionRequest):
# Prompt can also contain token ids, but that's out of scope for this project. """Represents a completion request."""
# Prompt can also contain token ids, but that's out of scope
# for this project.
prompt: Union[str, List[str]] prompt: Union[str, List[str]]
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
"""Represents a completion response."""
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}") id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
choices: List[CompletionRespChoice] choices: List[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time())) created: int = Field(default_factory=lambda: int(time()))

View file

@ -1,25 +1,42 @@
from pydantic import BaseModel, Field """ Lora types """
from time import time from time import time
from typing import Optional, List from typing import Optional, List
from pydantic import BaseModel, Field
class LoraCard(BaseModel): class LoraCard(BaseModel):
"""Represents a single Lora card."""
id: str = "test" id: str = "test"
object: str = "lora" object: str = "lora"
created: int = Field(default_factory=lambda: int(time())) created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI" owned_by: str = "tabbyAPI"
scaling: Optional[float] = None scaling: Optional[float] = None
class LoraList(BaseModel): class LoraList(BaseModel):
"""Represents a list of Lora cards."""
object: str = "list" object: str = "list"
data: List[LoraCard] = Field(default_factory=list) data: List[LoraCard] = Field(default_factory=list)
class LoraLoadInfo(BaseModel): class LoraLoadInfo(BaseModel):
"""Represents a single Lora load info."""
name: str name: str
scaling: Optional[float] = 1.0 scaling: Optional[float] = 1.0
class LoraLoadRequest(BaseModel): class LoraLoadRequest(BaseModel):
"""Represents a Lora load request."""
loras: List[LoraLoadInfo] loras: List[LoraLoadInfo]
class LoraLoadResponse(BaseModel): class LoraLoadResponse(BaseModel):
"""Represents a Lora load response."""
success: List[str] = Field(default_factory=list) success: List[str] = Field(default_factory=list)
failure: List[str] = Field(default_factory=list) failure: List[str] = Field(default_factory=list)

View file

@ -1,19 +1,29 @@
from pydantic import BaseModel, Field, ConfigDict """ Contains model card types. """
from time import time from time import time
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from gen_logging import LogConfig from gen_logging import LogConfig
class ModelCardParameters(BaseModel): class ModelCardParameters(BaseModel):
# Safe to do this since it's guaranteed to fetch a max seq len from model_container """Represents model card parameters."""
# Safe to do this since it's guaranteed to fetch a max seq len
# from model_container
max_seq_len: Optional[int] = None max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0 rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0
cache_mode: Optional[str] = "FP16" cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None num_experts_per_token: Optional[int] = None
draft: Optional['ModelCard'] = None draft: Optional["ModelCard"] = None
class ModelCard(BaseModel): class ModelCard(BaseModel):
"""Represents a single model card."""
id: str = "test" id: str = "test"
object: str = "model" object: str = "model"
created: int = Field(default_factory=lambda: int(time())) created: int = Field(default_factory=lambda: int(time()))
@ -21,26 +31,47 @@ class ModelCard(BaseModel):
logging: Optional[LogConfig] = None logging: Optional[LogConfig] = None
parameters: Optional[ModelCardParameters] = None parameters: Optional[ModelCardParameters] = None
class ModelList(BaseModel): class ModelList(BaseModel):
"""Represents a list of model cards."""
object: str = "list" object: str = "list"
data: List[ModelCard] = Field(default_factory=list) data: List[ModelCard] = Field(default_factory=list)
class DraftModelLoadRequest(BaseModel): class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
draft_model_name: str draft_model_name: str
draft_rope_scale: Optional[float] = 1.0 draft_rope_scale: Optional[float] = 1.0
draft_rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None) draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present", default=None
)
# TODO: Unify this with ModelCardParams # TODO: Unify this with ModelCardParams
class ModelLoadRequest(BaseModel): class ModelLoadRequest(BaseModel):
"""Represents a model load request."""
name: str name: str
# Max seq len is fetched from config.json of the model by default # Max seq len is fetched from config.json of the model by default
max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None) max_seq_len: Optional[int] = Field(
override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None) description="Leave this blank to use the model's base sequence length",
default=None,
)
override_base_seq_len: Optional[int] = Field(
description=(
"Overrides the model's base sequence length. " "Leave blank if unsure"
),
default=None,
)
gpu_split_auto: Optional[bool] = True gpu_split_auto: Optional[bool] = True
gpu_split: Optional[List[float]] = Field(default_factory=list) gpu_split: Optional[List[float]] = Field(default_factory=list)
rope_scale: Optional[float] = 1.0 rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None) rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present", default=None
)
no_flash_attention: Optional[bool] = False no_flash_attention: Optional[bool] = False
# low_mem: Optional[bool] = False # low_mem: Optional[bool] = False
cache_mode: Optional[str] = "FP16" cache_mode: Optional[str] = "FP16"
@ -48,9 +79,12 @@ class ModelLoadRequest(BaseModel):
num_experts_per_token: Optional[int] = None num_experts_per_token: Optional[int] = None
draft: Optional[DraftModelLoadRequest] = None draft: Optional[DraftModelLoadRequest] = None
class ModelLoadResponse(BaseModel): class ModelLoadResponse(BaseModel):
"""Represents a model load response."""
# Avoids pydantic namespace warning # Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces = []) model_config = ConfigDict(protected_namespaces=[])
model_type: str = "model" model_type: str = "model"
module: int module: int

View file

@ -1,30 +1,51 @@
from pydantic import BaseModel """ Tokenization types """
from typing import List from typing import List
from pydantic import BaseModel
class CommonTokenRequest(BaseModel): class CommonTokenRequest(BaseModel):
"""Represents a common tokenization request."""
add_bos_token: bool = True add_bos_token: bool = True
encode_special_tokens: bool = True encode_special_tokens: bool = True
decode_special_tokens: bool = True decode_special_tokens: bool = True
def get_params(self): def get_params(self):
"""Get the parameters for tokenization."""
return { return {
"add_bos_token": self.add_bos_token, "add_bos_token": self.add_bos_token,
"encode_special_tokens": self.encode_special_tokens, "encode_special_tokens": self.encode_special_tokens,
"decode_special_tokens": self.decode_special_tokens "decode_special_tokens": self.decode_special_tokens,
} }
class TokenEncodeRequest(CommonTokenRequest): class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: str text: str
class TokenEncodeResponse(BaseModel): class TokenEncodeResponse(BaseModel):
"""Represents a tokenization response."""
tokens: List[int] tokens: List[int]
length: int length: int
class TokenDecodeRequest(CommonTokenRequest): class TokenDecodeRequest(CommonTokenRequest):
""" " Represents a detokenization request."""
tokens: List[int] tokens: List[int]
class TokenDecodeResponse(BaseModel): class TokenDecodeResponse(BaseModel):
"""Represents a detokenization response."""
text: str text: str
class TokenCountResponse(BaseModel): class TokenCountResponse(BaseModel):
"""Represents a token count response."""
length: int length: int

View file

@ -1,103 +0,0 @@
import pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
ChatCompletionStreamChunk,
ChatCompletionResponse,
ChatCompletionStreamChoice
)
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from typing import Optional
from utils import unwrap
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
choice = CompletionRespChoice(
finish_reason = "Generated",
text = text
)
response = CompletionResponse(
choices = [choice],
model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
)
return response
def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
message = ChatCompletionMessage(
role = "assistant",
content = text
)
choice = ChatCompletionRespChoice(
finish_reason = "Generated",
message = message
)
response = ChatCompletionResponse(
choices = [choice],
model = unwrap(model_name, ""),
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
)
return response
def create_chat_completion_stream_chunk(const_id: str,
text: Optional[str] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None):
if finish_reason:
message = {}
else:
message = ChatCompletionMessage(
role = "assistant",
content = text
)
# The finish reason can be None
choice = ChatCompletionStreamChoice(
finish_reason = finish_reason,
delta = message
)
chunk = ChatCompletionStreamChunk(
id = const_id,
choices = [choice],
model = unwrap(model_name, "")
)
return chunk
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
# Convert the provided draft model path to a pathlib path for equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id = path.name)
model_card_list.data.append(model_card)
return model_card_list
def get_lora_list(lora_path: pathlib.Path):
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id = path.name)
lora_list.data.append(lora_card)
return lora_list

114
OAI/utils_oai.py Normal file
View file

@ -0,0 +1,114 @@
""" Utility functions for the OpenAI server. """
import pathlib
from typing import Optional
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
ChatCompletionStreamChunk,
ChatCompletionResponse,
ChatCompletionStreamChoice,
)
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from utils import unwrap
def create_completion_response(
text: str,
prompt_tokens: int,
completion_tokens: int,
model_name: Optional[str],
):
"""Create a completion response from the provided text."""
choice = CompletionRespChoice(finish_reason="Generated", text=text)
response = CompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def create_chat_completion_response(
text: str,
prompt_tokens: int,
completion_tokens: int,
model_name: Optional[str],
):
"""Create a chat completion response from the provided text."""
message = ChatCompletionMessage(role="assistant", content=text)
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)
response = ChatCompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
def create_chat_completion_stream_chunk(
const_id: str,
text: Optional[str] = None,
model_name: Optional[str] = None,
finish_reason: Optional[str] = None,
):
"""Create a chat completion stream chunk from the provided text."""
if finish_reason:
message = {}
else:
message = ChatCompletionMessage(role="assistant", content=text)
# The finish reason can be None
choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message)
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
)
return chunk
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
"""Get the list of models from the provided path."""
# Convert the provided draft model path to a pathlib path for
# equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id=path.name)
model_card_list.data.append(model_card) # pylint: disable=no-member
return model_card_list
def get_lora_list(lora_path: pathlib.Path):
"""Get the list of Lora cards from the provided path."""
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card) # pylint: disable=no-member
return lora_list

114
auth.py
View file

@ -1,105 +1,125 @@
import secrets
import yaml
from fastapi import Header, HTTPException
from pydantic import BaseModel
from typing import Optional
""" """
This method of authorization is pretty insecure, but since TabbyAPI is a local This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine. application, it should be fine.
""" """
import secrets
from typing import Optional
from fastapi import Header, HTTPException
from pydantic import BaseModel
import yaml
class AuthKeys(BaseModel): class AuthKeys(BaseModel):
"""
This class represents the authentication keys for the application.
It contains two types of keys: 'api_key' and 'admin_key'.
The 'api_key' is used for general API calls, while the 'admin_key'
is used for administrative tasks. The class also provides a method
to verify if a given key matches the stored 'api_key' or 'admin_key'.
"""
api_key: str api_key: str
admin_key: str admin_key: str
def verify_key(self, test_key: str, key_type: str): def verify_key(self, test_key: str, key_type: str):
# Match statements are only available in python 3.10 and up """Verify if a given key matches the stored key."""
if key_type == "admin_key": if key_type == "admin_key":
return test_key == self.admin_key return test_key == self.admin_key
elif key_type == "api_key": if key_type == "api_key":
# Admin keys are valid for all API calls # Admin keys are valid for all API calls
return test_key == self.api_key or test_key == self.admin_key return test_key == self.api_key or test_key == self.admin_key
else:
return False return False
auth_keys: Optional[AuthKeys] = None
disable_auth: bool = False AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
def load_auth_keys(disable_from_config: bool): def load_auth_keys(disable_from_config: bool):
global auth_keys """Load the authentication keys from api_tokens.yml. If the file does not
global disable_auth exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
global DISABLE_AUTH
disable_auth = disable_from_config DISABLE_AUTH = disable_from_config
if disable_from_config: if disable_from_config:
print( print(
"!! Warning: Disabling authentication makes your instance vulnerable.", "!! Warning: Disabling authentication",
"Set the \"disable_auth\" flag to False in config.yml if you want to share this", "makes your instance vulnerable.",
"instance with others." "Set the 'disable_auth' flag to False in config.yml",
"if you want to share this instance with others.",
) )
return return
try: try:
with open("api_tokens.yml", "r", encoding = 'utf8') as auth_file: with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file) auth_keys_dict = yaml.safe_load(auth_file)
auth_keys = AuthKeys.model_validate(auth_keys_dict) AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except OSError: except OSError:
new_auth_keys = AuthKeys( new_auth_keys = AuthKeys(
api_key = secrets.token_hex(16), api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
admin_key = secrets.token_hex(16)
) )
auth_keys = new_auth_keys AUTH_KEYS = new_auth_keys
with open("api_tokens.yml", "w", encoding = "utf8") as auth_file: with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(auth_keys.model_dump(), auth_file, default_flow_style=False) yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
print( print(
f"Your API key is: {auth_keys.api_key}\n" f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {auth_keys.admin_key}\n\n" f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml and restart the server. Have fun!" "If these keys get compromised, make sure to delete api_tokens.yml "
"and restart the server. Have fun!"
) )
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)): def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
"""Check if the API key is valid."""
# Allow request if auth is disabled # Allow request if auth is disabled
if disable_auth: if DISABLE_AUTH:
return return
if x_api_key: if x_api_key:
if auth_keys.verify_key(x_api_key, "api_key"): if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
return x_api_key
else:
raise HTTPException(401, "Invalid API key") raise HTTPException(401, "Invalid API key")
elif authorization: return x_api_key
split_key = authorization.split(" ")
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2: if len(split_key) < 2:
raise HTTPException(401, "Invalid API key") raise HTTPException(401, "Invalid API key")
elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "api_key"): if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
return authorization split_key[1], "api_key"
else: ):
raise HTTPException(401, "Invalid API key") raise HTTPException(401, "Invalid API key")
else:
return authorization
raise HTTPException(401, "Please provide an API key") raise HTTPException(401, "Please provide an API key")
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)): def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
"""Check if the admin key is valid."""
# Allow request if auth is disabled # Allow request if auth is disabled
if disable_auth: if DISABLE_AUTH:
return return
if x_admin_key: if x_admin_key:
if auth_keys.verify_key(x_admin_key, "admin_key"): if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
return x_admin_key
else:
raise HTTPException(401, "Invalid admin key") raise HTTPException(401, "Invalid admin key")
elif authorization: return x_admin_key
split_key = authorization.split(" ")
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2: if len(split_key) < 2:
raise HTTPException(401, "Invalid admin key") raise HTTPException(401, "Invalid admin key")
elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "admin_key"): if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
return authorization split_key[1], "admin_key"
else: ):
raise HTTPException(401, "Invalid admin key") raise HTTPException(401, "Invalid admin key")
else: return authorization
raise HTTPException(401, "Please provide an admin key") raise HTTPException(401, "Please provide an admin key")

36
formatting.bat Normal file
View file

@ -0,0 +1,36 @@
@echo off
setlocal
::Change to script's directory
cd /d %~dp0
::Get tool versions
for /f "tokens=2" %%i in ('ruff --version') do set RUFF_VERSION="%%i"
::Check tool versions
call :tool_version_check "ruff" %RUFF_VERSION% "0.1.9"
::Format and lint files
call ruff format
call ruff check
echo tabbyAPI ruff lint and format: Done
::Check if any files were changed
git diff --quiet
if errorlevel 1 (
echo Reformatted files. Please review and stage the changes.
echo Changes not staged for commit:
echo.
git --no-pager diff --name-only
exit /b 1
)
exit /b 0
:tool_version_check
if not "%2"=="%3" (
echo Wrong %1 version installed: %3 is required, not %2.
exit /b 1
)
exit /b 0

53
formatting.sh Executable file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash formatting.sh
# # Commit changed files with message 'Run yapf and ruff'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
RUFF_VERSION=$(ruff --version | head -n 1 | awk '{print $2}')
# params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)"
# Format and lint all files
format_and_lint() {
ruff format
ruff check
}
# Call format command
format_and_lint
echo 'tabbyAPI ruff lint and format: Done'
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

View file

@ -1,31 +1,40 @@
"""
Functions for logging generation events.
"""
from typing import Dict from typing import Dict
from pydantic import BaseModel from pydantic import BaseModel
# Logging preference config
class LogConfig(BaseModel): class LogConfig(BaseModel):
"""Logging preference config."""
prompt: bool = False prompt: bool = False
generation_params: bool = False generation_params: bool = False
# Global reference to logging preferences
config = LogConfig()
# Wrapper to set the logging config for generations # Global reference to logging preferences
CONFIG = LogConfig()
def update_from_dict(options_dict: Dict[str, bool]): def update_from_dict(options_dict: Dict[str, bool]):
global config """Wrapper to set the logging config for generations"""
global CONFIG
# Force bools on the dict # Force bools on the dict
for value in options_dict.values(): for value in options_dict.values():
if value is None: if value is None:
value = False value = False
config = LogConfig.model_validate(options_dict) CONFIG = LogConfig.model_validate(options_dict)
def broadcast_status(): def broadcast_status():
"""Broadcasts the current logging status"""
enabled = [] enabled = []
if config.prompt: if CONFIG.prompt:
enabled.append("prompts") enabled.append("prompts")
if config.generation_params: if CONFIG.generation_params:
enabled.append("generation params") enabled.append("generation params")
if len(enabled) > 0: if len(enabled) > 0:
@ -33,15 +42,20 @@ def broadcast_status():
else: else:
print("Generation logging is disabled") print("Generation logging is disabled")
# Logs generation parameters to console
def log_generation_params(**kwargs): def log_generation_params(**kwargs):
if config.generation_params: """Logs generation parameters to console."""
if CONFIG.generation_params:
print(f"Generation options: {kwargs}\n") print(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str): def log_prompt(prompt: str):
if config.prompt: """Logs the prompt to console."""
if CONFIG.prompt:
print(f"Prompt: {prompt if prompt else 'Empty'}\n") print(f"Prompt: {prompt if prompt else 'Empty'}\n")
def log_response(response: str): def log_response(response: str):
if config.prompt: """Logs the response to console."""
if CONFIG.prompt:
print(f"Response: {response if response else 'Empty'}\n") print(f"Response: {response if response else 'Empty'}\n")

View file

@ -1,3 +1,4 @@
"""Generator functions for the tabbyAPI."""
import inspect import inspect
from asyncio import Semaphore from asyncio import Semaphore
from functools import partialmethod from functools import partialmethod
@ -5,8 +6,10 @@ from typing import AsyncGenerator
generate_semaphore = Semaphore(1) generate_semaphore = Semaphore(1)
# Async generation that blocks on a semaphore # Async generation that blocks on a semaphore
async def generate_with_semaphore(generator: AsyncGenerator): async def generate_with_semaphore(generator: AsyncGenerator):
"""Generate with a semaphore."""
async with generate_semaphore: async with generate_semaphore:
if inspect.isasyncgenfunction: if inspect.isasyncgenfunction:
async for result in generator(): async for result in generator():
@ -15,6 +18,7 @@ async def generate_with_semaphore(generator: AsyncGenerator):
for result in generator(): for result in generator():
yield result yield result
# Block a function with semaphore # Block a function with semaphore
async def call_with_semaphore(callback: partialmethod): async def call_with_semaphore(callback: partialmethod):
if inspect.iscoroutinefunction(callback): if inspect.iscoroutinefunction(callback):

355
main.py
View file

@ -1,14 +1,16 @@
import uvicorn """The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import yaml
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from fastapi import FastAPI, Request, HTTPException, Depends from typing import Optional
from uuid import uuid4
import uvicorn
import yaml
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from functools import partial from functools import partial
from progress.bar import IncrementalBar from progress.bar import IncrementalBar
from typing import Optional
from uuid import uuid4
import gen_logging import gen_logging
from auth import check_admin_key, check_api_key, load_auth_keys from auth import check_admin_key, check_api_key, load_auth_keys
@ -17,19 +19,24 @@ from model import ModelContainer
from OAI.types.completion import CompletionRequest from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse, ModelCardParameters from OAI.types.model import (
ModelCard,
ModelLoadRequest,
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.token import ( from OAI.types.token import (
TokenEncodeRequest, TokenEncodeRequest,
TokenEncodeResponse, TokenEncodeResponse,
TokenDecodeRequest, TokenDecodeRequest,
TokenDecodeResponse TokenDecodeResponse,
) )
from OAI.utils import ( from OAI.utils_oai import (
create_completion_response, create_completion_response,
get_model_list, get_model_list,
get_lora_list, get_lora_list,
create_chat_completion_response, create_chat_completion_response,
create_chat_completion_stream_chunk create_chat_completion_stream_chunk,
) )
from templating import get_prompt_from_template from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap from utils import get_generator_error, get_sse_packet, load_progress, unwrap
@ -37,13 +44,15 @@ from utils import get_generator_error, get_sse_packet, load_progress, unwrap
app = FastAPI() app = FastAPI()
# Globally scoped variables. Undefined until initalized in main # Globally scoped variables. Undefined until initalized in main
model_container: Optional[ModelContainer] = None MODEL_CONTAINER: Optional[ModelContainer] = None
config: dict = {} config: dict = {}
def _check_model_container(): def _check_model_container():
if model_container is None or model_container.model is None: if MODEL_CONTAINER is None or MODEL_CONTAINER.model is None:
raise HTTPException(400, "No models are loaded.") raise HTTPException(400, "No models are loaded.")
# ALlow CORS requests # ALlow CORS requests
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -53,10 +62,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Model list endpoint # Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models(): async def list_models():
"""Lists all models in the model directory."""
model_config = unwrap(config.get("model"), {}) model_config = unwrap(config.get("model"), {})
model_dir = unwrap(model_config.get("model_dir"), "models") model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir) model_path = pathlib.Path(model_dir)
@ -66,43 +77,53 @@ async def list_models():
models = get_model_list(model_path.resolve(), draft_model_dir) models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False): if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo")) models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models return models
# Currently loaded model endpoint # Currently loaded model endpoint
@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.get(
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) "/v1/model",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
@app.get(
"/v1/internal/model/info",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def get_current_model(): async def get_current_model():
model_name = model_container.get_model_path().name """Returns the currently loaded model."""
prompt_template = model_container.prompt_template model_name = MODEL_CONTAINER.get_model_path().name
prompt_template = MODEL_CONTAINER.prompt_template
model_card = ModelCard( model_card = ModelCard(
id = model_name, id=model_name,
parameters = ModelCardParameters( parameters=ModelCardParameters(
rope_scale = model_container.config.scale_pos_emb, rope_scale=MODEL_CONTAINER.config.scale_pos_emb,
rope_alpha = model_container.config.scale_alpha_value, rope_alpha=MODEL_CONTAINER.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len, max_seq_len=MODEL_CONTAINER.config.max_seq_len,
cache_mode = "FP8" if model_container.cache_fp8 else "FP16", cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16",
prompt_template = prompt_template.name if prompt_template else None prompt_template=prompt_template.name if prompt_template else None,
), ),
logging = gen_logging.config logging=gen_logging.CONFIG,
) )
if model_container.draft_config: if MODEL_CONTAINER.draft_config:
draft_card = ModelCard( draft_card = ModelCard(
id = model_container.get_model_path(True).name, id=MODEL_CONTAINER.get_model_path(True).name,
parameters = ModelCardParameters( parameters=ModelCardParameters(
rope_scale = model_container.draft_config.scale_pos_emb, rope_scale=MODEL_CONTAINER.draft_config.scale_pos_emb,
rope_alpha = model_container.draft_config.scale_alpha_value, rope_alpha=MODEL_CONTAINER.draft_config.scale_alpha_value,
max_seq_len = model_container.draft_config.max_seq_len max_seq_len=MODEL_CONTAINER.draft_config.max_seq_len,
) ),
) )
model_card.parameters.draft = draft_card model_card.parameters.draft = draft_card
return model_card return model_card
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models(): async def list_draft_models():
"""Lists all draft models in the model directory."""
model_config = unwrap(config.get("model"), {}) model_config = unwrap(config.get("model"), {})
draft_config = unwrap(model_config.get("draft"), {}) draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models") draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
@ -112,12 +133,14 @@ async def list_draft_models():
return models return models
# Load model endpoint # Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(request: Request, data: ModelLoadRequest): async def load_model(request: Request, data: ModelLoadRequest):
global model_container """Loads a model into the model container."""
global MODEL_CONTAINER
if model_container and model_container.model: if MODEL_CONTAINER and MODEL_CONTAINER.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.") raise HTTPException(400, "A model is already loaded! Please unload it first.")
if not data.name: if not data.name:
@ -129,32 +152,35 @@ async def load_model(request: Request, data: ModelLoadRequest):
load_data = data.model_dump() load_data = data.model_dump()
# TODO: Add API exception if draft directory isn't found
draft_config = unwrap(model_config.get("draft"), {}) draft_config = unwrap(model_config.get("draft"), {})
if data.draft: if data.draft:
if not data.draft.draft_model_name: if not data.draft.draft_model_name:
raise HTTPException(400, "draft_model_name was not found inside the draft object.") raise HTTPException(
400, "draft_model_name was not found inside the draft object."
)
load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models") load_data["draft"]["draft_model_dir"] = unwrap(
draft_config.get("draft_model_dir"), "models"
)
if not model_path.exists(): if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?") raise HTTPException(400, "model_path does not exist. Check model_name?")
model_container = ModelContainer(model_path.resolve(), False, **load_data) MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data)
async def generator(): async def generator():
global model_container """Generator for the loading process."""
model_type = "draft" if model_container.draft_config else "model" model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
load_status = model_container.load_gen(load_progress) load_status = MODEL_CONTAINER.load_gen(load_progress)
try: try:
for (module, modules) in load_status: for module, modules in load_status:
if await request.is_disconnected(): if await request.is_disconnected():
break break
if module == 0: if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules: elif module == modules:
loading_bar.next() loading_bar.next()
loading_bar.finish() loading_bar.finish()
@ -163,13 +189,13 @@ async def load_model(request: Request, data: ModelLoadRequest):
model_type=model_type, model_type=model_type,
module=module, module=module,
modules=modules, modules=modules,
status="finished" status="finished",
) )
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())
# Switch to model progress if the draft model is loaded # Switch to model progress if the draft model is loaded
if model_container.draft_config: if MODEL_CONTAINER.draft_config:
model_type = "model" model_type = "model"
else: else:
loading_bar.next() loading_bar.next()
@ -178,29 +204,39 @@ async def load_model(request: Request, data: ModelLoadRequest):
model_type=model_type, model_type=model_type,
module=module, module=module,
modules=modules, modules=modules,
status="processing" status="processing",
) )
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())
except CancelledError: except CancelledError:
print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.") print(
except Exception as e: "\nError: Model load cancelled by user. "
yield get_generator_error(str(e)) "Please make sure to run unload to free up resources."
)
except Exception as exc:
yield get_generator_error(str(exc))
return StreamingResponse(generator(), media_type="text/event-stream")
return StreamingResponse(generator(), media_type = "text/event-stream")
# Unload model endpoint # Unload model endpoint
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) @app.get(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def unload_model(): async def unload_model():
global model_container """Unloads the currently loaded model."""
global MODEL_CONTAINER
MODEL_CONTAINER.unload()
MODEL_CONTAINER = None
model_container.unload()
model_container = None
# Lora list endpoint # Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras(): async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
model_config = unwrap(config.get("model"), {}) model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {}) lora_config = unwrap(model_config.get("lora"), {})
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
@ -209,24 +245,36 @@ async def get_all_loras():
return loras return loras
# Currently loaded loras endpoint # Currently loaded loras endpoint
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def get_active_loras(): async def get_active_loras():
"""Returns the currently loaded loras."""
active_loras = LoraList( active_loras = LoraList(
data = list(map( data=list(
map(
lambda lora: LoraCard( lambda lora: LoraCard(
id = pathlib.Path(lora.lora_path).parent.name, id=pathlib.Path(lora.lora_path).parent.name,
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
), ),
model_container.active_loras MODEL_CONTAINER.active_loras,
)
)
) )
))
return active_loras return active_loras
# Load lora endpoint # Load lora endpoint
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) @app.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def load_lora(data: LoraLoadRequest): async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container."""
if not data.loras: if not data.loras:
raise HTTPException(400, "List of loras to load is not found.") raise HTTPException(400, "List of loras to load is not found.")
@ -234,166 +282,203 @@ async def load_lora(data: LoraLoadRequest):
lora_config = unwrap(model_config.get("lora"), {}) lora_config = unwrap(model_config.get("lora"), {})
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
if not lora_dir.exists(): if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?") raise HTTPException(
400,
# Clean-up existing loras if present "A parent lora directory does not exist. Check your config.yml?",
if len(model_container.active_loras) > 0:
model_container.unload(True)
result = model_container.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success = unwrap(result.get("success"), []),
failure = unwrap(result.get("failure"), [])
) )
# Clean-up existing loras if present
if len(MODEL_CONTAINER.active_loras) > 0:
MODEL_CONTAINER.unload(True)
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success=unwrap(result.get("success"), []),
failure=unwrap(result.get("failure"), []),
)
# Unload lora endpoint # Unload lora endpoint
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)]) @app.get(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
async def unload_loras(): async def unload_loras():
model_container.unload(True) """Unloads the currently loaded loras."""
MODEL_CONTAINER.unload(True)
# Encode tokens endpoint # Encode tokens endpoint
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest): async def encode_tokens(data: TokenEncodeRequest):
raw_tokens = model_container.get_tokens(data.text, None, **data.get_params()) """Encodes a string into tokens."""
raw_tokens = MODEL_CONTAINER.get_tokens(data.text, None, **data.get_params())
# Have to use this if check otherwise Torch's tensors error out with a boolean issue # Have to use this if check otherwise Torch's tensors error out
# with a boolean issue
tokens = raw_tokens[0].tolist() if raw_tokens is not None else [] tokens = raw_tokens[0].tolist() if raw_tokens is not None else []
response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response return response
# Decode tokens endpoint # Decode tokens endpoint
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest): async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params()) """Decodes tokens into a string."""
response = TokenDecodeResponse(text = unwrap(message, "")) message = MODEL_CONTAINER.get_tokens(None, data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response return response
# Completions endpoint # Completions endpoint
@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def generate_completion(request: Request, data: CompletionRequest): async def generate_completion(request: Request, data: CompletionRequest):
model_path = model_container.get_model_path() """Generates a completion from a prompt."""
model_path = MODEL_CONTAINER.get_model_path()
if isinstance(data.prompt, list): if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt) data.prompt = "\n".join(data.prompt)
if data.stream: if data.stream:
async def generator(): async def generator():
"""Generator for the generation process."""
try: try:
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) new_generation = MODEL_CONTAINER.generate_gen(
for (part, prompt_tokens, completion_tokens) in new_generation: data.prompt, **data.to_gen_params()
)
for part, prompt_tokens, completion_tokens in new_generation:
if await request.is_disconnected(): if await request.is_disconnected():
break break
response = create_completion_response(part, response = create_completion_response(
prompt_tokens, part, prompt_tokens, completion_tokens, model_path.name
completion_tokens, )
model_path.name)
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())
except CancelledError: except CancelledError:
print("Error: Completion request cancelled by user.") print("Error: Completion request cancelled by user.")
except Exception as e: except Exception as exc:
yield get_generator_error(str(e)) yield get_generator_error(str(exc))
return StreamingResponse( return StreamingResponse(
generate_with_semaphore(generator), generate_with_semaphore(generator), media_type="text/event-stream"
media_type = "text/event-stream"
) )
else:
response_text, prompt_tokens, completion_tokens = await call_with_semaphore( response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(model_container.generate, data.prompt, **data.to_gen_params()) partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params())
)
response = create_completion_response(
response_text, prompt_tokens, completion_tokens, model_path.name
) )
response = create_completion_response(response_text,
prompt_tokens,
completion_tokens,
model_path.name)
return response return response
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
if model_container.prompt_template is None:
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
model_path = model_container.get_model_path() # Chat completions endpoint
@app.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
)
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
"""Generates a chat completion from a prompt."""
if MODEL_CONTAINER.prompt_template is None:
return HTTPException(
422,
"This endpoint is disabled because a prompt template is not set.",
)
model_path = MODEL_CONTAINER.get_model_path()
if isinstance(data.messages, str): if isinstance(data.messages, str):
prompt = data.messages prompt = data.messages
else: else:
try: try:
special_tokens_dict = model_container.get_special_tokens( special_tokens_dict = MODEL_CONTAINER.get_special_tokens(
unwrap(data.add_bos_token, True), unwrap(data.add_bos_token, True),
unwrap(data.ban_eos_token, False) unwrap(data.ban_eos_token, False),
) )
prompt = get_prompt_from_template( prompt = get_prompt_from_template(
data.messages, data.messages,
model_container.prompt_template, MODEL_CONTAINER.prompt_template,
data.add_generation_prompt, data.add_generation_prompt,
special_tokens_dict, special_tokens_dict,
) )
except KeyError: except KeyError:
return HTTPException( return HTTPException(
400, 400,
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?" "Could not find a Conversation from prompt template "
f"'{MODEL_CONTAINER.prompt_template.name}'. "
"Check your spelling?",
) )
if data.stream: if data.stream:
const_id = f"chatcmpl-{uuid4().hex}" const_id = f"chatcmpl-{uuid4().hex}"
async def generator(): async def generator():
"""Generator for the generation process."""
try: try:
new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) new_generation = MODEL_CONTAINER.generate_gen(
for (part, _, _) in new_generation: prompt, **data.to_gen_params()
)
for part, _, _ in new_generation:
if await request.is_disconnected(): if await request.is_disconnected():
break break
response = create_chat_completion_stream_chunk( response = create_chat_completion_stream_chunk(
const_id, const_id, part, model_path.name
part,
model_path.name
) )
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())
# Yield a finish response on successful generation # Yield a finish response on successful generation
finish_response = create_chat_completion_stream_chunk( finish_response = create_chat_completion_stream_chunk(
const_id, const_id, finish_reason="stop"
finish_reason = "stop"
) )
yield get_sse_packet(finish_response.model_dump_json()) yield get_sse_packet(finish_response.model_dump_json())
except CancelledError: except CancelledError:
print("Error: Chat completion cancelled by user.") print("Error: Chat completion cancelled by user.")
except Exception as e: except Exception as exc:
yield get_generator_error(str(e)) yield get_generator_error(str(exc))
return StreamingResponse( return StreamingResponse(
generate_with_semaphore(generator), generate_with_semaphore(generator), media_type="text/event-stream"
media_type = "text/event-stream"
) )
else:
response_text, prompt_tokens, completion_tokens = await call_with_semaphore( response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
partial(model_container.generate, prompt, **data.to_gen_params()) partial(MODEL_CONTAINER.generate, prompt, **data.to_gen_params())
)
response = create_chat_completion_response(
response_text, prompt_tokens, completion_tokens, model_path.name
) )
response = create_chat_completion_response(response_text,
prompt_tokens,
completion_tokens,
model_path.name)
return response return response
if __name__ == "__main__": if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function # Load from YAML config. Possibly add a config -> kwargs conversion function
try: try:
with open('config.yml', 'r', encoding = "utf8") as config_file: with open("config.yml", "r", encoding="utf8") as config_file:
config = unwrap(yaml.safe_load(config_file), {}) config = unwrap(yaml.safe_load(config_file), {})
except Exception as e: except Exception as exc:
print( print(
"The YAML config couldn't load because of the following error:", "The YAML config couldn't load because of the following error:",
f"\n\n{e}", f"\n\n{exc}",
"\n\nTabbyAPI will start anyway and not parse this config file." "\n\nTabbyAPI will start anyway and not parse this config file.",
) )
config = {} config = {}
@ -409,18 +494,18 @@ if __name__ == "__main__":
gen_logging.broadcast_status() gen_logging.broadcast_status()
# If an initial model name is specified, create a container and load the model # If an initial model name is specified, create a container
# and load the model
model_config = unwrap(config.get("model"), {}) model_config = unwrap(config.get("model"), {})
if "model_name" in model_config: if "model_name" in model_config:
# TODO: Move this to model_container
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_config.get("model_name") model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path.resolve(), False, **model_config) MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config)
load_status = model_container.load_gen(load_progress) load_status = MODEL_CONTAINER.load_gen(load_progress)
for (module, modules) in load_status: for module, modules in load_status:
if module == 0: if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules: elif module == modules:
loading_bar.next() loading_bar.next()
loading_bar.finish() loading_bar.finish()
@ -431,11 +516,11 @@ if __name__ == "__main__":
lora_config = unwrap(model_config.get("lora"), {}) lora_config = unwrap(model_config.get("lora"), {})
if "loras" in lora_config: if "loras" in lora_config:
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model_container.load_loras(lora_dir.resolve(), **lora_config) MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)
uvicorn.run( uvicorn.run(
app, app,
host=network_config.get("host", "127.0.0.1"), host=network_config.get("host", "127.0.0.1"),
port=network_config.get("port", 5000), port=network_config.get("port", 5000),
log_level="debug" log_level="debug",
) )

395
model.py
View file

@ -1,29 +1,36 @@
"""The model container class for ExLlamaV2 models."""
import gc import gc
import pathlib import pathlib
import time import time
import torch import torch
from exllamav2 import( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Config, ExLlamaV2Config,
ExLlamaV2Cache, ExLlamaV2Cache,
ExLlamaV2Cache_8bit, ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer, ExLlamaV2Tokenizer,
ExLlamaV2Lora ExLlamaV2Lora,
)
from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
) )
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from gen_logging import log_generation_params, log_prompt, log_response from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union from typing import List, Optional, Union
from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file from templating import (
PromptTemplate,
find_template_from_model,
get_template_from_model_json,
get_template_from_file,
)
from utils import coalesce, unwrap from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split # Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2 AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
class ModelContainer: class ModelContainer:
"""The model container class for ExLlamaV2 models."""
config: Optional[ExLlamaV2Config] = None config: Optional[ExLlamaV2Config] = None
draft_config: Optional[ExLlamaV2Config] = None draft_config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None model: Optional[ExLlamaV2] = None
@ -40,35 +47,51 @@ class ModelContainer:
active_loras: List[ExLlamaV2Lora] = [] active_loras: List[ExLlamaV2Lora] = []
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs): def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
""" """
Create model container Create model container
Args: Args:
model_dir (int): Model directory containing config.json, tokenizer.model etc. model_dir (int): Model directory containing config.json,
tokenizer.model etc.
quiet (bool): Suppress console output quiet (bool): Suppress console output
load_progress_callback (function, optional): A function to call for each module loaded. Prototype: load_progress_callback (function, optional): A function to call for
def progress(loaded_modules: int, total_modules: int, loading_draft: bool) each module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int,
loading_draft: bool)
**kwargs: **kwargs:
`cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16") `cache_mode` (str): Sets cache mode, "FP16" or "FP8"
'max_seq_len' (int): Override model's default max sequence length (default: 4096) (defaulf: "FP16")
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0) 'max_seq_len' (int): Override model's default max sequence
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0) length (default: 4096)
'prompt_template' (str): Manually sets the prompt template for this model (default: None) 'rope_scale' (float): Set RoPE scaling factor for model
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048) (default: 1.0)
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller 'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
batches. This limits the size of temporary buffers needed for the hidden state and attention (default: 1.0)
weights. 'prompt_template' (str): Manually sets the prompt template for
this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model
(default: 2048)
Inferencing in chunks reduces overall VRAM overhead by
processing very long sequences in smaller batches. This
limits the size of temporary buffers needed for the hidden
state and attention weights.
'draft_model_dir' (str): Draft model directory 'draft_model_dir' (str): Draft model directory
'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0) 'draft_rope_scale' (float): Set RoPE scaling factor for draft
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model. model (default: 1.0)
By default, the draft model's alpha value is calculated automatically to scale to the size of the 'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
model. By default, the draft model's alpha value is
calculated automatically to scale to the size of the
full model. full model.
'lora_dir' (str): Lora directory 'lora_dir' (str): LoRA directory
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling' 'loras' (list[dict]): List of loras to be loaded, consisting of
'gpu_split_auto' (bool): Automatically split model across available devices (default: True) 'name' and 'scaling'
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device 'gpu_split_auto' (bool): Automatically split model across
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False) available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some)
tensors, per device
'no_flash_attn' (bool): Turns off flash attention
(increases vram usage) (default: False)
""" """
self.quiet = quiet self.quiet = quiet
@ -90,7 +113,8 @@ class ModelContainer:
if override_base_seq_len: if override_base_seq_len:
self.config.max_seq_len = override_base_seq_len self.config.max_seq_len = override_base_seq_len
# Grab the base model's sequence length before overrides for rope calculations # Grab the base model's sequence length before overrides for
# rope calculations
base_seq_len = self.config.max_seq_len base_seq_len = self.config.max_seq_len
# Set the target seq len if present # Set the target seq len if present
@ -103,14 +127,14 @@ class ModelContainer:
# Automatically calculate rope alpha # Automatically calculate rope alpha
self.config.scale_alpha_value = unwrap( self.config.scale_alpha_value = unwrap(
kwargs.get("rope_alpha"), kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
self.calculate_rope_alpha(base_seq_len)
) )
# Turn off flash attention? # Turn off flash attention?
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False) self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
# low_mem is currently broken in exllamav2. Don't use it until it's fixed. # low_mem is currently broken in exllamav2. Don't use it until it's
# fixed.
""" """
if "low_mem" in kwargs and kwargs["low_mem"]: if "low_mem" in kwargs and kwargs["low_mem"]:
self.config.set_low_mem() self.config.set_low_mem()
@ -119,7 +143,10 @@ class ModelContainer:
# Set prompt template override if provided # Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template") prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name: if prompt_template_name:
print(f"Attempting to load prompt template with name {prompt_template_name}") print(
"Attempting to load prompt template with name",
{prompt_template_name},
)
# Read the template # Read the template
self.prompt_template = get_template_from_file(prompt_template_name) self.prompt_template = get_template_from_file(prompt_template_name)
else: else:
@ -127,16 +154,17 @@ class ModelContainer:
self.prompt_template = get_template_from_model_json( self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json", pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template", "chat_template",
"from_tokenizer_config" "from_tokenizer_config",
) )
# Try finding the chat template from the model's config.json # Try finding the chat template from the model's config.json
# TODO: This may not even be used with huggingface models, mark for removal. # TODO: This may not even be used with huggingface models,
# mark for removal.
if self.prompt_template is None: if self.prompt_template is None:
self.prompt_template = get_template_from_model_json( self.prompt_template = get_template_from_model_json(
pathlib.Path(self.config.model_config), pathlib.Path(self.config.model_config),
"chat_template", "chat_template",
"from_model_config" "from_model_config",
) )
# If that fails, attempt fetching from model name # If that fails, attempt fetching from model name
@ -147,10 +175,13 @@ class ModelContainer:
# Catch all for template lookup errors # Catch all for template lookup errors
if self.prompt_template: if self.prompt_template:
print(f"Using template {self.prompt_template.name} for chat completions.") print(
f"Using template {self.prompt_template.name} for chat " "completions."
)
else: else:
print( print(
"Chat completions are disabled because a prompt template wasn't provided or auto-detected." "Chat completions are disabled because a prompt template",
"wasn't provided or auto-detected.",
) )
# Set num of experts per token if provided # Set num of experts per token if provided
@ -159,11 +190,16 @@ class ModelContainer:
if hasattr(self.config, "num_experts_per_token"): if hasattr(self.config, "num_experts_per_token"):
self.config.num_experts_per_token = num_experts_override self.config.num_experts_per_token = num_experts_override
else: else:
print(" !! Warning: Currently installed ExLlamaV2 does not support overriding MoE experts") print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support overriding MoE experts"
)
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len) chunk_size = min(
unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len
)
self.config.max_input_len = chunk_size self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2 self.config.max_attn_size = chunk_size**2
draft_args = unwrap(kwargs.get("draft"), {}) draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name") draft_model_name = draft_args.get("draft_model_name")
@ -171,23 +207,30 @@ class ModelContainer:
# Always disable draft if params are incorrectly configured # Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None: if draft_args and draft_model_name is None:
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.") print(
"A draft config was found but a model name was not given. "
"Please check your config.yml! Skipping draft load."
)
enable_draft = False enable_draft = False
if enable_draft: if enable_draft:
self.draft_config = ExLlamaV2Config() self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models")) draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare() self.draft_config.prepare()
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0) self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
)
# Automatically calculate draft rope alpha # Automatically calculate draft rope alpha
self.draft_config.scale_alpha_value = unwrap( self.draft_config.scale_alpha_value = unwrap(
draft_args.get("draft_rope_alpha"), draft_args.get("draft_rope_alpha"),
self.calculate_rope_alpha(self.draft_config.max_seq_len) self.calculate_rope_alpha(self.draft_config.max_seq_len),
) )
self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.max_seq_len = self.config.max_seq_len
@ -196,22 +239,31 @@ class ModelContainer:
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
def calculate_rope_alpha(self, base_seq_len): def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""
ratio = self.config.max_seq_len / base_seq_len ratio = self.config.max_seq_len / base_seq_len
# Default to a 1 alpha if the sequence length is ever less than or equal to 1 # Default to a 1 alpha if the sequence length is ever less
alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 # than or equal to 1
if ratio <= 1.0:
alpha = 1
else:
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
return alpha return alpha
def get_model_path(self, is_draft: bool = False): def get_model_path(self, is_draft: bool = False):
model_path = pathlib.Path(self.draft_config.model_dir if is_draft else self.config.model_dir) """Get the path for this model."""
model_path = pathlib.Path(
self.draft_config.model_dir if is_draft else self.config.model_dir
)
return model_path return model_path
def load(self, progress_callback = None): def load(self, progress_callback=None):
""" """
Load model Load model
Args: Args:
progress_callback (function, optional): A function to call for each module loaded. Prototype: progress_callback (function, optional): A function to call for each
module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int) def progress(loaded_modules: int, total_modules: int)
""" """
for _ in self.load_gen(progress_callback): for _ in self.load_gen(progress_callback):
@ -231,25 +283,32 @@ class ModelContainer:
lora_scaling = unwrap(lora.get("scaling"), 1.0) lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None: if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.") print(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name) failure.append(lora_name)
continue continue
print(f"Loading lora: {lora_name} at scaling {lora_scaling}") print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name lora_path = lora_directory / lora_name
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)) # FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
print("Lora successfully loaded.") print("Lora successfully loaded.")
success.append(lora_name) success.append(lora_name)
# Return success and failure names # Return success and failure names
return { 'success': success, 'failure': failure } return {"success": success, "failure": failure}
def load_gen(self, progress_callback = None): def load_gen(self, progress_callback=None):
""" """
Load model, generator function Load model, generator function
Args: Args:
progress_callback (function, optional): A function to call for each module loaded. Prototype: progress_callback (function, optional): A function to call for each
module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int) def progress(loaded_modules: int, total_modules: int)
""" """
@ -262,13 +321,18 @@ class ModelContainer:
if not self.quiet: if not self.quiet:
print("Loading draft model: " + self.draft_config.model_dir) print("Loading draft model: " + self.draft_config.model_dir)
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy = True) self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
reserve = [auto_split_reserve_bytes] + [0] * 16 reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
yield from self.draft_model.load_autosplit_gen(self.draft_cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback) yield from self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=reserve,
last_id_only=True,
callback_gen=progress_callback,
)
# Test VRAM allocation with a full-length forward pass # Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache = self.cache, preprocess_only = True) self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Load model # Load model
self.model = ExLlamaV2(self.config) self.model = ExLlamaV2(self.config)
@ -276,29 +340,41 @@ class ModelContainer:
print("Loading model: " + self.config.model_dir) print("Loading model: " + self.config.model_dir)
if not self.gpu_split_auto: if not self.gpu_split_auto:
for value in self.model.load_gen(self.gpu_split, callback_gen = progress_callback): for value in self.model.load_gen(
self.gpu_split, callback_gen=progress_callback
):
if isinstance(value, str): if isinstance(value, str):
yield value yield value
if self.cache_fp8: if self.cache_fp8:
self.cache = ExLlamaV2Cache_8bit(self.model, lazy = self.gpu_split_auto) self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
else: else:
self.cache = ExLlamaV2Cache(self.model, lazy = self.gpu_split_auto) self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
if self.gpu_split_auto: if self.gpu_split_auto:
reserve = [auto_split_reserve_bytes] + [0] * 16 reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
yield from self.model.load_autosplit_gen(self.cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback) yield from self.model.load_autosplit_gen(
self.cache,
reserve_vram=reserve,
last_id_only=True,
callback_gen=progress_callback,
)
# Test VRAM allocation with a full-length forward pass # Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache = self.cache, preprocess_only = True) self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Create generator # Create generator
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer, self.draft_model, self.draft_cache) self.generator = ExLlamaV2StreamingGenerator(
self.model,
self.cache,
self.tokenizer,
self.draft_model,
self.draft_cache,
)
print("Model successfully loaded.") print("Model successfully loaded.")
def unload(self, loras_only: bool = False): def unload(self, loras_only: bool = False):
""" """
Free all VRAM resources used by this model Free all VRAM resources used by this model
@ -327,19 +403,24 @@ class ModelContainer:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Common function for token operations
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs): def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
"""Common function for token operations"""
if text: if text:
# Assume token encoding # Assume token encoding
return self.tokenizer.encode( return self.tokenizer.encode(
text, text,
add_bos = unwrap(kwargs.get("add_bos_token"), True), add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True) encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
) )
if ids: if ids:
# Assume token decoding # Assume token decoding
ids = torch.tensor([ids]) ids = torch.tensor([ids])
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0] return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
return None
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool): def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
return { return {
@ -350,13 +431,15 @@ class ModelContainer:
} }
def generate(self, prompt: str, **kwargs): def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs)) generation = list(self.generate_gen(prompt, **kwargs))
if generation: if generation:
response = "".join(map(lambda chunk: chunk[0], generation)) response = "".join(map(lambda chunk: chunk[0], generation))
return response, generation[-1][1], generation[-1][2] return response, generation[-1][1], generation[-1][2]
else:
return "", 0, 0 return "", 0, 0
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def generate_gen(self, prompt: str, **kwargs): def generate_gen(self, prompt: str, **kwargs):
""" """
Create generator function for prompt completion Create generator function for prompt completion
@ -366,7 +449,8 @@ class ModelContainer:
**kwargs: **kwargs:
'token_healing' (bool): Use token healing (default: False) 'token_healing' (bool): Use token healing (default: False)
'temperature' (float): Sampling temperature (default: 1.0) 'temperature' (float): Sampling temperature (default: 1.0)
'temperature_last' (bool): Apply temperature after all other samplers (default: False) 'temperature_last' (bool): Apply temperature after all other
samplers (default: False)
'top_k' (int): Sampling top-K (default: 0) 'top_k' (int): Sampling top-K (default: 0)
'top_p' (float): Sampling top-P (default: 1.0) 'top_p' (float): Sampling top-P (default: 1.0)
'min_p' (float): Sampling min-P (default: 0.0) 'min_p' (float): Sampling min-P (default: 0.0)
@ -375,19 +459,27 @@ class ModelContainer:
'mirostat' (bool): Use Mirostat (default: False) 'mirostat' (bool): Use Mirostat (default: False)
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5) 'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1) 'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) 'repetition_penalty' (float): Token repetition/presence penalty
'repetition_range' (int): Repetition penalty range (default: whole context) (default: 1.15)
'repetition_decay' (int): Repetition penalty range (default: same as range) 'repetition_range' (int): Repetition penalty range
'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS]) (default: whole context)
'repetition_decay' (int): Repetition penalty range
(default: same as range)
'stop' (List[Union[str, int]]): List of stop strings/tokens to
end response (default: [EOS])
'max_tokens' (int): Max no. tokens in response (default: 150) 'max_tokens' (int): Max no. tokens in response (default: 150)
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True) 'add_bos_token' (bool): Adds the BOS token to the start of the
'ban_eos_token' (bool): Bans the EOS token from generation (default: False) prompt (default: True)
'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None) 'ban_eos_token' (bool): Bans the EOS token from generation
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate) (default: False)
'generate_window' (int): Space to reserve at the end of the model's context when generating. 'logit_bias' (Dict[int, float]): Biases specific tokens to
Rolls context window by the same amount if context length is exceeded to allow generating past either show up more or less (default: None)
the models max_seq_len. 'stream_interval' (float): Interval in seconds between each
output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the
model's context when generating. Rolls context window by
the same amount if context length is exceeded to allow
generating pastthe models max_seq_len.
""" """
token_healing = unwrap(kwargs.get("token_healing"), False) token_healing = unwrap(kwargs.get("token_healing"), False)
@ -399,17 +491,37 @@ class ModelContainer:
gen_settings = ExLlamaV2Sampler.Settings() gen_settings = ExLlamaV2Sampler.Settings()
# Warn of unsupported settings if the setting is enabled # Warn of unsupported settings if the setting is enabled
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"): if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") gen_settings, "mirostat"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"Mirostat sampling"
)
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling") gen_settings, "min_p"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support min-P sampling"
)
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") gen_settings, "tfs"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"tail-free sampling (TFS)"
)
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"): if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") gen_settings, "temperature_last"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"temperature_last"
)
# Apply settings # Apply settings
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
@ -424,14 +536,24 @@ class ModelContainer:
# Default tau and eta fallbacks don't matter if mirostat is off # Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0) gen_settings.token_repetition_penalty = unwrap(
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len) kwargs.get("repetition_penalty"), 1.0
)
gen_settings.token_repetition_range = unwrap(
kwargs.get("repetition_range"), self.config.max_seq_len
)
# Always make sure the fallback is 0 if range < 0 # Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed fallback # It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong # Always default to 0 if something goes wrong
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range if gen_settings.token_repetition_range <= 0:
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0) fallback_decay = 0
else:
fallback_decay = gen_settings.token_repetition_range
gen_settings.token_repetition_decay = coalesce(
kwargs.get("repetition_decay"), fallback_decay, 0
)
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True) add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
@ -448,13 +570,13 @@ class ModelContainer:
# Log generation options to console # Log generation options to console
# Some options are too large, so log the args instead # Some options are too large, so log the args instead
log_generation_params( log_generation_params(
max_tokens = max_tokens, max_tokens=max_tokens,
**vars(gen_settings), **vars(gen_settings),
token_healing = token_healing, token_healing=token_healing,
add_bos_token = add_bos_token, add_bos_token=add_bos_token,
ban_eos_token = ban_eos_token, ban_eos_token=ban_eos_token,
stop_conditions = stop_conditions, stop_conditions=stop_conditions,
logit_bias = logit_bias logit_bias=logit_bias,
) )
# Log prompt to console # Log prompt to console
@ -465,13 +587,17 @@ class ModelContainer:
# Create a vocab tensor if it doesn't exist for token biasing # Create a vocab tensor if it doesn't exist for token biasing
if gen_settings.token_bias is None: if gen_settings.token_bias is None:
padding = -self.tokenizer.config.vocab_size % 32 padding = -self.tokenizer.config.vocab_size % 32
gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float) gen_settings.token_bias = torch.zeros(
(self.tokenizer.config.vocab_size + padding,),
dtype=torch.float,
)
# Map logits to the tensor with their biases # Map logits to the tensor with their biases
for token, bias in logit_bias.items(): for token, bias in logit_bias.items():
gen_settings.token_bias[token] = bias gen_settings.token_bias[token] = bias
# Ban the EOS token if specified. If not, append to stop conditions as well. # Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array # Set this below logging to avoid polluting the stop strings array
if ban_eos_token: if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
@ -483,16 +609,15 @@ class ModelContainer:
# Tokenized context # Tokenized context
ids = self.tokenizer.encode( ids = self.tokenizer.encode(
prompt, prompt, add_bos=add_bos_token, encode_special_tokens=True
add_bos = add_bos_token,
encode_special_tokens = True
) )
context_len = len(ids[0]) context_len = len(ids[0])
if context_len > self.config.max_seq_len: if context_len > self.config.max_seq_len:
print( print(
f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.", f"WARNING: The context length {context_len} is greater than "
"Generation is truncated and metrics may not be accurate." f"the max_seq_len {self.config.max_seq_len}.",
"Generation is truncated and metrics may not be accurate.",
) )
prompt_tokens = ids.shape[-1] prompt_tokens = ids.shape[-1]
@ -503,26 +628,32 @@ class ModelContainer:
start_time = time.time() start_time = time.time()
last_chunk_time = start_time last_chunk_time = start_time
save_tokens = torch.empty((1, 0), dtype = torch.bool) save_tokens = torch.empty((1, 0), dtype=torch.bool)
chunk_buffer = "" chunk_buffer = ""
chunk_tokens = 0 chunk_tokens = 0
while True: while True:
# Ingest prompt # Ingest prompt
if chunk_tokens == 0: if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim = - 1) ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((1, 0), dtype = torch.bool) save_tokens = torch.empty((1, 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow):] active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras) self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)
# Generate # Generate
chunk, eos, tokens = self.generator.stream() chunk, eos, tokens = self.generator.stream()
if token_healing: if token_healing:
ids[:, -1] = self.generator.sequence_ids[:, -2] # Extract healed token # Extract healed token
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False token_healing = False
save_tokens = torch.cat((save_tokens, tokens), dim=-1) save_tokens = torch.cat((save_tokens, tokens), dim=-1)
@ -535,7 +666,9 @@ class ModelContainer:
now = time.time() now = time.time()
elapsed = now - last_chunk_time elapsed = now - last_chunk_time
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens): if chunk_buffer != "" and (
elapsed > stream_interval or eos or generated_tokens == max_tokens
):
yield chunk_buffer, prompt_tokens, generated_tokens yield chunk_buffer, prompt_tokens, generated_tokens
full_response += chunk_buffer full_response += chunk_buffer
chunk_buffer = "" chunk_buffer = ""
@ -549,12 +682,20 @@ class ModelContainer:
elapsed_time = last_chunk_time - start_time elapsed_time = last_chunk_time - start_time
initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds" initial_response = (
f"Metrics: {generated_tokens} tokens generated in "
f"{round(elapsed_time, 2)} seconds"
)
itemization = [] itemization = []
extra_parts = [] extra_parts = []
# Add tokens per second # Add tokens per second
itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s") tokens_per_second = (
"Indeterminate"
if elapsed_time == 0
else round(generated_tokens / elapsed_time, 2)
)
itemization.append(f"{tokens_per_second} T/s")
# Add context (original token count) # Add context (original token count)
if ids is not None: if ids is not None:
@ -564,4 +705,10 @@ class ModelContainer:
extra_parts.append("<-- Not accurate (truncated)") extra_parts.append("<-- Not accurate (truncated)")
# Print output # Print output
print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts)) print(
initial_response
+ " ("
+ ", ".join(itemization)
+ ") "
+ " ".join(extra_parts)
)

15
requirements-dev.txt Normal file
View file

@ -0,0 +1,15 @@
# formatting
ruff==0.1.9
## Implement below dependencies when support is added
# type checking
# mypy==0.991
# types-PyYAML
# types-requests
# types-setuptools
# testing
# pytest
# pytest-forked
# pytest-asyncio

View file

@ -1,46 +1,56 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import json import json
import pathlib import pathlib
from functools import lru_cache from functools import lru_cache
from importlib.metadata import version as package_version from importlib.metadata import version as package_version
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from packaging import version from packaging import version
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, Dict from typing import Optional, Dict
# Small replication of AutoTokenizer's chat template system for efficiency
class PromptTemplate(BaseModel): class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
name: str name: str
template: str template: str
def get_prompt_from_template(messages,
def get_prompt_from_template(
messages,
prompt_template: PromptTemplate, prompt_template: PromptTemplate,
add_generation_prompt: bool, add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None): special_tokens: Optional[Dict[str, str]] = None,
):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"): if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError( raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 or greater. " "Parsing these chat completion messages requires jinja2 3.0.0 "
f"Current version: {version('jinja2')}\n" f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: " "Please upgrade jinja by running the following command: "
"pip install --upgrade jinja2" "pip install --upgrade jinja2"
) )
compiled_template = _compile_template(prompt_template.template) compiled_template = _compile_template(prompt_template.template)
return compiled_template.render( return compiled_template.render(
messages = messages, messages=messages,
add_generation_prompt = add_generation_prompt, add_generation_prompt=add_generation_prompt,
**special_tokens, **special_tokens,
) )
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
# Inspired from
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache @lru_cache
def _compile_template(template: str): def _compile_template(template: str):
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True) jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_template = jinja_env.from_string(template) jinja_template = jinja_env.from_string(template)
return jinja_template return jinja_template
# Find a matching template name from a model path
def find_template_from_model(model_path: pathlib.Path): def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name model_name = model_path.name
template_directory = pathlib.Path("templates") template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"): for filepath in template_directory.glob("*.jinja"):
@ -50,14 +60,16 @@ def find_template_from_model(model_path: pathlib.Path):
if template_name in model_name.lower(): if template_name in model_name.lower():
return template_name return template_name
# Get a template from a jinja file return None
def get_template_from_file(prompt_template_name: str): def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists(): if template_path.exists():
with open(template_path, "r", encoding = "utf8") as raw_template: with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate( return PromptTemplate(
name = prompt_template_name, name=prompt_template_name, template=raw_template.read()
template = raw_template.read()
) )
return None return None
@ -66,15 +78,12 @@ def get_template_from_file(prompt_template_name: str):
# Get a template from a JSON file # Get a template from a JSON file
# Requires a key and template name # Requires a key and template name
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str): def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
"""Get a template from a JSON file. Requires a key and template name"""
if json_path.exists(): if json_path.exists():
with open(json_path, "r", encoding = "utf8") as config_file: with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file) model_config = json.load(config_file)
chat_template = model_config.get(key) chat_template = model_config.get(key)
if chat_template: if chat_template:
return PromptTemplate( return PromptTemplate(name=name, template=chat_template)
name = name,
template = chat_template
)
return None return None

View file

@ -1,22 +1,49 @@
""" Test the model container. """
from model import ModelContainer from model import ModelContainer
def progress(module, modules): def progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules yield module, modules
container = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/")
loader = container.load_gen(progress) def test_load_gen(model_path):
for (module, modules) in loader: """Test loading a model."""
container = ModelContainer(model_path)
loader = container.load_gen(progress)
for module, modules in loader:
print(module, modules) print(module, modules)
container.unload()
del container
generator = container.generate_gen("Once upon a tim", token_healing = True)
for g in generator:
print(g, end = "")
container.unload() def test_generate_gen(model_path):
del container """Test generating from a model."""
container = ModelContainer(model_path)
generator = container.generate_gen("Once upon a tim", token_healing=True)
for chunk in generator:
print(chunk, end="")
container.unload()
del container
mc = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/")
mc.load(progress)
response = mc.generate("All work and no play makes turbo a derpy cat.\nAll work and no play makes turbo a derpy cat.\nAll", top_k = 1, max_new_tokens = 1000, stream_interval = 0.5) def test_generate(model_path):
print (response) """Test generating from a model."""
model_container = ModelContainer(model_path)
model_container.load(progress)
prompt = (
"All work and no play makes turbo a derpy cat.\n"
"All work and no play makes turbo a derpy cat.\nAll"
)
response = model_container.generate(
prompt, top_k=1, max_new_tokens=1000, stream_interval=0.5
)
print(response)
if __name__ == "__main__":
MODEL1 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
MODEL2 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/"
test_load_gen(MODEL1)
test_generate_gen(MODEL1)
test_generate(MODEL2)

View file

@ -1,3 +1,4 @@
""" Test if the wheels are installed correctly. """
from importlib.metadata import version from importlib.metadata import version
from importlib.util import find_spec from importlib.util import find_spec
@ -34,8 +35,12 @@ else:
print( print(
f"\nSuccessful imports: {', '.join(successful_packages)}", f"\nSuccessful imports: {', '.join(successful_packages)}",
f"\nErrored imports: {''.join(errored_packages)}" f"\nErrored imports: {''.join(errored_packages)}",
) )
if len(errored_packages) > 0: if len(errored_packages) > 0:
print("\nIf packages are installed, but not found on this test, please check the wheel versions for the correct python version and CUDA version (if applicable).") print(
"\nIf packages are installed, but not found on this test, please "
"check the wheel versions for the correct python version and CUDA "
"version (if applicable)."
)

View file

@ -1,43 +1,54 @@
"""Common utilities for the tabbyAPI"""
import traceback import traceback
from pydantic import BaseModel
from typing import Optional from typing import Optional
# Wrapper callback for load progress from pydantic import BaseModel
def load_progress(module, modules): def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules yield module, modules
# Common error types
class TabbyGeneratorErrorMessage(BaseModel): class TabbyGeneratorErrorMessage(BaseModel):
"""Common error types."""
message: str message: str
trace: Optional[str] = None trace: Optional[str] = None
class TabbyGeneratorError(BaseModel): class TabbyGeneratorError(BaseModel):
"""Common error types."""
error: TabbyGeneratorErrorMessage error: TabbyGeneratorErrorMessage
def get_generator_error(message: str): def get_generator_error(message: str):
"""Get a generator error."""
error_message = TabbyGeneratorErrorMessage( error_message = TabbyGeneratorErrorMessage(
message = message, message=message, trace=traceback.format_exc()
trace = traceback.format_exc()
) )
generator_error = TabbyGeneratorError( generator_error = TabbyGeneratorError(error=error_message)
error = error_message
)
# Log and send the exception # Log and send the exception
print(f"\n{generator_error.error.trace}") print(f"\n{generator_error.error.trace}")
return get_sse_packet(generator_error.model_dump_json()) return get_sse_packet(generator_error.model_dump_json())
def get_sse_packet(json_data: str): def get_sse_packet(json_data: str):
"""Get an SSE packet."""
return f"data: {json_data}\n\n" return f"data: {json_data}\n\n"
# Unwrap function for Optionals
def unwrap(wrapped, default = None): def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None: if wrapped is None:
return default return default
else:
return wrapped return wrapped
# Coalesce function for multiple unwraps
def coalesce(*args): def coalesce(*args):
"""Coalesce function for multiple unwraps."""
return next((arg for arg in args if arg is not None), None) return next((arg for arg in args if arg is not None), None)