feat: workflows for formatting/linting (#35)
* add github workflows for pylint and yapf * yapf * docstrings for auth * fix auth.py * fix generators.py * fix gen_logging.py * fix main.py * fix model.py * fix templating.py * fix utils.py * update formatting.sh to include subdirs for pylint * fix model_test.py * fix wheel_test.py * rename utils to utils_oai * fix OAI/utils_oai.py * fix completion.py * fix token.py * fix lora.py * fix common.py * add pylintrc and fix model.py * finish up pylint * fix attribute error * main.py formatting * add formatting batch script * Main: Remove unnecessary global Linter suggestion. Signed-off-by: kingbri <bdashore3@proton.me> * switch to ruff * Formatting + Linting: Add ruff.toml Signed-off-by: kingbri <bdashore3@proton.me> * Formatting + Linting: Switch scripts to use ruff Also remove the file and recent file change functions from both scripts. Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format and lint Signed-off-by: kingbri <bdashore3@proton.me> * Scripts + Workflows: Format Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Remove pylint flags We use ruff now Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format Signed-off-by: kingbri <bdashore3@proton.me> * Formatting: Line length is 88 Use the same value as Black. Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format Update to new line length rules. Signed-off-by: kingbri <bdashore3@proton.me> --------- Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Co-authored-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
a14abfe21c
commit
fa47f51f85
22 changed files with 1210 additions and 511 deletions
32
.github/workflows/ruff.yml
vendored
Normal file
32
.github/workflows/ruff.yml
vendored
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
name: ruff
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ruff:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
- name: Format and show diff with ruff
|
||||||
|
run: |
|
||||||
|
ruff format --diff
|
||||||
|
- name: Lint code with ruff
|
||||||
|
run: |
|
||||||
|
ruff check
|
||||||
111
.ruff.toml
Normal file
111
.ruff.toml
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Exclude a variety of commonly ignored directories.
|
||||||
|
exclude = [
|
||||||
|
".git",
|
||||||
|
".git-rewrite",
|
||||||
|
".mypy_cache",
|
||||||
|
".pyenv",
|
||||||
|
".pytest_cache",
|
||||||
|
".ruff_cache",
|
||||||
|
".venv",
|
||||||
|
".vscode",
|
||||||
|
"__pypackages__",
|
||||||
|
"_build",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"node_modules",
|
||||||
|
"site-packages",
|
||||||
|
"venv",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Same as Black.
|
||||||
|
line-length = 88
|
||||||
|
indent-width = 4
|
||||||
|
|
||||||
|
# Assume Python 3.10
|
||||||
|
target-version = "py310"
|
||||||
|
|
||||||
|
[lint]
|
||||||
|
# Enable preview
|
||||||
|
preview = true
|
||||||
|
|
||||||
|
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
|
||||||
|
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
|
||||||
|
# McCabe complexity (`C901`) by default.
|
||||||
|
# Enable flake8-bugbear (`B`) rules, in addition to the defaults.
|
||||||
|
select = ["E4", "E7", "E9", "F", "B"]
|
||||||
|
extend-select = [
|
||||||
|
"D419", # empty-docstring
|
||||||
|
"PLC2401", # non-ascii-name
|
||||||
|
"E501", # line-too-long
|
||||||
|
"W291", # trailing-whitespace
|
||||||
|
"PLC0414", # useless-import-alias
|
||||||
|
"E999", # syntax-error
|
||||||
|
"PLE0101", # return-in-init
|
||||||
|
"F706", # return-outside-function
|
||||||
|
"F704", # yield-outside-function
|
||||||
|
"PLE0116", # continue-in-finally
|
||||||
|
"PLE0117", # nonlocal-without-binding
|
||||||
|
"PLE0241", # duplicate-bases
|
||||||
|
"PLE0302", # unexpected-special-method-signature
|
||||||
|
"PLE0604", # invalid-all-object
|
||||||
|
"PLE0704", # misplaced-bare-raise
|
||||||
|
"PLE1205", # logging-too-many-args
|
||||||
|
"PLE1206", # logging-too-few-args
|
||||||
|
"PLE1307", # bad-string-format-type
|
||||||
|
"PLE1310", # bad-str-strip-call
|
||||||
|
"PLE1507", # invalid-envvar-value
|
||||||
|
"PLR0124", # comparison-with-itself
|
||||||
|
"PLR0202", # no-classmethod-decorator
|
||||||
|
"PLR0203", # no-staticmethod-decorator
|
||||||
|
"PLR0206", # property-with-parameters
|
||||||
|
"PLR1704", # redefined-argument-from-local
|
||||||
|
"PLR1711", # useless-return
|
||||||
|
"C416", # unnecessary-comprehension
|
||||||
|
"PLW0108", # unnecessary-lambda
|
||||||
|
"PLW0127", # self-assigning-variable
|
||||||
|
"PLW0129", # assert-on-string-literal
|
||||||
|
"PLW0602", # global-variable-not-assigned
|
||||||
|
"PLW0604", # global-at-module-level
|
||||||
|
"F401", # unused-import
|
||||||
|
"F841", # unused-variable
|
||||||
|
"E722", # bare-except
|
||||||
|
"PLW0711", # binary-op-exception
|
||||||
|
"PLW1501", # bad-open-mode
|
||||||
|
"PLW1508", # invalid-envvar-default
|
||||||
|
"PLW1509", # subprocess-popen-preexec-fn
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"PLR6301", # no-self-use
|
||||||
|
"UP004", # useless-object-inheritance
|
||||||
|
"PLR0904", # too-many-public-methods
|
||||||
|
"PLR0911", # too-many-return-statements
|
||||||
|
"PLR0912", # too-many-branches
|
||||||
|
"PLR0913", # too-many-arguments
|
||||||
|
"PLR0914", # too-many-locals
|
||||||
|
"PLR0915", # too-many-statements
|
||||||
|
"PLR0916", # too-many-boolean-expressions
|
||||||
|
"PLW0120", # useless-else-on-loop
|
||||||
|
"PLW0406", # import-self
|
||||||
|
"PLW0603", # global-statement
|
||||||
|
"PLW1641", # eq-without-hash
|
||||||
|
]
|
||||||
|
|
||||||
|
# Allow fix for all enabled rules (when `--fix`) is provided.
|
||||||
|
fixable = ["ALL"]
|
||||||
|
unfixable = ["B"]
|
||||||
|
|
||||||
|
# Allow unused variables when underscore-prefixed.
|
||||||
|
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||||
|
|
||||||
|
[format]
|
||||||
|
# Like Black, use double quotes for strings.
|
||||||
|
quote-style = "double"
|
||||||
|
|
||||||
|
# Like Black, indent with spaces, rather than tabs.
|
||||||
|
indent-style = "space"
|
||||||
|
|
||||||
|
# Like Black, respect magic trailing commas.
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
|
||||||
|
# Like Black, automatically detect the appropriate line ending.
|
||||||
|
line-ending = "auto"
|
||||||
|
|
@ -4,22 +4,26 @@ from pydantic import BaseModel, Field
|
||||||
from typing import Union, List, Optional, Dict
|
from typing import Union, List, Optional, Dict
|
||||||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRespChoice(BaseModel):
|
class ChatCompletionRespChoice(BaseModel):
|
||||||
# Index is 0 since we aren't using multiple choices
|
# Index is 0 since we aren't using multiple choices
|
||||||
index: int = 0
|
index: int = 0
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
message: ChatCompletionMessage
|
message: ChatCompletionMessage
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamChoice(BaseModel):
|
class ChatCompletionStreamChoice(BaseModel):
|
||||||
# Index is 0 since we aren't using multiple choices
|
# Index is 0 since we aren't using multiple choices
|
||||||
index: int = 0
|
index: int = 0
|
||||||
finish_reason: Optional[str]
|
finish_reason: Optional[str]
|
||||||
delta: Union[ChatCompletionMessage, dict] = {}
|
delta: Union[ChatCompletionMessage, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
# Inherited from common request
|
# Inherited from common request
|
||||||
class ChatCompletionRequest(CommonCompletionRequest):
|
class ChatCompletionRequest(CommonCompletionRequest):
|
||||||
# Messages
|
# Messages
|
||||||
|
|
@ -28,6 +32,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
add_generation_prompt: Optional[bool] = True
|
add_generation_prompt: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||||
choices: List[ChatCompletionRespChoice]
|
choices: List[ChatCompletionRespChoice]
|
||||||
|
|
@ -36,6 +41,7 @@ class ChatCompletionResponse(BaseModel):
|
||||||
object: str = "chat.completion"
|
object: str = "chat.completion"
|
||||||
usage: Optional[UsageStats] = None
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamChunk(BaseModel):
|
class ChatCompletionStreamChunk(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||||
choices: List[ChatCompletionStreamChoice]
|
choices: List[ChatCompletionStreamChoice]
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,54 @@
|
||||||
from pydantic import BaseModel, Field, AliasChoices
|
""" Common types for OAI. """
|
||||||
from typing import List, Dict, Optional, Union
|
from typing import List, Dict, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, AliasChoices
|
||||||
|
|
||||||
from utils import unwrap
|
from utils import unwrap
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
|
"""Represents log probabilities."""
|
||||||
|
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[float] = Field(default_factory=list)
|
token_logprobs: List[float] = Field(default_factory=list)
|
||||||
tokens: List[str] = Field(default_factory=list)
|
tokens: List[str] = Field(default_factory=list)
|
||||||
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
|
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class UsageStats(BaseModel):
|
class UsageStats(BaseModel):
|
||||||
|
"""Represents usage stats."""
|
||||||
|
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class CommonCompletionRequest(BaseModel):
|
class CommonCompletionRequest(BaseModel):
|
||||||
|
"""Represents a common completion request."""
|
||||||
|
|
||||||
# Model information
|
# Model information
|
||||||
# This parameter is not used, the loaded model is used instead
|
# This parameter is not used, the loaded model is used instead
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
|
|
||||||
# Extra OAI request stuff
|
# Extra OAI request stuff
|
||||||
best_of: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
best_of: Optional[int] = Field(
|
||||||
echo: Optional[bool] = Field(description = "Not parsed. Only used for OAI compliance.", default = False)
|
description="Not parsed. Only used for OAI compliance.", default=None
|
||||||
logprobs: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
)
|
||||||
n: Optional[int] = Field(description = "Not parsed. Only used for OAI compliance.", default = 1)
|
echo: Optional[bool] = Field(
|
||||||
suffix: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
description="Not parsed. Only used for OAI compliance.", default=False
|
||||||
user: Optional[str] = Field(description = "Not parsed. Only used for OAI compliance.", default = None)
|
)
|
||||||
|
logprobs: Optional[int] = Field(
|
||||||
|
description="Not parsed. Only used for OAI compliance.", default=None
|
||||||
|
)
|
||||||
|
n: Optional[int] = Field(
|
||||||
|
description="Not parsed. Only used for OAI compliance.", default=1
|
||||||
|
)
|
||||||
|
suffix: Optional[str] = Field(
|
||||||
|
description="Not parsed. Only used for OAI compliance.", default=None
|
||||||
|
)
|
||||||
|
user: Optional[str] = Field(
|
||||||
|
description="Not parsed. Only used for OAI compliance.", default=None
|
||||||
|
)
|
||||||
|
|
||||||
# Generation info
|
# Generation info
|
||||||
# seed: Optional[int] = -1
|
# seed: Optional[int] = -1
|
||||||
|
|
@ -35,8 +59,9 @@ class CommonCompletionRequest(BaseModel):
|
||||||
max_tokens: Optional[int] = 150
|
max_tokens: Optional[int] = 150
|
||||||
|
|
||||||
# Aliased to repetition_penalty
|
# Aliased to repetition_penalty
|
||||||
# TODO: Maybe make this an alias to rep pen
|
frequency_penalty: Optional[float] = Field(
|
||||||
frequency_penalty: Optional[float] = Field(description = "Aliased to Repetition Penalty", default = 0.0)
|
description="Aliased to Repetition Penalty", default=0.0
|
||||||
|
)
|
||||||
|
|
||||||
# Sampling params
|
# Sampling params
|
||||||
token_healing: Optional[bool] = False
|
token_healing: Optional[bool] = False
|
||||||
|
|
@ -58,18 +83,21 @@ class CommonCompletionRequest(BaseModel):
|
||||||
|
|
||||||
# Aliased variables
|
# Aliased variables
|
||||||
repetition_range: Optional[int] = Field(
|
repetition_range: Optional[int] = Field(
|
||||||
default = None,
|
default=None,
|
||||||
validation_alias = AliasChoices('repetition_range', 'repetition_penalty_range')
|
validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Converts to internal generation parameters
|
|
||||||
def to_gen_params(self):
|
def to_gen_params(self):
|
||||||
|
"""Converts to internal generation parameters."""
|
||||||
# Convert stop to an array of strings
|
# Convert stop to an array of strings
|
||||||
if isinstance(self.stop, str):
|
if isinstance(self.stop, str):
|
||||||
self.stop = [self.stop]
|
self.stop = [self.stop]
|
||||||
|
|
||||||
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
|
# Set repetition_penalty to frequency_penalty if repetition_penalty
|
||||||
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
|
# isn't already defined
|
||||||
|
if (
|
||||||
|
self.repetition_penalty is None or self.repetition_penalty == 1.0
|
||||||
|
) and self.frequency_penalty:
|
||||||
self.repetition_penalty = self.frequency_penalty
|
self.repetition_penalty = self.frequency_penalty
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,35 @@
|
||||||
from uuid import uuid4
|
""" Completion API protocols """
|
||||||
from time import time
|
from time import time
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
|
||||||
|
|
||||||
|
|
||||||
class CompletionRespChoice(BaseModel):
|
class CompletionRespChoice(BaseModel):
|
||||||
|
"""Represents a single choice in a completion response."""
|
||||||
|
|
||||||
# Index is 0 since we aren't using multiple choices
|
# Index is 0 since we aren't using multiple choices
|
||||||
index: int = 0
|
index: int = 0
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
# Inherited from common request
|
# Inherited from common request
|
||||||
class CompletionRequest(CommonCompletionRequest):
|
class CompletionRequest(CommonCompletionRequest):
|
||||||
# Prompt can also contain token ids, but that's out of scope for this project.
|
"""Represents a completion request."""
|
||||||
|
|
||||||
|
# Prompt can also contain token ids, but that's out of scope
|
||||||
|
# for this project.
|
||||||
prompt: Union[str, List[str]]
|
prompt: Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
|
"""Represents a completion response."""
|
||||||
|
|
||||||
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
|
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
|
||||||
choices: List[CompletionRespChoice]
|
choices: List[CompletionRespChoice]
|
||||||
created: int = Field(default_factory=lambda: int(time()))
|
created: int = Field(default_factory=lambda: int(time()))
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,42 @@
|
||||||
from pydantic import BaseModel, Field
|
""" Lora types """
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class LoraCard(BaseModel):
|
class LoraCard(BaseModel):
|
||||||
|
"""Represents a single Lora card."""
|
||||||
|
|
||||||
id: str = "test"
|
id: str = "test"
|
||||||
object: str = "lora"
|
object: str = "lora"
|
||||||
created: int = Field(default_factory=lambda: int(time()))
|
created: int = Field(default_factory=lambda: int(time()))
|
||||||
owned_by: str = "tabbyAPI"
|
owned_by: str = "tabbyAPI"
|
||||||
scaling: Optional[float] = None
|
scaling: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class LoraList(BaseModel):
|
class LoraList(BaseModel):
|
||||||
|
"""Represents a list of Lora cards."""
|
||||||
|
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
data: List[LoraCard] = Field(default_factory=list)
|
data: List[LoraCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class LoraLoadInfo(BaseModel):
|
class LoraLoadInfo(BaseModel):
|
||||||
|
"""Represents a single Lora load info."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
scaling: Optional[float] = 1.0
|
scaling: Optional[float] = 1.0
|
||||||
|
|
||||||
|
|
||||||
class LoraLoadRequest(BaseModel):
|
class LoraLoadRequest(BaseModel):
|
||||||
|
"""Represents a Lora load request."""
|
||||||
|
|
||||||
loras: List[LoraLoadInfo]
|
loras: List[LoraLoadInfo]
|
||||||
|
|
||||||
|
|
||||||
class LoraLoadResponse(BaseModel):
|
class LoraLoadResponse(BaseModel):
|
||||||
|
"""Represents a Lora load response."""
|
||||||
|
|
||||||
success: List[str] = Field(default_factory=list)
|
success: List[str] = Field(default_factory=list)
|
||||||
failure: List[str] = Field(default_factory=list)
|
failure: List[str] = Field(default_factory=list)
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,29 @@
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
""" Contains model card types. """
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
from gen_logging import LogConfig
|
from gen_logging import LogConfig
|
||||||
|
|
||||||
|
|
||||||
class ModelCardParameters(BaseModel):
|
class ModelCardParameters(BaseModel):
|
||||||
# Safe to do this since it's guaranteed to fetch a max seq len from model_container
|
"""Represents model card parameters."""
|
||||||
|
|
||||||
|
# Safe to do this since it's guaranteed to fetch a max seq len
|
||||||
|
# from model_container
|
||||||
max_seq_len: Optional[int] = None
|
max_seq_len: Optional[int] = None
|
||||||
rope_scale: Optional[float] = 1.0
|
rope_scale: Optional[float] = 1.0
|
||||||
rope_alpha: Optional[float] = 1.0
|
rope_alpha: Optional[float] = 1.0
|
||||||
cache_mode: Optional[str] = "FP16"
|
cache_mode: Optional[str] = "FP16"
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
num_experts_per_token: Optional[int] = None
|
num_experts_per_token: Optional[int] = None
|
||||||
draft: Optional['ModelCard'] = None
|
draft: Optional["ModelCard"] = None
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
|
"""Represents a single model card."""
|
||||||
|
|
||||||
id: str = "test"
|
id: str = "test"
|
||||||
object: str = "model"
|
object: str = "model"
|
||||||
created: int = Field(default_factory=lambda: int(time()))
|
created: int = Field(default_factory=lambda: int(time()))
|
||||||
|
|
@ -21,26 +31,47 @@ class ModelCard(BaseModel):
|
||||||
logging: Optional[LogConfig] = None
|
logging: Optional[LogConfig] = None
|
||||||
parameters: Optional[ModelCardParameters] = None
|
parameters: Optional[ModelCardParameters] = None
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
class ModelList(BaseModel):
|
||||||
|
"""Represents a list of model cards."""
|
||||||
|
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
data: List[ModelCard] = Field(default_factory=list)
|
data: List[ModelCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DraftModelLoadRequest(BaseModel):
|
class DraftModelLoadRequest(BaseModel):
|
||||||
|
"""Represents a draft model load request."""
|
||||||
|
|
||||||
draft_model_name: str
|
draft_model_name: str
|
||||||
draft_rope_scale: Optional[float] = 1.0
|
draft_rope_scale: Optional[float] = 1.0
|
||||||
draft_rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None)
|
draft_rope_alpha: Optional[float] = Field(
|
||||||
|
description="Automatically calculated if not present", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Unify this with ModelCardParams
|
# TODO: Unify this with ModelCardParams
|
||||||
class ModelLoadRequest(BaseModel):
|
class ModelLoadRequest(BaseModel):
|
||||||
|
"""Represents a model load request."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
# Max seq len is fetched from config.json of the model by default
|
# Max seq len is fetched from config.json of the model by default
|
||||||
max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None)
|
max_seq_len: Optional[int] = Field(
|
||||||
override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None)
|
description="Leave this blank to use the model's base sequence length",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
override_base_seq_len: Optional[int] = Field(
|
||||||
|
description=(
|
||||||
|
"Overrides the model's base sequence length. " "Leave blank if unsure"
|
||||||
|
),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
gpu_split_auto: Optional[bool] = True
|
gpu_split_auto: Optional[bool] = True
|
||||||
gpu_split: Optional[List[float]] = Field(default_factory=list)
|
gpu_split: Optional[List[float]] = Field(default_factory=list)
|
||||||
rope_scale: Optional[float] = 1.0
|
rope_scale: Optional[float] = 1.0
|
||||||
rope_alpha: Optional[float] = Field(description = "Automatically calculated if not present", default = None)
|
rope_alpha: Optional[float] = Field(
|
||||||
|
description="Automatically calculated if not present", default=None
|
||||||
|
)
|
||||||
no_flash_attention: Optional[bool] = False
|
no_flash_attention: Optional[bool] = False
|
||||||
# low_mem: Optional[bool] = False
|
# low_mem: Optional[bool] = False
|
||||||
cache_mode: Optional[str] = "FP16"
|
cache_mode: Optional[str] = "FP16"
|
||||||
|
|
@ -48,9 +79,12 @@ class ModelLoadRequest(BaseModel):
|
||||||
num_experts_per_token: Optional[int] = None
|
num_experts_per_token: Optional[int] = None
|
||||||
draft: Optional[DraftModelLoadRequest] = None
|
draft: Optional[DraftModelLoadRequest] = None
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadResponse(BaseModel):
|
class ModelLoadResponse(BaseModel):
|
||||||
|
"""Represents a model load response."""
|
||||||
|
|
||||||
# Avoids pydantic namespace warning
|
# Avoids pydantic namespace warning
|
||||||
model_config = ConfigDict(protected_namespaces = [])
|
model_config = ConfigDict(protected_namespaces=[])
|
||||||
|
|
||||||
model_type: str = "model"
|
model_type: str = "model"
|
||||||
module: int
|
module: int
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,51 @@
|
||||||
from pydantic import BaseModel
|
""" Tokenization types """
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class CommonTokenRequest(BaseModel):
|
class CommonTokenRequest(BaseModel):
|
||||||
|
"""Represents a common tokenization request."""
|
||||||
|
|
||||||
add_bos_token: bool = True
|
add_bos_token: bool = True
|
||||||
encode_special_tokens: bool = True
|
encode_special_tokens: bool = True
|
||||||
decode_special_tokens: bool = True
|
decode_special_tokens: bool = True
|
||||||
|
|
||||||
def get_params(self):
|
def get_params(self):
|
||||||
|
"""Get the parameters for tokenization."""
|
||||||
return {
|
return {
|
||||||
"add_bos_token": self.add_bos_token,
|
"add_bos_token": self.add_bos_token,
|
||||||
"encode_special_tokens": self.encode_special_tokens,
|
"encode_special_tokens": self.encode_special_tokens,
|
||||||
"decode_special_tokens": self.decode_special_tokens
|
"decode_special_tokens": self.decode_special_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TokenEncodeRequest(CommonTokenRequest):
|
class TokenEncodeRequest(CommonTokenRequest):
|
||||||
|
"""Represents a tokenization request."""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
class TokenEncodeResponse(BaseModel):
|
class TokenEncodeResponse(BaseModel):
|
||||||
|
"""Represents a tokenization response."""
|
||||||
|
|
||||||
tokens: List[int]
|
tokens: List[int]
|
||||||
length: int
|
length: int
|
||||||
|
|
||||||
|
|
||||||
class TokenDecodeRequest(CommonTokenRequest):
|
class TokenDecodeRequest(CommonTokenRequest):
|
||||||
|
""" " Represents a detokenization request."""
|
||||||
|
|
||||||
tokens: List[int]
|
tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
class TokenDecodeResponse(BaseModel):
|
class TokenDecodeResponse(BaseModel):
|
||||||
|
"""Represents a detokenization response."""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
class TokenCountResponse(BaseModel):
|
class TokenCountResponse(BaseModel):
|
||||||
|
"""Represents a token count response."""
|
||||||
|
|
||||||
length: int
|
length: int
|
||||||
|
|
|
||||||
103
OAI/utils.py
103
OAI/utils.py
|
|
@ -1,103 +0,0 @@
|
||||||
import pathlib
|
|
||||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice
|
|
||||||
from OAI.types.chat_completion import (
|
|
||||||
ChatCompletionMessage,
|
|
||||||
ChatCompletionRespChoice,
|
|
||||||
ChatCompletionStreamChunk,
|
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionStreamChoice
|
|
||||||
)
|
|
||||||
from OAI.types.common import UsageStats
|
|
||||||
from OAI.types.lora import LoraList, LoraCard
|
|
||||||
from OAI.types.model import ModelList, ModelCard
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from utils import unwrap
|
|
||||||
|
|
||||||
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
|
||||||
choice = CompletionRespChoice(
|
|
||||||
finish_reason = "Generated",
|
|
||||||
text = text
|
|
||||||
)
|
|
||||||
|
|
||||||
response = CompletionResponse(
|
|
||||||
choices = [choice],
|
|
||||||
model = unwrap(model_name, ""),
|
|
||||||
usage = UsageStats(prompt_tokens = prompt_tokens,
|
|
||||||
completion_tokens = completion_tokens,
|
|
||||||
total_tokens = prompt_tokens + completion_tokens)
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
|
||||||
message = ChatCompletionMessage(
|
|
||||||
role = "assistant",
|
|
||||||
content = text
|
|
||||||
)
|
|
||||||
|
|
||||||
choice = ChatCompletionRespChoice(
|
|
||||||
finish_reason = "Generated",
|
|
||||||
message = message
|
|
||||||
)
|
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
|
||||||
choices = [choice],
|
|
||||||
model = unwrap(model_name, ""),
|
|
||||||
usage = UsageStats(prompt_tokens = prompt_tokens,
|
|
||||||
completion_tokens = completion_tokens,
|
|
||||||
total_tokens = prompt_tokens + completion_tokens)
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def create_chat_completion_stream_chunk(const_id: str,
|
|
||||||
text: Optional[str] = None,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
finish_reason: Optional[str] = None):
|
|
||||||
if finish_reason:
|
|
||||||
message = {}
|
|
||||||
else:
|
|
||||||
message = ChatCompletionMessage(
|
|
||||||
role = "assistant",
|
|
||||||
content = text
|
|
||||||
)
|
|
||||||
|
|
||||||
# The finish reason can be None
|
|
||||||
choice = ChatCompletionStreamChoice(
|
|
||||||
finish_reason = finish_reason,
|
|
||||||
delta = message
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk = ChatCompletionStreamChunk(
|
|
||||||
id = const_id,
|
|
||||||
choices = [choice],
|
|
||||||
model = unwrap(model_name, "")
|
|
||||||
)
|
|
||||||
|
|
||||||
return chunk
|
|
||||||
|
|
||||||
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
|
|
||||||
|
|
||||||
# Convert the provided draft model path to a pathlib path for equality comparisons
|
|
||||||
if draft_model_path:
|
|
||||||
draft_model_path = pathlib.Path(draft_model_path).resolve()
|
|
||||||
|
|
||||||
model_card_list = ModelList()
|
|
||||||
for path in model_path.iterdir():
|
|
||||||
|
|
||||||
# Don't include the draft models path
|
|
||||||
if path.is_dir() and path != draft_model_path:
|
|
||||||
model_card = ModelCard(id = path.name)
|
|
||||||
model_card_list.data.append(model_card)
|
|
||||||
|
|
||||||
return model_card_list
|
|
||||||
|
|
||||||
def get_lora_list(lora_path: pathlib.Path):
|
|
||||||
lora_list = LoraList()
|
|
||||||
for path in lora_path.iterdir():
|
|
||||||
if path.is_dir():
|
|
||||||
lora_card = LoraCard(id = path.name)
|
|
||||||
lora_list.data.append(lora_card)
|
|
||||||
|
|
||||||
return lora_list
|
|
||||||
114
OAI/utils_oai.py
Normal file
114
OAI/utils_oai.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
""" Utility functions for the OpenAI server. """
|
||||||
|
import pathlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from OAI.types.chat_completion import (
|
||||||
|
ChatCompletionMessage,
|
||||||
|
ChatCompletionRespChoice,
|
||||||
|
ChatCompletionStreamChunk,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionStreamChoice,
|
||||||
|
)
|
||||||
|
from OAI.types.completion import CompletionResponse, CompletionRespChoice
|
||||||
|
from OAI.types.common import UsageStats
|
||||||
|
from OAI.types.lora import LoraList, LoraCard
|
||||||
|
from OAI.types.model import ModelList, ModelCard
|
||||||
|
|
||||||
|
from utils import unwrap
|
||||||
|
|
||||||
|
|
||||||
|
def create_completion_response(
|
||||||
|
text: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
model_name: Optional[str],
|
||||||
|
):
|
||||||
|
"""Create a completion response from the provided text."""
|
||||||
|
choice = CompletionRespChoice(finish_reason="Generated", text=text)
|
||||||
|
|
||||||
|
response = CompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
model=unwrap(model_name, ""),
|
||||||
|
usage=UsageStats(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_completion_response(
|
||||||
|
text: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
model_name: Optional[str],
|
||||||
|
):
|
||||||
|
"""Create a chat completion response from the provided text."""
|
||||||
|
message = ChatCompletionMessage(role="assistant", content=text)
|
||||||
|
|
||||||
|
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)
|
||||||
|
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
model=unwrap(model_name, ""),
|
||||||
|
usage=UsageStats(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_completion_stream_chunk(
|
||||||
|
const_id: str,
|
||||||
|
text: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Create a chat completion stream chunk from the provided text."""
|
||||||
|
if finish_reason:
|
||||||
|
message = {}
|
||||||
|
else:
|
||||||
|
message = ChatCompletionMessage(role="assistant", content=text)
|
||||||
|
|
||||||
|
# The finish reason can be None
|
||||||
|
choice = ChatCompletionStreamChoice(finish_reason=finish_reason, delta=message)
|
||||||
|
|
||||||
|
chunk = ChatCompletionStreamChunk(
|
||||||
|
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||||
|
)
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
|
||||||
|
"""Get the list of models from the provided path."""
|
||||||
|
|
||||||
|
# Convert the provided draft model path to a pathlib path for
|
||||||
|
# equality comparisons
|
||||||
|
if draft_model_path:
|
||||||
|
draft_model_path = pathlib.Path(draft_model_path).resolve()
|
||||||
|
|
||||||
|
model_card_list = ModelList()
|
||||||
|
for path in model_path.iterdir():
|
||||||
|
# Don't include the draft models path
|
||||||
|
if path.is_dir() and path != draft_model_path:
|
||||||
|
model_card = ModelCard(id=path.name)
|
||||||
|
model_card_list.data.append(model_card) # pylint: disable=no-member
|
||||||
|
|
||||||
|
return model_card_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_list(lora_path: pathlib.Path):
|
||||||
|
"""Get the list of Lora cards from the provided path."""
|
||||||
|
lora_list = LoraList()
|
||||||
|
for path in lora_path.iterdir():
|
||||||
|
if path.is_dir():
|
||||||
|
lora_card = LoraCard(id=path.name)
|
||||||
|
lora_list.data.append(lora_card) # pylint: disable=no-member
|
||||||
|
|
||||||
|
return lora_list
|
||||||
114
auth.py
114
auth.py
|
|
@ -1,105 +1,125 @@
|
||||||
import secrets
|
|
||||||
import yaml
|
|
||||||
from fastapi import Header, HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This method of authorization is pretty insecure, but since TabbyAPI is a local
|
This method of authorization is pretty insecure, but since TabbyAPI is a local
|
||||||
application, it should be fine.
|
application, it should be fine.
|
||||||
"""
|
"""
|
||||||
|
import secrets
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Header, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
class AuthKeys(BaseModel):
|
class AuthKeys(BaseModel):
|
||||||
|
"""
|
||||||
|
This class represents the authentication keys for the application.
|
||||||
|
It contains two types of keys: 'api_key' and 'admin_key'.
|
||||||
|
The 'api_key' is used for general API calls, while the 'admin_key'
|
||||||
|
is used for administrative tasks. The class also provides a method
|
||||||
|
to verify if a given key matches the stored 'api_key' or 'admin_key'.
|
||||||
|
"""
|
||||||
|
|
||||||
api_key: str
|
api_key: str
|
||||||
admin_key: str
|
admin_key: str
|
||||||
|
|
||||||
def verify_key(self, test_key: str, key_type: str):
|
def verify_key(self, test_key: str, key_type: str):
|
||||||
# Match statements are only available in python 3.10 and up
|
"""Verify if a given key matches the stored key."""
|
||||||
if key_type == "admin_key":
|
if key_type == "admin_key":
|
||||||
return test_key == self.admin_key
|
return test_key == self.admin_key
|
||||||
elif key_type == "api_key":
|
if key_type == "api_key":
|
||||||
# Admin keys are valid for all API calls
|
# Admin keys are valid for all API calls
|
||||||
return test_key == self.api_key or test_key == self.admin_key
|
return test_key == self.api_key or test_key == self.admin_key
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
auth_keys: Optional[AuthKeys] = None
|
|
||||||
disable_auth: bool = False
|
AUTH_KEYS: Optional[AuthKeys] = None
|
||||||
|
DISABLE_AUTH: bool = False
|
||||||
|
|
||||||
|
|
||||||
def load_auth_keys(disable_from_config: bool):
|
def load_auth_keys(disable_from_config: bool):
|
||||||
global auth_keys
|
"""Load the authentication keys from api_tokens.yml. If the file does not
|
||||||
global disable_auth
|
exist, generate new keys and save them to api_tokens.yml."""
|
||||||
|
global AUTH_KEYS
|
||||||
|
global DISABLE_AUTH
|
||||||
|
|
||||||
disable_auth = disable_from_config
|
DISABLE_AUTH = disable_from_config
|
||||||
if disable_from_config:
|
if disable_from_config:
|
||||||
print(
|
print(
|
||||||
"!! Warning: Disabling authentication makes your instance vulnerable.",
|
"!! Warning: Disabling authentication",
|
||||||
"Set the \"disable_auth\" flag to False in config.yml if you want to share this",
|
"makes your instance vulnerable.",
|
||||||
"instance with others."
|
"Set the 'disable_auth' flag to False in config.yml",
|
||||||
|
"if you want to share this instance with others.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open("api_tokens.yml", "r", encoding = 'utf8') as auth_file:
|
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||||
auth_keys_dict = yaml.safe_load(auth_file)
|
auth_keys_dict = yaml.safe_load(auth_file)
|
||||||
auth_keys = AuthKeys.model_validate(auth_keys_dict)
|
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||||
except OSError:
|
except OSError:
|
||||||
new_auth_keys = AuthKeys(
|
new_auth_keys = AuthKeys(
|
||||||
api_key = secrets.token_hex(16),
|
api_key=secrets.token_hex(16), admin_key=secrets.token_hex(16)
|
||||||
admin_key = secrets.token_hex(16)
|
|
||||||
)
|
)
|
||||||
auth_keys = new_auth_keys
|
AUTH_KEYS = new_auth_keys
|
||||||
|
|
||||||
with open("api_tokens.yml", "w", encoding = "utf8") as auth_file:
|
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||||
yaml.safe_dump(auth_keys.model_dump(), auth_file, default_flow_style=False)
|
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Your API key is: {auth_keys.api_key}\n"
|
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||||
f"Your admin key is: {auth_keys.admin_key}\n\n"
|
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
|
||||||
"If these keys get compromised, make sure to delete api_tokens.yml and restart the server. Have fun!"
|
"If these keys get compromised, make sure to delete api_tokens.yml "
|
||||||
|
"and restart the server. Have fun!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
|
def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)):
|
||||||
|
"""Check if the API key is valid."""
|
||||||
|
|
||||||
# Allow request if auth is disabled
|
# Allow request if auth is disabled
|
||||||
if disable_auth:
|
if DISABLE_AUTH:
|
||||||
return
|
return
|
||||||
|
|
||||||
if x_api_key:
|
if x_api_key:
|
||||||
if auth_keys.verify_key(x_api_key, "api_key"):
|
if not AUTH_KEYS.verify_key(x_api_key, "api_key"):
|
||||||
return x_api_key
|
|
||||||
else:
|
|
||||||
raise HTTPException(401, "Invalid API key")
|
raise HTTPException(401, "Invalid API key")
|
||||||
elif authorization:
|
return x_api_key
|
||||||
split_key = authorization.split(" ")
|
|
||||||
|
|
||||||
|
if authorization:
|
||||||
|
split_key = authorization.split(" ")
|
||||||
if len(split_key) < 2:
|
if len(split_key) < 2:
|
||||||
raise HTTPException(401, "Invalid API key")
|
raise HTTPException(401, "Invalid API key")
|
||||||
elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "api_key"):
|
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
|
||||||
return authorization
|
split_key[1], "api_key"
|
||||||
else:
|
):
|
||||||
raise HTTPException(401, "Invalid API key")
|
raise HTTPException(401, "Invalid API key")
|
||||||
else:
|
|
||||||
|
return authorization
|
||||||
|
|
||||||
raise HTTPException(401, "Please provide an API key")
|
raise HTTPException(401, "Please provide an API key")
|
||||||
|
|
||||||
|
|
||||||
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
|
def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)):
|
||||||
|
"""Check if the admin key is valid."""
|
||||||
|
|
||||||
# Allow request if auth is disabled
|
# Allow request if auth is disabled
|
||||||
if disable_auth:
|
if DISABLE_AUTH:
|
||||||
return
|
return
|
||||||
|
|
||||||
if x_admin_key:
|
if x_admin_key:
|
||||||
if auth_keys.verify_key(x_admin_key, "admin_key"):
|
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
|
||||||
return x_admin_key
|
|
||||||
else:
|
|
||||||
raise HTTPException(401, "Invalid admin key")
|
raise HTTPException(401, "Invalid admin key")
|
||||||
elif authorization:
|
return x_admin_key
|
||||||
split_key = authorization.split(" ")
|
|
||||||
|
|
||||||
|
if authorization:
|
||||||
|
split_key = authorization.split(" ")
|
||||||
if len(split_key) < 2:
|
if len(split_key) < 2:
|
||||||
raise HTTPException(401, "Invalid admin key")
|
raise HTTPException(401, "Invalid admin key")
|
||||||
elif split_key[0].lower() == "bearer" and auth_keys.verify_key(split_key[1], "admin_key"):
|
if split_key[0].lower() != "bearer" or not AUTH_KEYS.verify_key(
|
||||||
return authorization
|
split_key[1], "admin_key"
|
||||||
else:
|
):
|
||||||
raise HTTPException(401, "Invalid admin key")
|
raise HTTPException(401, "Invalid admin key")
|
||||||
else:
|
return authorization
|
||||||
|
|
||||||
raise HTTPException(401, "Please provide an admin key")
|
raise HTTPException(401, "Please provide an admin key")
|
||||||
|
|
|
||||||
36
formatting.bat
Normal file
36
formatting.bat
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
@echo off
|
||||||
|
setlocal
|
||||||
|
|
||||||
|
::Change to script's directory
|
||||||
|
cd /d %~dp0
|
||||||
|
|
||||||
|
::Get tool versions
|
||||||
|
for /f "tokens=2" %%i in ('ruff --version') do set RUFF_VERSION="%%i"
|
||||||
|
|
||||||
|
::Check tool versions
|
||||||
|
call :tool_version_check "ruff" %RUFF_VERSION% "0.1.9"
|
||||||
|
|
||||||
|
::Format and lint files
|
||||||
|
call ruff format
|
||||||
|
call ruff check
|
||||||
|
|
||||||
|
echo tabbyAPI ruff lint and format: Done
|
||||||
|
|
||||||
|
::Check if any files were changed
|
||||||
|
git diff --quiet
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo Reformatted files. Please review and stage the changes.
|
||||||
|
echo Changes not staged for commit:
|
||||||
|
echo.
|
||||||
|
git --no-pager diff --name-only
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
exit /b 0
|
||||||
|
|
||||||
|
:tool_version_check
|
||||||
|
if not "%2"=="%3" (
|
||||||
|
echo Wrong %1 version installed: %3 is required, not %2.
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
exit /b 0
|
||||||
53
formatting.sh
Executable file
53
formatting.sh
Executable file
|
|
@ -0,0 +1,53 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# YAPF formatter, adapted from ray and skypilot.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# # Do work and commit your work.
|
||||||
|
|
||||||
|
# # Format files that differ from origin/main.
|
||||||
|
# bash formatting.sh
|
||||||
|
|
||||||
|
# # Commit changed files with message 'Run yapf and ruff'
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
||||||
|
# You are encouraged to run this locally before pushing changes for review.
|
||||||
|
|
||||||
|
# Cause the script to exit if a single command fails
|
||||||
|
set -eo pipefail
|
||||||
|
|
||||||
|
# this stops git rev-parse from failing if we run this from the .git directory
|
||||||
|
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
|
||||||
|
ROOT="$(git rev-parse --show-toplevel)"
|
||||||
|
builtin cd "$ROOT" || exit 1
|
||||||
|
|
||||||
|
RUFF_VERSION=$(ruff --version | head -n 1 | awk '{print $2}')
|
||||||
|
|
||||||
|
# params: tool name, tool version, required version
|
||||||
|
tool_version_check() {
|
||||||
|
if [[ $2 != $3 ]]; then
|
||||||
|
echo "Wrong $1 version installed: $3 is required, not $2."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)"
|
||||||
|
|
||||||
|
# Format and lint all files
|
||||||
|
format_and_lint() {
|
||||||
|
ruff format
|
||||||
|
ruff check
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call format command
|
||||||
|
format_and_lint
|
||||||
|
echo 'tabbyAPI ruff lint and format: Done'
|
||||||
|
|
||||||
|
if ! git diff --quiet &>/dev/null; then
|
||||||
|
echo 'Reformatted files. Please review and stage the changes.'
|
||||||
|
echo 'Changes not staged for commit:'
|
||||||
|
echo
|
||||||
|
git --no-pager diff --name-only
|
||||||
|
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
@ -1,31 +1,40 @@
|
||||||
|
"""
|
||||||
|
Functions for logging generation events.
|
||||||
|
"""
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# Logging preference config
|
|
||||||
class LogConfig(BaseModel):
|
class LogConfig(BaseModel):
|
||||||
|
"""Logging preference config."""
|
||||||
|
|
||||||
prompt: bool = False
|
prompt: bool = False
|
||||||
generation_params: bool = False
|
generation_params: bool = False
|
||||||
|
|
||||||
# Global reference to logging preferences
|
|
||||||
config = LogConfig()
|
|
||||||
|
|
||||||
# Wrapper to set the logging config for generations
|
# Global reference to logging preferences
|
||||||
|
CONFIG = LogConfig()
|
||||||
|
|
||||||
|
|
||||||
def update_from_dict(options_dict: Dict[str, bool]):
|
def update_from_dict(options_dict: Dict[str, bool]):
|
||||||
global config
|
"""Wrapper to set the logging config for generations"""
|
||||||
|
global CONFIG
|
||||||
|
|
||||||
# Force bools on the dict
|
# Force bools on the dict
|
||||||
for value in options_dict.values():
|
for value in options_dict.values():
|
||||||
if value is None:
|
if value is None:
|
||||||
value = False
|
value = False
|
||||||
|
|
||||||
config = LogConfig.model_validate(options_dict)
|
CONFIG = LogConfig.model_validate(options_dict)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_status():
|
def broadcast_status():
|
||||||
|
"""Broadcasts the current logging status"""
|
||||||
enabled = []
|
enabled = []
|
||||||
if config.prompt:
|
if CONFIG.prompt:
|
||||||
enabled.append("prompts")
|
enabled.append("prompts")
|
||||||
|
|
||||||
if config.generation_params:
|
if CONFIG.generation_params:
|
||||||
enabled.append("generation params")
|
enabled.append("generation params")
|
||||||
|
|
||||||
if len(enabled) > 0:
|
if len(enabled) > 0:
|
||||||
|
|
@ -33,15 +42,20 @@ def broadcast_status():
|
||||||
else:
|
else:
|
||||||
print("Generation logging is disabled")
|
print("Generation logging is disabled")
|
||||||
|
|
||||||
# Logs generation parameters to console
|
|
||||||
def log_generation_params(**kwargs):
|
def log_generation_params(**kwargs):
|
||||||
if config.generation_params:
|
"""Logs generation parameters to console."""
|
||||||
|
if CONFIG.generation_params:
|
||||||
print(f"Generation options: {kwargs}\n")
|
print(f"Generation options: {kwargs}\n")
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(prompt: str):
|
def log_prompt(prompt: str):
|
||||||
if config.prompt:
|
"""Logs the prompt to console."""
|
||||||
|
if CONFIG.prompt:
|
||||||
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
|
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
|
||||||
|
|
||||||
|
|
||||||
def log_response(response: str):
|
def log_response(response: str):
|
||||||
if config.prompt:
|
"""Logs the response to console."""
|
||||||
|
if CONFIG.prompt:
|
||||||
print(f"Response: {response if response else 'Empty'}\n")
|
print(f"Response: {response if response else 'Empty'}\n")
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
"""Generator functions for the tabbyAPI."""
|
||||||
import inspect
|
import inspect
|
||||||
from asyncio import Semaphore
|
from asyncio import Semaphore
|
||||||
from functools import partialmethod
|
from functools import partialmethod
|
||||||
|
|
@ -5,8 +6,10 @@ from typing import AsyncGenerator
|
||||||
|
|
||||||
generate_semaphore = Semaphore(1)
|
generate_semaphore = Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
# Async generation that blocks on a semaphore
|
# Async generation that blocks on a semaphore
|
||||||
async def generate_with_semaphore(generator: AsyncGenerator):
|
async def generate_with_semaphore(generator: AsyncGenerator):
|
||||||
|
"""Generate with a semaphore."""
|
||||||
async with generate_semaphore:
|
async with generate_semaphore:
|
||||||
if inspect.isasyncgenfunction:
|
if inspect.isasyncgenfunction:
|
||||||
async for result in generator():
|
async for result in generator():
|
||||||
|
|
@ -15,6 +18,7 @@ async def generate_with_semaphore(generator: AsyncGenerator):
|
||||||
for result in generator():
|
for result in generator():
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
# Block a function with semaphore
|
# Block a function with semaphore
|
||||||
async def call_with_semaphore(callback: partialmethod):
|
async def call_with_semaphore(callback: partialmethod):
|
||||||
if inspect.iscoroutinefunction(callback):
|
if inspect.iscoroutinefunction(callback):
|
||||||
|
|
|
||||||
355
main.py
355
main.py
|
|
@ -1,14 +1,16 @@
|
||||||
import uvicorn
|
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||||
import yaml
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
import yaml
|
||||||
|
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from progress.bar import IncrementalBar
|
from progress.bar import IncrementalBar
|
||||||
from typing import Optional
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import gen_logging
|
import gen_logging
|
||||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||||
|
|
@ -17,19 +19,24 @@ from model import ModelContainer
|
||||||
from OAI.types.completion import CompletionRequest
|
from OAI.types.completion import CompletionRequest
|
||||||
from OAI.types.chat_completion import ChatCompletionRequest
|
from OAI.types.chat_completion import ChatCompletionRequest
|
||||||
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
||||||
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse, ModelCardParameters
|
from OAI.types.model import (
|
||||||
|
ModelCard,
|
||||||
|
ModelLoadRequest,
|
||||||
|
ModelLoadResponse,
|
||||||
|
ModelCardParameters,
|
||||||
|
)
|
||||||
from OAI.types.token import (
|
from OAI.types.token import (
|
||||||
TokenEncodeRequest,
|
TokenEncodeRequest,
|
||||||
TokenEncodeResponse,
|
TokenEncodeResponse,
|
||||||
TokenDecodeRequest,
|
TokenDecodeRequest,
|
||||||
TokenDecodeResponse
|
TokenDecodeResponse,
|
||||||
)
|
)
|
||||||
from OAI.utils import (
|
from OAI.utils_oai import (
|
||||||
create_completion_response,
|
create_completion_response,
|
||||||
get_model_list,
|
get_model_list,
|
||||||
get_lora_list,
|
get_lora_list,
|
||||||
create_chat_completion_response,
|
create_chat_completion_response,
|
||||||
create_chat_completion_stream_chunk
|
create_chat_completion_stream_chunk,
|
||||||
)
|
)
|
||||||
from templating import get_prompt_from_template
|
from templating import get_prompt_from_template
|
||||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||||
|
|
@ -37,13 +44,15 @@ from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# Globally scoped variables. Undefined until initalized in main
|
# Globally scoped variables. Undefined until initalized in main
|
||||||
model_container: Optional[ModelContainer] = None
|
MODEL_CONTAINER: Optional[ModelContainer] = None
|
||||||
config: dict = {}
|
config: dict = {}
|
||||||
|
|
||||||
|
|
||||||
def _check_model_container():
|
def _check_model_container():
|
||||||
if model_container is None or model_container.model is None:
|
if MODEL_CONTAINER is None or MODEL_CONTAINER.model is None:
|
||||||
raise HTTPException(400, "No models are loaded.")
|
raise HTTPException(400, "No models are loaded.")
|
||||||
|
|
||||||
|
|
||||||
# ALlow CORS requests
|
# ALlow CORS requests
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
|
@ -53,10 +62,12 @@ app.add_middleware(
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Model list endpoint
|
# Model list endpoint
|
||||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||||
async def list_models():
|
async def list_models():
|
||||||
|
"""Lists all models in the model directory."""
|
||||||
model_config = unwrap(config.get("model"), {})
|
model_config = unwrap(config.get("model"), {})
|
||||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||||
model_path = pathlib.Path(model_dir)
|
model_path = pathlib.Path(model_dir)
|
||||||
|
|
@ -66,43 +77,53 @@ async def list_models():
|
||||||
|
|
||||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||||
if unwrap(model_config.get("use_dummy_models"), False):
|
if unwrap(model_config.get("use_dummy_models"), False):
|
||||||
models.data.insert(0, ModelCard(id = "gpt-3.5-turbo"))
|
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
# Currently loaded model endpoint
|
# Currently loaded model endpoint
|
||||||
@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
@app.get(
|
||||||
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
"/v1/model",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
|
@app.get(
|
||||||
|
"/v1/internal/model/info",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def get_current_model():
|
async def get_current_model():
|
||||||
model_name = model_container.get_model_path().name
|
"""Returns the currently loaded model."""
|
||||||
prompt_template = model_container.prompt_template
|
model_name = MODEL_CONTAINER.get_model_path().name
|
||||||
|
prompt_template = MODEL_CONTAINER.prompt_template
|
||||||
model_card = ModelCard(
|
model_card = ModelCard(
|
||||||
id = model_name,
|
id=model_name,
|
||||||
parameters = ModelCardParameters(
|
parameters=ModelCardParameters(
|
||||||
rope_scale = model_container.config.scale_pos_emb,
|
rope_scale=MODEL_CONTAINER.config.scale_pos_emb,
|
||||||
rope_alpha = model_container.config.scale_alpha_value,
|
rope_alpha=MODEL_CONTAINER.config.scale_alpha_value,
|
||||||
max_seq_len = model_container.config.max_seq_len,
|
max_seq_len=MODEL_CONTAINER.config.max_seq_len,
|
||||||
cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
|
cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16",
|
||||||
prompt_template = prompt_template.name if prompt_template else None
|
prompt_template=prompt_template.name if prompt_template else None,
|
||||||
),
|
),
|
||||||
logging = gen_logging.config
|
logging=gen_logging.CONFIG,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_container.draft_config:
|
if MODEL_CONTAINER.draft_config:
|
||||||
draft_card = ModelCard(
|
draft_card = ModelCard(
|
||||||
id = model_container.get_model_path(True).name,
|
id=MODEL_CONTAINER.get_model_path(True).name,
|
||||||
parameters = ModelCardParameters(
|
parameters=ModelCardParameters(
|
||||||
rope_scale = model_container.draft_config.scale_pos_emb,
|
rope_scale=MODEL_CONTAINER.draft_config.scale_pos_emb,
|
||||||
rope_alpha = model_container.draft_config.scale_alpha_value,
|
rope_alpha=MODEL_CONTAINER.draft_config.scale_alpha_value,
|
||||||
max_seq_len = model_container.draft_config.max_seq_len
|
max_seq_len=MODEL_CONTAINER.draft_config.max_seq_len,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
model_card.parameters.draft = draft_card
|
model_card.parameters.draft = draft_card
|
||||||
|
|
||||||
return model_card
|
return model_card
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||||
async def list_draft_models():
|
async def list_draft_models():
|
||||||
|
"""Lists all draft models in the model directory."""
|
||||||
model_config = unwrap(config.get("model"), {})
|
model_config = unwrap(config.get("model"), {})
|
||||||
draft_config = unwrap(model_config.get("draft"), {})
|
draft_config = unwrap(model_config.get("draft"), {})
|
||||||
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
|
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
|
||||||
|
|
@ -112,12 +133,14 @@ async def list_draft_models():
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
# Load model endpoint
|
# Load model endpoint
|
||||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||||
async def load_model(request: Request, data: ModelLoadRequest):
|
async def load_model(request: Request, data: ModelLoadRequest):
|
||||||
global model_container
|
"""Loads a model into the model container."""
|
||||||
|
global MODEL_CONTAINER
|
||||||
|
|
||||||
if model_container and model_container.model:
|
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||||
raise HTTPException(400, "A model is already loaded! Please unload it first.")
|
raise HTTPException(400, "A model is already loaded! Please unload it first.")
|
||||||
|
|
||||||
if not data.name:
|
if not data.name:
|
||||||
|
|
@ -129,32 +152,35 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||||
|
|
||||||
load_data = data.model_dump()
|
load_data = data.model_dump()
|
||||||
|
|
||||||
# TODO: Add API exception if draft directory isn't found
|
|
||||||
draft_config = unwrap(model_config.get("draft"), {})
|
draft_config = unwrap(model_config.get("draft"), {})
|
||||||
if data.draft:
|
if data.draft:
|
||||||
if not data.draft.draft_model_name:
|
if not data.draft.draft_model_name:
|
||||||
raise HTTPException(400, "draft_model_name was not found inside the draft object.")
|
raise HTTPException(
|
||||||
|
400, "draft_model_name was not found inside the draft object."
|
||||||
|
)
|
||||||
|
|
||||||
load_data["draft"]["draft_model_dir"] = unwrap(draft_config.get("draft_model_dir"), "models")
|
load_data["draft"]["draft_model_dir"] = unwrap(
|
||||||
|
draft_config.get("draft_model_dir"), "models"
|
||||||
|
)
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||||
|
|
||||||
model_container = ModelContainer(model_path.resolve(), False, **load_data)
|
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data)
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
global model_container
|
"""Generator for the loading process."""
|
||||||
|
|
||||||
model_type = "draft" if model_container.draft_config else "model"
|
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
|
||||||
load_status = model_container.load_gen(load_progress)
|
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for (module, modules) in load_status:
|
for module, modules in load_status:
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
break
|
break
|
||||||
|
|
||||||
if module == 0:
|
if module == 0:
|
||||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
|
||||||
elif module == modules:
|
elif module == modules:
|
||||||
loading_bar.next()
|
loading_bar.next()
|
||||||
loading_bar.finish()
|
loading_bar.finish()
|
||||||
|
|
@ -163,13 +189,13 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
module=module,
|
module=module,
|
||||||
modules=modules,
|
modules=modules,
|
||||||
status="finished"
|
status="finished",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield get_sse_packet(response.model_dump_json())
|
||||||
|
|
||||||
# Switch to model progress if the draft model is loaded
|
# Switch to model progress if the draft model is loaded
|
||||||
if model_container.draft_config:
|
if MODEL_CONTAINER.draft_config:
|
||||||
model_type = "model"
|
model_type = "model"
|
||||||
else:
|
else:
|
||||||
loading_bar.next()
|
loading_bar.next()
|
||||||
|
|
@ -178,29 +204,39 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
module=module,
|
module=module,
|
||||||
modules=modules,
|
modules=modules,
|
||||||
status="processing"
|
status="processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield get_sse_packet(response.model_dump_json())
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.")
|
print(
|
||||||
except Exception as e:
|
"\nError: Model load cancelled by user. "
|
||||||
yield get_generator_error(str(e))
|
"Please make sure to run unload to free up resources."
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
yield get_generator_error(str(exc))
|
||||||
|
|
||||||
|
return StreamingResponse(generator(), media_type="text/event-stream")
|
||||||
|
|
||||||
return StreamingResponse(generator(), media_type = "text/event-stream")
|
|
||||||
|
|
||||||
# Unload model endpoint
|
# Unload model endpoint
|
||||||
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
@app.get(
|
||||||
|
"/v1/model/unload",
|
||||||
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def unload_model():
|
async def unload_model():
|
||||||
global model_container
|
"""Unloads the currently loaded model."""
|
||||||
|
global MODEL_CONTAINER
|
||||||
|
|
||||||
|
MODEL_CONTAINER.unload()
|
||||||
|
MODEL_CONTAINER = None
|
||||||
|
|
||||||
model_container.unload()
|
|
||||||
model_container = None
|
|
||||||
|
|
||||||
# Lora list endpoint
|
# Lora list endpoint
|
||||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||||
async def get_all_loras():
|
async def get_all_loras():
|
||||||
|
"""Lists all LoRAs in the lora directory."""
|
||||||
model_config = unwrap(config.get("model"), {})
|
model_config = unwrap(config.get("model"), {})
|
||||||
lora_config = unwrap(model_config.get("lora"), {})
|
lora_config = unwrap(model_config.get("lora"), {})
|
||||||
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||||
|
|
@ -209,24 +245,36 @@ async def get_all_loras():
|
||||||
|
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
|
|
||||||
# Currently loaded loras endpoint
|
# Currently loaded loras endpoint
|
||||||
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
@app.get(
|
||||||
|
"/v1/lora",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def get_active_loras():
|
async def get_active_loras():
|
||||||
|
"""Returns the currently loaded loras."""
|
||||||
active_loras = LoraList(
|
active_loras = LoraList(
|
||||||
data = list(map(
|
data=list(
|
||||||
|
map(
|
||||||
lambda lora: LoraCard(
|
lambda lora: LoraCard(
|
||||||
id = pathlib.Path(lora.lora_path).parent.name,
|
id=pathlib.Path(lora.lora_path).parent.name,
|
||||||
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha
|
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||||
),
|
),
|
||||||
model_container.active_loras
|
MODEL_CONTAINER.active_loras,
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
))
|
|
||||||
|
|
||||||
return active_loras
|
return active_loras
|
||||||
|
|
||||||
|
|
||||||
# Load lora endpoint
|
# Load lora endpoint
|
||||||
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
@app.post(
|
||||||
|
"/v1/lora/load",
|
||||||
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def load_lora(data: LoraLoadRequest):
|
async def load_lora(data: LoraLoadRequest):
|
||||||
|
"""Loads a LoRA into the model container."""
|
||||||
if not data.loras:
|
if not data.loras:
|
||||||
raise HTTPException(400, "List of loras to load is not found.")
|
raise HTTPException(400, "List of loras to load is not found.")
|
||||||
|
|
||||||
|
|
@ -234,166 +282,203 @@ async def load_lora(data: LoraLoadRequest):
|
||||||
lora_config = unwrap(model_config.get("lora"), {})
|
lora_config = unwrap(model_config.get("lora"), {})
|
||||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||||
if not lora_dir.exists():
|
if not lora_dir.exists():
|
||||||
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")
|
raise HTTPException(
|
||||||
|
400,
|
||||||
# Clean-up existing loras if present
|
"A parent lora directory does not exist. Check your config.yml?",
|
||||||
if len(model_container.active_loras) > 0:
|
|
||||||
model_container.unload(True)
|
|
||||||
|
|
||||||
result = model_container.load_loras(lora_dir, **data.model_dump())
|
|
||||||
return LoraLoadResponse(
|
|
||||||
success = unwrap(result.get("success"), []),
|
|
||||||
failure = unwrap(result.get("failure"), [])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clean-up existing loras if present
|
||||||
|
if len(MODEL_CONTAINER.active_loras) > 0:
|
||||||
|
MODEL_CONTAINER.unload(True)
|
||||||
|
|
||||||
|
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
|
||||||
|
return LoraLoadResponse(
|
||||||
|
success=unwrap(result.get("success"), []),
|
||||||
|
failure=unwrap(result.get("failure"), []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Unload lora endpoint
|
# Unload lora endpoint
|
||||||
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
|
@app.get(
|
||||||
|
"/v1/lora/unload",
|
||||||
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def unload_loras():
|
async def unload_loras():
|
||||||
model_container.unload(True)
|
"""Unloads the currently loaded loras."""
|
||||||
|
MODEL_CONTAINER.unload(True)
|
||||||
|
|
||||||
|
|
||||||
# Encode tokens endpoint
|
# Encode tokens endpoint
|
||||||
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
@app.post(
|
||||||
|
"/v1/token/encode",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def encode_tokens(data: TokenEncodeRequest):
|
async def encode_tokens(data: TokenEncodeRequest):
|
||||||
raw_tokens = model_container.get_tokens(data.text, None, **data.get_params())
|
"""Encodes a string into tokens."""
|
||||||
|
raw_tokens = MODEL_CONTAINER.get_tokens(data.text, None, **data.get_params())
|
||||||
|
|
||||||
# Have to use this if check otherwise Torch's tensors error out with a boolean issue
|
# Have to use this if check otherwise Torch's tensors error out
|
||||||
|
# with a boolean issue
|
||||||
tokens = raw_tokens[0].tolist() if raw_tokens is not None else []
|
tokens = raw_tokens[0].tolist() if raw_tokens is not None else []
|
||||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Decode tokens endpoint
|
# Decode tokens endpoint
|
||||||
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
@app.post(
|
||||||
|
"/v1/token/decode",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def decode_tokens(data: TokenDecodeRequest):
|
async def decode_tokens(data: TokenDecodeRequest):
|
||||||
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
"""Decodes tokens into a string."""
|
||||||
response = TokenDecodeResponse(text = unwrap(message, ""))
|
message = MODEL_CONTAINER.get_tokens(None, data.tokens, **data.get_params())
|
||||||
|
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Completions endpoint
|
# Completions endpoint
|
||||||
@app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
@app.post(
|
||||||
|
"/v1/completions",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
async def generate_completion(request: Request, data: CompletionRequest):
|
async def generate_completion(request: Request, data: CompletionRequest):
|
||||||
model_path = model_container.get_model_path()
|
"""Generates a completion from a prompt."""
|
||||||
|
model_path = MODEL_CONTAINER.get_model_path()
|
||||||
|
|
||||||
if isinstance(data.prompt, list):
|
if isinstance(data.prompt, list):
|
||||||
data.prompt = "\n".join(data.prompt)
|
data.prompt = "\n".join(data.prompt)
|
||||||
|
|
||||||
if data.stream:
|
if data.stream:
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
|
"""Generator for the generation process."""
|
||||||
try:
|
try:
|
||||||
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
|
new_generation = MODEL_CONTAINER.generate_gen(
|
||||||
for (part, prompt_tokens, completion_tokens) in new_generation:
|
data.prompt, **data.to_gen_params()
|
||||||
|
)
|
||||||
|
for part, prompt_tokens, completion_tokens in new_generation:
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
break
|
break
|
||||||
|
|
||||||
response = create_completion_response(part,
|
response = create_completion_response(
|
||||||
prompt_tokens,
|
part, prompt_tokens, completion_tokens, model_path.name
|
||||||
completion_tokens,
|
)
|
||||||
model_path.name)
|
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield get_sse_packet(response.model_dump_json())
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
print("Error: Completion request cancelled by user.")
|
print("Error: Completion request cancelled by user.")
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
yield get_generator_error(str(e))
|
yield get_generator_error(str(exc))
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_with_semaphore(generator),
|
generate_with_semaphore(generator), media_type="text/event-stream"
|
||||||
media_type = "text/event-stream"
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
||||||
partial(model_container.generate, data.prompt, **data.to_gen_params())
|
partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params())
|
||||||
|
)
|
||||||
|
|
||||||
|
response = create_completion_response(
|
||||||
|
response_text, prompt_tokens, completion_tokens, model_path.name
|
||||||
)
|
)
|
||||||
response = create_completion_response(response_text,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
model_path.name)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Chat completions endpoint
|
|
||||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
|
||||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
|
||||||
if model_container.prompt_template is None:
|
|
||||||
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
|
|
||||||
|
|
||||||
model_path = model_container.get_model_path()
|
# Chat completions endpoint
|
||||||
|
@app.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
|
||||||
|
)
|
||||||
|
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||||
|
"""Generates a chat completion from a prompt."""
|
||||||
|
if MODEL_CONTAINER.prompt_template is None:
|
||||||
|
return HTTPException(
|
||||||
|
422,
|
||||||
|
"This endpoint is disabled because a prompt template is not set.",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_path = MODEL_CONTAINER.get_model_path()
|
||||||
|
|
||||||
if isinstance(data.messages, str):
|
if isinstance(data.messages, str):
|
||||||
prompt = data.messages
|
prompt = data.messages
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
special_tokens_dict = model_container.get_special_tokens(
|
special_tokens_dict = MODEL_CONTAINER.get_special_tokens(
|
||||||
unwrap(data.add_bos_token, True),
|
unwrap(data.add_bos_token, True),
|
||||||
unwrap(data.ban_eos_token, False)
|
unwrap(data.ban_eos_token, False),
|
||||||
)
|
)
|
||||||
prompt = get_prompt_from_template(
|
prompt = get_prompt_from_template(
|
||||||
data.messages,
|
data.messages,
|
||||||
model_container.prompt_template,
|
MODEL_CONTAINER.prompt_template,
|
||||||
data.add_generation_prompt,
|
data.add_generation_prompt,
|
||||||
special_tokens_dict,
|
special_tokens_dict,
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
400,
|
400,
|
||||||
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
|
"Could not find a Conversation from prompt template "
|
||||||
|
f"'{MODEL_CONTAINER.prompt_template.name}'. "
|
||||||
|
"Check your spelling?",
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.stream:
|
if data.stream:
|
||||||
const_id = f"chatcmpl-{uuid4().hex}"
|
const_id = f"chatcmpl-{uuid4().hex}"
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
|
"""Generator for the generation process."""
|
||||||
try:
|
try:
|
||||||
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
|
new_generation = MODEL_CONTAINER.generate_gen(
|
||||||
for (part, _, _) in new_generation:
|
prompt, **data.to_gen_params()
|
||||||
|
)
|
||||||
|
for part, _, _ in new_generation:
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
break
|
break
|
||||||
|
|
||||||
response = create_chat_completion_stream_chunk(
|
response = create_chat_completion_stream_chunk(
|
||||||
const_id,
|
const_id, part, model_path.name
|
||||||
part,
|
|
||||||
model_path.name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(response.model_dump_json())
|
yield get_sse_packet(response.model_dump_json())
|
||||||
|
|
||||||
# Yield a finish response on successful generation
|
# Yield a finish response on successful generation
|
||||||
finish_response = create_chat_completion_stream_chunk(
|
finish_response = create_chat_completion_stream_chunk(
|
||||||
const_id,
|
const_id, finish_reason="stop"
|
||||||
finish_reason = "stop"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_sse_packet(finish_response.model_dump_json())
|
yield get_sse_packet(finish_response.model_dump_json())
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
print("Error: Chat completion cancelled by user.")
|
print("Error: Chat completion cancelled by user.")
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
yield get_generator_error(str(e))
|
yield get_generator_error(str(exc))
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_with_semaphore(generator),
|
generate_with_semaphore(generator), media_type="text/event-stream"
|
||||||
media_type = "text/event-stream"
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
response_text, prompt_tokens, completion_tokens = await call_with_semaphore(
|
||||||
partial(model_container.generate, prompt, **data.to_gen_params())
|
partial(MODEL_CONTAINER.generate, prompt, **data.to_gen_params())
|
||||||
|
)
|
||||||
|
|
||||||
|
response = create_chat_completion_response(
|
||||||
|
response_text, prompt_tokens, completion_tokens, model_path.name
|
||||||
)
|
)
|
||||||
response = create_chat_completion_response(response_text,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
model_path.name)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
||||||
try:
|
try:
|
||||||
with open('config.yml', 'r', encoding = "utf8") as config_file:
|
with open("config.yml", "r", encoding="utf8") as config_file:
|
||||||
config = unwrap(yaml.safe_load(config_file), {})
|
config = unwrap(yaml.safe_load(config_file), {})
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
print(
|
print(
|
||||||
"The YAML config couldn't load because of the following error:",
|
"The YAML config couldn't load because of the following error:",
|
||||||
f"\n\n{e}",
|
f"\n\n{exc}",
|
||||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
"\n\nTabbyAPI will start anyway and not parse this config file.",
|
||||||
)
|
)
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
|
|
@ -409,18 +494,18 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
gen_logging.broadcast_status()
|
gen_logging.broadcast_status()
|
||||||
|
|
||||||
# If an initial model name is specified, create a container and load the model
|
# If an initial model name is specified, create a container
|
||||||
|
# and load the model
|
||||||
model_config = unwrap(config.get("model"), {})
|
model_config = unwrap(config.get("model"), {})
|
||||||
if "model_name" in model_config:
|
if "model_name" in model_config:
|
||||||
# TODO: Move this to model_container
|
|
||||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||||
model_path = model_path / model_config.get("model_name")
|
model_path = model_path / model_config.get("model_name")
|
||||||
|
|
||||||
model_container = ModelContainer(model_path.resolve(), False, **model_config)
|
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config)
|
||||||
load_status = model_container.load_gen(load_progress)
|
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||||
for (module, modules) in load_status:
|
for module, modules in load_status:
|
||||||
if module == 0:
|
if module == 0:
|
||||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
|
||||||
elif module == modules:
|
elif module == modules:
|
||||||
loading_bar.next()
|
loading_bar.next()
|
||||||
loading_bar.finish()
|
loading_bar.finish()
|
||||||
|
|
@ -431,11 +516,11 @@ if __name__ == "__main__":
|
||||||
lora_config = unwrap(model_config.get("lora"), {})
|
lora_config = unwrap(model_config.get("lora"), {})
|
||||||
if "loras" in lora_config:
|
if "loras" in lora_config:
|
||||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||||
model_container.load_loras(lora_dir.resolve(), **lora_config)
|
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
host=network_config.get("host", "127.0.0.1"),
|
host=network_config.get("host", "127.0.0.1"),
|
||||||
port=network_config.get("port", 5000),
|
port=network_config.get("port", 5000),
|
||||||
log_level="debug"
|
log_level="debug",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
395
model.py
395
model.py
|
|
@ -1,29 +1,36 @@
|
||||||
|
"""The model container class for ExLlamaV2 models."""
|
||||||
import gc
|
import gc
|
||||||
import pathlib
|
import pathlib
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from exllamav2 import(
|
from exllamav2 import (
|
||||||
ExLlamaV2,
|
ExLlamaV2,
|
||||||
ExLlamaV2Config,
|
ExLlamaV2Config,
|
||||||
ExLlamaV2Cache,
|
ExLlamaV2Cache,
|
||||||
ExLlamaV2Cache_8bit,
|
ExLlamaV2Cache_8bit,
|
||||||
ExLlamaV2Tokenizer,
|
ExLlamaV2Tokenizer,
|
||||||
ExLlamaV2Lora
|
ExLlamaV2Lora,
|
||||||
)
|
|
||||||
from exllamav2.generator import(
|
|
||||||
ExLlamaV2StreamingGenerator,
|
|
||||||
ExLlamaV2Sampler
|
|
||||||
)
|
)
|
||||||
|
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||||
|
|
||||||
from gen_logging import log_generation_params, log_prompt, log_response
|
from gen_logging import log_generation_params, log_prompt, log_response
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from templating import PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file
|
from templating import (
|
||||||
|
PromptTemplate,
|
||||||
|
find_template_from_model,
|
||||||
|
get_template_from_model_json,
|
||||||
|
get_template_from_file,
|
||||||
|
)
|
||||||
from utils import coalesce, unwrap
|
from utils import coalesce, unwrap
|
||||||
|
|
||||||
# Bytes to reserve on first device when loading with auto split
|
# Bytes to reserve on first device when loading with auto split
|
||||||
auto_split_reserve_bytes = 96 * 1024**2
|
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
|
||||||
|
|
||||||
|
|
||||||
class ModelContainer:
|
class ModelContainer:
|
||||||
|
"""The model container class for ExLlamaV2 models."""
|
||||||
|
|
||||||
config: Optional[ExLlamaV2Config] = None
|
config: Optional[ExLlamaV2Config] = None
|
||||||
draft_config: Optional[ExLlamaV2Config] = None
|
draft_config: Optional[ExLlamaV2Config] = None
|
||||||
model: Optional[ExLlamaV2] = None
|
model: Optional[ExLlamaV2] = None
|
||||||
|
|
@ -40,35 +47,51 @@ class ModelContainer:
|
||||||
|
|
||||||
active_loras: List[ExLlamaV2Lora] = []
|
active_loras: List[ExLlamaV2Lora] = []
|
||||||
|
|
||||||
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
|
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create model container
|
Create model container
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_dir (int): Model directory containing config.json, tokenizer.model etc.
|
model_dir (int): Model directory containing config.json,
|
||||||
|
tokenizer.model etc.
|
||||||
quiet (bool): Suppress console output
|
quiet (bool): Suppress console output
|
||||||
load_progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
load_progress_callback (function, optional): A function to call for
|
||||||
def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
|
each module loaded. Prototype:
|
||||||
|
def progress(loaded_modules: int, total_modules: int,
|
||||||
|
loading_draft: bool)
|
||||||
**kwargs:
|
**kwargs:
|
||||||
`cache_mode` (str): Sets cache mode, "FP16" or "FP8" (defaulf: "FP16")
|
`cache_mode` (str): Sets cache mode, "FP16" or "FP8"
|
||||||
'max_seq_len' (int): Override model's default max sequence length (default: 4096)
|
(defaulf: "FP16")
|
||||||
'rope_scale' (float): Set RoPE scaling factor for model (default: 1.0)
|
'max_seq_len' (int): Override model's default max sequence
|
||||||
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model (default: 1.0)
|
length (default: 4096)
|
||||||
'prompt_template' (str): Manually sets the prompt template for this model (default: None)
|
'rope_scale' (float): Set RoPE scaling factor for model
|
||||||
'chunk_size' (int): Sets the maximum chunk size for the model (default: 2048)
|
(default: 1.0)
|
||||||
Inferencing in chunks reduces overall VRAM overhead by processing very long sequences in smaller
|
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
|
||||||
batches. This limits the size of temporary buffers needed for the hidden state and attention
|
(default: 1.0)
|
||||||
weights.
|
'prompt_template' (str): Manually sets the prompt template for
|
||||||
|
this model (default: None)
|
||||||
|
'chunk_size' (int): Sets the maximum chunk size for the model
|
||||||
|
(default: 2048)
|
||||||
|
Inferencing in chunks reduces overall VRAM overhead by
|
||||||
|
processing very long sequences in smaller batches. This
|
||||||
|
limits the size of temporary buffers needed for the hidden
|
||||||
|
state and attention weights.
|
||||||
'draft_model_dir' (str): Draft model directory
|
'draft_model_dir' (str): Draft model directory
|
||||||
'draft_rope_scale' (float): Set RoPE scaling factor for draft model (default: 1.0)
|
'draft_rope_scale' (float): Set RoPE scaling factor for draft
|
||||||
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
|
model (default: 1.0)
|
||||||
By default, the draft model's alpha value is calculated automatically to scale to the size of the
|
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
|
||||||
|
model. By default, the draft model's alpha value is
|
||||||
|
calculated automatically to scale to the size of the
|
||||||
full model.
|
full model.
|
||||||
'lora_dir' (str): Lora directory
|
'lora_dir' (str): LoRA directory
|
||||||
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling'
|
'loras' (list[dict]): List of loras to be loaded, consisting of
|
||||||
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
|
'name' and 'scaling'
|
||||||
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
|
'gpu_split_auto' (bool): Automatically split model across
|
||||||
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
|
available devices (default: True)
|
||||||
|
'gpu_split' (list[float]): Allocation for weights and (some)
|
||||||
|
tensors, per device
|
||||||
|
'no_flash_attn' (bool): Turns off flash attention
|
||||||
|
(increases vram usage) (default: False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.quiet = quiet
|
self.quiet = quiet
|
||||||
|
|
@ -90,7 +113,8 @@ class ModelContainer:
|
||||||
if override_base_seq_len:
|
if override_base_seq_len:
|
||||||
self.config.max_seq_len = override_base_seq_len
|
self.config.max_seq_len = override_base_seq_len
|
||||||
|
|
||||||
# Grab the base model's sequence length before overrides for rope calculations
|
# Grab the base model's sequence length before overrides for
|
||||||
|
# rope calculations
|
||||||
base_seq_len = self.config.max_seq_len
|
base_seq_len = self.config.max_seq_len
|
||||||
|
|
||||||
# Set the target seq len if present
|
# Set the target seq len if present
|
||||||
|
|
@ -103,14 +127,14 @@ class ModelContainer:
|
||||||
|
|
||||||
# Automatically calculate rope alpha
|
# Automatically calculate rope alpha
|
||||||
self.config.scale_alpha_value = unwrap(
|
self.config.scale_alpha_value = unwrap(
|
||||||
kwargs.get("rope_alpha"),
|
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||||
self.calculate_rope_alpha(base_seq_len)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Turn off flash attention?
|
# Turn off flash attention?
|
||||||
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
|
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
|
||||||
|
|
||||||
# low_mem is currently broken in exllamav2. Don't use it until it's fixed.
|
# low_mem is currently broken in exllamav2. Don't use it until it's
|
||||||
|
# fixed.
|
||||||
"""
|
"""
|
||||||
if "low_mem" in kwargs and kwargs["low_mem"]:
|
if "low_mem" in kwargs and kwargs["low_mem"]:
|
||||||
self.config.set_low_mem()
|
self.config.set_low_mem()
|
||||||
|
|
@ -119,7 +143,10 @@ class ModelContainer:
|
||||||
# Set prompt template override if provided
|
# Set prompt template override if provided
|
||||||
prompt_template_name = kwargs.get("prompt_template")
|
prompt_template_name = kwargs.get("prompt_template")
|
||||||
if prompt_template_name:
|
if prompt_template_name:
|
||||||
print(f"Attempting to load prompt template with name {prompt_template_name}")
|
print(
|
||||||
|
"Attempting to load prompt template with name",
|
||||||
|
{prompt_template_name},
|
||||||
|
)
|
||||||
# Read the template
|
# Read the template
|
||||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||||
else:
|
else:
|
||||||
|
|
@ -127,16 +154,17 @@ class ModelContainer:
|
||||||
self.prompt_template = get_template_from_model_json(
|
self.prompt_template = get_template_from_model_json(
|
||||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||||
"chat_template",
|
"chat_template",
|
||||||
"from_tokenizer_config"
|
"from_tokenizer_config",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try finding the chat template from the model's config.json
|
# Try finding the chat template from the model's config.json
|
||||||
# TODO: This may not even be used with huggingface models, mark for removal.
|
# TODO: This may not even be used with huggingface models,
|
||||||
|
# mark for removal.
|
||||||
if self.prompt_template is None:
|
if self.prompt_template is None:
|
||||||
self.prompt_template = get_template_from_model_json(
|
self.prompt_template = get_template_from_model_json(
|
||||||
pathlib.Path(self.config.model_config),
|
pathlib.Path(self.config.model_config),
|
||||||
"chat_template",
|
"chat_template",
|
||||||
"from_model_config"
|
"from_model_config",
|
||||||
)
|
)
|
||||||
|
|
||||||
# If that fails, attempt fetching from model name
|
# If that fails, attempt fetching from model name
|
||||||
|
|
@ -147,10 +175,13 @@ class ModelContainer:
|
||||||
|
|
||||||
# Catch all for template lookup errors
|
# Catch all for template lookup errors
|
||||||
if self.prompt_template:
|
if self.prompt_template:
|
||||||
print(f"Using template {self.prompt_template.name} for chat completions.")
|
print(
|
||||||
|
f"Using template {self.prompt_template.name} for chat " "completions."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
"Chat completions are disabled because a prompt template wasn't provided or auto-detected."
|
"Chat completions are disabled because a prompt template",
|
||||||
|
"wasn't provided or auto-detected.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set num of experts per token if provided
|
# Set num of experts per token if provided
|
||||||
|
|
@ -159,11 +190,16 @@ class ModelContainer:
|
||||||
if hasattr(self.config, "num_experts_per_token"):
|
if hasattr(self.config, "num_experts_per_token"):
|
||||||
self.config.num_experts_per_token = num_experts_override
|
self.config.num_experts_per_token = num_experts_override
|
||||||
else:
|
else:
|
||||||
print(" !! Warning: Currently installed ExLlamaV2 does not support overriding MoE experts")
|
print(
|
||||||
|
" !! Warning: Currently installed ExLlamaV2 does not "
|
||||||
|
"support overriding MoE experts"
|
||||||
|
)
|
||||||
|
|
||||||
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
|
chunk_size = min(
|
||||||
|
unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len
|
||||||
|
)
|
||||||
self.config.max_input_len = chunk_size
|
self.config.max_input_len = chunk_size
|
||||||
self.config.max_attn_size = chunk_size ** 2
|
self.config.max_attn_size = chunk_size**2
|
||||||
|
|
||||||
draft_args = unwrap(kwargs.get("draft"), {})
|
draft_args = unwrap(kwargs.get("draft"), {})
|
||||||
draft_model_name = draft_args.get("draft_model_name")
|
draft_model_name = draft_args.get("draft_model_name")
|
||||||
|
|
@ -171,23 +207,30 @@ class ModelContainer:
|
||||||
|
|
||||||
# Always disable draft if params are incorrectly configured
|
# Always disable draft if params are incorrectly configured
|
||||||
if draft_args and draft_model_name is None:
|
if draft_args and draft_model_name is None:
|
||||||
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
|
print(
|
||||||
|
"A draft config was found but a model name was not given. "
|
||||||
|
"Please check your config.yml! Skipping draft load."
|
||||||
|
)
|
||||||
enable_draft = False
|
enable_draft = False
|
||||||
|
|
||||||
if enable_draft:
|
if enable_draft:
|
||||||
self.draft_config = ExLlamaV2Config()
|
self.draft_config = ExLlamaV2Config()
|
||||||
draft_model_path = pathlib.Path(unwrap(draft_args.get("draft_model_dir"), "models"))
|
draft_model_path = pathlib.Path(
|
||||||
|
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||||
|
)
|
||||||
draft_model_path = draft_model_path / draft_model_name
|
draft_model_path = draft_model_path / draft_model_name
|
||||||
|
|
||||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||||
self.draft_config.prepare()
|
self.draft_config.prepare()
|
||||||
|
|
||||||
self.draft_config.scale_pos_emb = unwrap(draft_args.get("draft_rope_scale"), 1.0)
|
self.draft_config.scale_pos_emb = unwrap(
|
||||||
|
draft_args.get("draft_rope_scale"), 1.0
|
||||||
|
)
|
||||||
|
|
||||||
# Automatically calculate draft rope alpha
|
# Automatically calculate draft rope alpha
|
||||||
self.draft_config.scale_alpha_value = unwrap(
|
self.draft_config.scale_alpha_value = unwrap(
|
||||||
draft_args.get("draft_rope_alpha"),
|
draft_args.get("draft_rope_alpha"),
|
||||||
self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
self.calculate_rope_alpha(self.draft_config.max_seq_len),
|
||||||
)
|
)
|
||||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||||
|
|
||||||
|
|
@ -196,22 +239,31 @@ class ModelContainer:
|
||||||
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
||||||
|
|
||||||
def calculate_rope_alpha(self, base_seq_len):
|
def calculate_rope_alpha(self, base_seq_len):
|
||||||
|
"""Calculate the rope alpha value for a given sequence length."""
|
||||||
ratio = self.config.max_seq_len / base_seq_len
|
ratio = self.config.max_seq_len / base_seq_len
|
||||||
|
|
||||||
# Default to a 1 alpha if the sequence length is ever less than or equal to 1
|
# Default to a 1 alpha if the sequence length is ever less
|
||||||
alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
# than or equal to 1
|
||||||
|
if ratio <= 1.0:
|
||||||
|
alpha = 1
|
||||||
|
else:
|
||||||
|
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
|
||||||
return alpha
|
return alpha
|
||||||
|
|
||||||
def get_model_path(self, is_draft: bool = False):
|
def get_model_path(self, is_draft: bool = False):
|
||||||
model_path = pathlib.Path(self.draft_config.model_dir if is_draft else self.config.model_dir)
|
"""Get the path for this model."""
|
||||||
|
model_path = pathlib.Path(
|
||||||
|
self.draft_config.model_dir if is_draft else self.config.model_dir
|
||||||
|
)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
def load(self, progress_callback = None):
|
def load(self, progress_callback=None):
|
||||||
"""
|
"""
|
||||||
Load model
|
Load model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
progress_callback (function, optional): A function to call for each
|
||||||
|
module loaded. Prototype:
|
||||||
def progress(loaded_modules: int, total_modules: int)
|
def progress(loaded_modules: int, total_modules: int)
|
||||||
"""
|
"""
|
||||||
for _ in self.load_gen(progress_callback):
|
for _ in self.load_gen(progress_callback):
|
||||||
|
|
@ -231,25 +283,32 @@ class ModelContainer:
|
||||||
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
||||||
|
|
||||||
if lora_name is None:
|
if lora_name is None:
|
||||||
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
|
print(
|
||||||
|
"One of your loras does not have a name. Please check your "
|
||||||
|
"config.yml! Skipping lora load."
|
||||||
|
)
|
||||||
failure.append(lora_name)
|
failure.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||||
lora_path = lora_directory / lora_name
|
lora_path = lora_directory / lora_name
|
||||||
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling))
|
# FIXME(alpin): Does self.model need to be passed here?
|
||||||
|
self.active_loras.append(
|
||||||
|
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||||
|
)
|
||||||
print("Lora successfully loaded.")
|
print("Lora successfully loaded.")
|
||||||
success.append(lora_name)
|
success.append(lora_name)
|
||||||
|
|
||||||
# Return success and failure names
|
# Return success and failure names
|
||||||
return { 'success': success, 'failure': failure }
|
return {"success": success, "failure": failure}
|
||||||
|
|
||||||
def load_gen(self, progress_callback = None):
|
def load_gen(self, progress_callback=None):
|
||||||
"""
|
"""
|
||||||
Load model, generator function
|
Load model, generator function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
progress_callback (function, optional): A function to call for each module loaded. Prototype:
|
progress_callback (function, optional): A function to call for each
|
||||||
|
module loaded. Prototype:
|
||||||
def progress(loaded_modules: int, total_modules: int)
|
def progress(loaded_modules: int, total_modules: int)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -262,13 +321,18 @@ class ModelContainer:
|
||||||
if not self.quiet:
|
if not self.quiet:
|
||||||
print("Loading draft model: " + self.draft_config.model_dir)
|
print("Loading draft model: " + self.draft_config.model_dir)
|
||||||
|
|
||||||
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy = True)
|
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
|
||||||
reserve = [auto_split_reserve_bytes] + [0] * 16
|
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||||
yield from self.draft_model.load_autosplit_gen(self.draft_cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback)
|
yield from self.draft_model.load_autosplit_gen(
|
||||||
|
self.draft_cache,
|
||||||
|
reserve_vram=reserve,
|
||||||
|
last_id_only=True,
|
||||||
|
callback_gen=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
# Test VRAM allocation with a full-length forward pass
|
# Test VRAM allocation with a full-length forward pass
|
||||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
|
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||||
self.draft_model.forward(input_ids, cache = self.cache, preprocess_only = True)
|
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
self.model = ExLlamaV2(self.config)
|
self.model = ExLlamaV2(self.config)
|
||||||
|
|
@ -276,29 +340,41 @@ class ModelContainer:
|
||||||
print("Loading model: " + self.config.model_dir)
|
print("Loading model: " + self.config.model_dir)
|
||||||
|
|
||||||
if not self.gpu_split_auto:
|
if not self.gpu_split_auto:
|
||||||
for value in self.model.load_gen(self.gpu_split, callback_gen = progress_callback):
|
for value in self.model.load_gen(
|
||||||
|
self.gpu_split, callback_gen=progress_callback
|
||||||
|
):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
if self.cache_fp8:
|
if self.cache_fp8:
|
||||||
self.cache = ExLlamaV2Cache_8bit(self.model, lazy = self.gpu_split_auto)
|
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
|
||||||
else:
|
else:
|
||||||
self.cache = ExLlamaV2Cache(self.model, lazy = self.gpu_split_auto)
|
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
|
||||||
|
|
||||||
if self.gpu_split_auto:
|
if self.gpu_split_auto:
|
||||||
reserve = [auto_split_reserve_bytes] + [0] * 16
|
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||||
yield from self.model.load_autosplit_gen(self.cache, reserve_vram = reserve, last_id_only = True, callback_gen = progress_callback)
|
yield from self.model.load_autosplit_gen(
|
||||||
|
self.cache,
|
||||||
|
reserve_vram=reserve,
|
||||||
|
last_id_only=True,
|
||||||
|
callback_gen=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
# Test VRAM allocation with a full-length forward pass
|
# Test VRAM allocation with a full-length forward pass
|
||||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype = torch.long)
|
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||||
self.model.forward(input_ids, cache = self.cache, preprocess_only = True)
|
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||||
|
|
||||||
# Create generator
|
# Create generator
|
||||||
self.generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer, self.draft_model, self.draft_cache)
|
self.generator = ExLlamaV2StreamingGenerator(
|
||||||
|
self.model,
|
||||||
|
self.cache,
|
||||||
|
self.tokenizer,
|
||||||
|
self.draft_model,
|
||||||
|
self.draft_cache,
|
||||||
|
)
|
||||||
|
|
||||||
print("Model successfully loaded.")
|
print("Model successfully loaded.")
|
||||||
|
|
||||||
|
|
||||||
def unload(self, loras_only: bool = False):
|
def unload(self, loras_only: bool = False):
|
||||||
"""
|
"""
|
||||||
Free all VRAM resources used by this model
|
Free all VRAM resources used by this model
|
||||||
|
|
@ -327,19 +403,24 @@ class ModelContainer:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Common function for token operations
|
|
||||||
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
|
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
|
||||||
|
"""Common function for token operations"""
|
||||||
if text:
|
if text:
|
||||||
# Assume token encoding
|
# Assume token encoding
|
||||||
return self.tokenizer.encode(
|
return self.tokenizer.encode(
|
||||||
text,
|
text,
|
||||||
add_bos = unwrap(kwargs.get("add_bos_token"), True),
|
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||||
encode_special_tokens = unwrap(kwargs.get("encode_special_tokens"), True)
|
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||||
)
|
)
|
||||||
if ids:
|
if ids:
|
||||||
# Assume token decoding
|
# Assume token decoding
|
||||||
ids = torch.tensor([ids])
|
ids = torch.tensor([ids])
|
||||||
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
|
return self.tokenizer.decode(
|
||||||
|
ids,
|
||||||
|
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
|
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
|
||||||
return {
|
return {
|
||||||
|
|
@ -350,13 +431,15 @@ class ModelContainer:
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate(self, prompt: str, **kwargs):
|
def generate(self, prompt: str, **kwargs):
|
||||||
|
"""Generate a response to a prompt"""
|
||||||
generation = list(self.generate_gen(prompt, **kwargs))
|
generation = list(self.generate_gen(prompt, **kwargs))
|
||||||
if generation:
|
if generation:
|
||||||
response = "".join(map(lambda chunk: chunk[0], generation))
|
response = "".join(map(lambda chunk: chunk[0], generation))
|
||||||
return response, generation[-1][1], generation[-1][2]
|
return response, generation[-1][1], generation[-1][2]
|
||||||
else:
|
|
||||||
return "", 0, 0
|
return "", 0, 0
|
||||||
|
|
||||||
|
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
||||||
def generate_gen(self, prompt: str, **kwargs):
|
def generate_gen(self, prompt: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create generator function for prompt completion
|
Create generator function for prompt completion
|
||||||
|
|
@ -366,7 +449,8 @@ class ModelContainer:
|
||||||
**kwargs:
|
**kwargs:
|
||||||
'token_healing' (bool): Use token healing (default: False)
|
'token_healing' (bool): Use token healing (default: False)
|
||||||
'temperature' (float): Sampling temperature (default: 1.0)
|
'temperature' (float): Sampling temperature (default: 1.0)
|
||||||
'temperature_last' (bool): Apply temperature after all other samplers (default: False)
|
'temperature_last' (bool): Apply temperature after all other
|
||||||
|
samplers (default: False)
|
||||||
'top_k' (int): Sampling top-K (default: 0)
|
'top_k' (int): Sampling top-K (default: 0)
|
||||||
'top_p' (float): Sampling top-P (default: 1.0)
|
'top_p' (float): Sampling top-P (default: 1.0)
|
||||||
'min_p' (float): Sampling min-P (default: 0.0)
|
'min_p' (float): Sampling min-P (default: 0.0)
|
||||||
|
|
@ -375,19 +459,27 @@ class ModelContainer:
|
||||||
'mirostat' (bool): Use Mirostat (default: False)
|
'mirostat' (bool): Use Mirostat (default: False)
|
||||||
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
|
'mirostat_tau' (float) Mirostat tau parameter (default: 1.5)
|
||||||
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
|
'mirostat_eta' (float) Mirostat eta parameter (default: 0.1)
|
||||||
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
'repetition_penalty' (float): Token repetition/presence penalty
|
||||||
'repetition_range' (int): Repetition penalty range (default: whole context)
|
(default: 1.15)
|
||||||
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
'repetition_range' (int): Repetition penalty range
|
||||||
'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS])
|
(default: whole context)
|
||||||
|
'repetition_decay' (int): Repetition penalty range
|
||||||
|
(default: same as range)
|
||||||
|
'stop' (List[Union[str, int]]): List of stop strings/tokens to
|
||||||
|
end response (default: [EOS])
|
||||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||||
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
|
'add_bos_token' (bool): Adds the BOS token to the start of the
|
||||||
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
|
prompt (default: True)
|
||||||
'logit_bias' (Dict[int, float]): Biases specific tokens to either show up more or less (default: None)
|
'ban_eos_token' (bool): Bans the EOS token from generation
|
||||||
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
|
(default: False)
|
||||||
'generate_window' (int): Space to reserve at the end of the model's context when generating.
|
'logit_bias' (Dict[int, float]): Biases specific tokens to
|
||||||
Rolls context window by the same amount if context length is exceeded to allow generating past
|
either show up more or less (default: None)
|
||||||
the models max_seq_len.
|
'stream_interval' (float): Interval in seconds between each
|
||||||
|
output chunk (default: immediate)
|
||||||
|
'generate_window' (int): Space to reserve at the end of the
|
||||||
|
model's context when generating. Rolls context window by
|
||||||
|
the same amount if context length is exceeded to allow
|
||||||
|
generating pastthe models max_seq_len.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||||
|
|
@ -399,17 +491,37 @@ class ModelContainer:
|
||||||
gen_settings = ExLlamaV2Sampler.Settings()
|
gen_settings = ExLlamaV2Sampler.Settings()
|
||||||
|
|
||||||
# Warn of unsupported settings if the setting is enabled
|
# Warn of unsupported settings if the setting is enabled
|
||||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(gen_settings, "mirostat"):
|
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
|
||||||
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
|
gen_settings, "mirostat"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||||
|
"Mirostat sampling"
|
||||||
|
)
|
||||||
|
|
||||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||||
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
|
gen_settings, "min_p"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
" !! Warning: Currently installed ExLlamaV2 does not "
|
||||||
|
"support min-P sampling"
|
||||||
|
)
|
||||||
|
|
||||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||||
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
|
gen_settings, "tfs"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||||
|
"tail-free sampling (TFS)"
|
||||||
|
)
|
||||||
|
|
||||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(gen_settings, "temperature_last"):
|
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
|
||||||
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
|
gen_settings, "temperature_last"
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||||
|
"temperature_last"
|
||||||
|
)
|
||||||
|
|
||||||
# Apply settings
|
# Apply settings
|
||||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||||
|
|
@ -424,14 +536,24 @@ class ModelContainer:
|
||||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||||
gen_settings.token_repetition_penalty = unwrap(kwargs.get("repetition_penalty"), 1.0)
|
gen_settings.token_repetition_penalty = unwrap(
|
||||||
gen_settings.token_repetition_range = unwrap(kwargs.get("repetition_range"), self.config.max_seq_len)
|
kwargs.get("repetition_penalty"), 1.0
|
||||||
|
)
|
||||||
|
gen_settings.token_repetition_range = unwrap(
|
||||||
|
kwargs.get("repetition_range"), self.config.max_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
# Always make sure the fallback is 0 if range < 0
|
# Always make sure the fallback is 0 if range < 0
|
||||||
# It's technically fine to use -1, but this just validates the passed fallback
|
# It's technically fine to use -1, but this just validates the passed
|
||||||
|
# fallback
|
||||||
# Always default to 0 if something goes wrong
|
# Always default to 0 if something goes wrong
|
||||||
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
|
if gen_settings.token_repetition_range <= 0:
|
||||||
gen_settings.token_repetition_decay = coalesce(kwargs.get("repetition_decay"), fallback_decay, 0)
|
fallback_decay = 0
|
||||||
|
else:
|
||||||
|
fallback_decay = gen_settings.token_repetition_range
|
||||||
|
gen_settings.token_repetition_decay = coalesce(
|
||||||
|
kwargs.get("repetition_decay"), fallback_decay, 0
|
||||||
|
)
|
||||||
|
|
||||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||||
|
|
@ -448,13 +570,13 @@ class ModelContainer:
|
||||||
# Log generation options to console
|
# Log generation options to console
|
||||||
# Some options are too large, so log the args instead
|
# Some options are too large, so log the args instead
|
||||||
log_generation_params(
|
log_generation_params(
|
||||||
max_tokens = max_tokens,
|
max_tokens=max_tokens,
|
||||||
**vars(gen_settings),
|
**vars(gen_settings),
|
||||||
token_healing = token_healing,
|
token_healing=token_healing,
|
||||||
add_bos_token = add_bos_token,
|
add_bos_token=add_bos_token,
|
||||||
ban_eos_token = ban_eos_token,
|
ban_eos_token=ban_eos_token,
|
||||||
stop_conditions = stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
logit_bias = logit_bias
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log prompt to console
|
# Log prompt to console
|
||||||
|
|
@ -465,13 +587,17 @@ class ModelContainer:
|
||||||
# Create a vocab tensor if it doesn't exist for token biasing
|
# Create a vocab tensor if it doesn't exist for token biasing
|
||||||
if gen_settings.token_bias is None:
|
if gen_settings.token_bias is None:
|
||||||
padding = -self.tokenizer.config.vocab_size % 32
|
padding = -self.tokenizer.config.vocab_size % 32
|
||||||
gen_settings.token_bias = torch.zeros((self.tokenizer.config.vocab_size + padding,), dtype = torch.float)
|
gen_settings.token_bias = torch.zeros(
|
||||||
|
(self.tokenizer.config.vocab_size + padding,),
|
||||||
|
dtype=torch.float,
|
||||||
|
)
|
||||||
|
|
||||||
# Map logits to the tensor with their biases
|
# Map logits to the tensor with their biases
|
||||||
for token, bias in logit_bias.items():
|
for token, bias in logit_bias.items():
|
||||||
gen_settings.token_bias[token] = bias
|
gen_settings.token_bias[token] = bias
|
||||||
|
|
||||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
# Ban the EOS token if specified. If not, append to stop conditions
|
||||||
|
# as well.
|
||||||
# Set this below logging to avoid polluting the stop strings array
|
# Set this below logging to avoid polluting the stop strings array
|
||||||
if ban_eos_token:
|
if ban_eos_token:
|
||||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||||
|
|
@ -483,16 +609,15 @@ class ModelContainer:
|
||||||
|
|
||||||
# Tokenized context
|
# Tokenized context
|
||||||
ids = self.tokenizer.encode(
|
ids = self.tokenizer.encode(
|
||||||
prompt,
|
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
||||||
add_bos = add_bos_token,
|
|
||||||
encode_special_tokens = True
|
|
||||||
)
|
)
|
||||||
context_len = len(ids[0])
|
context_len = len(ids[0])
|
||||||
|
|
||||||
if context_len > self.config.max_seq_len:
|
if context_len > self.config.max_seq_len:
|
||||||
print(
|
print(
|
||||||
f"WARNING: The context length {context_len} is greater than the max_seq_len {self.config.max_seq_len}.",
|
f"WARNING: The context length {context_len} is greater than "
|
||||||
"Generation is truncated and metrics may not be accurate."
|
f"the max_seq_len {self.config.max_seq_len}.",
|
||||||
|
"Generation is truncated and metrics may not be accurate.",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_tokens = ids.shape[-1]
|
prompt_tokens = ids.shape[-1]
|
||||||
|
|
@ -503,26 +628,32 @@ class ModelContainer:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
last_chunk_time = start_time
|
last_chunk_time = start_time
|
||||||
|
|
||||||
save_tokens = torch.empty((1, 0), dtype = torch.bool)
|
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
||||||
chunk_buffer = ""
|
chunk_buffer = ""
|
||||||
chunk_tokens = 0
|
chunk_tokens = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Ingest prompt
|
# Ingest prompt
|
||||||
if chunk_tokens == 0:
|
if chunk_tokens == 0:
|
||||||
ids = torch.cat((ids, save_tokens), dim = - 1)
|
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||||
save_tokens = torch.empty((1, 0), dtype = torch.bool)
|
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
||||||
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
|
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
|
||||||
active_ids = ids[:, max(0, overflow):]
|
active_ids = ids[:, max(0, overflow) :]
|
||||||
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
||||||
|
|
||||||
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras)
|
self.generator.begin_stream(
|
||||||
|
active_ids,
|
||||||
|
gen_settings,
|
||||||
|
token_healing=token_healing,
|
||||||
|
loras=self.active_loras,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
chunk, eos, tokens = self.generator.stream()
|
chunk, eos, tokens = self.generator.stream()
|
||||||
|
|
||||||
if token_healing:
|
if token_healing:
|
||||||
ids[:, -1] = self.generator.sequence_ids[:, -2] # Extract healed token
|
# Extract healed token
|
||||||
|
ids[:, -1] = self.generator.sequence_ids[:, -2]
|
||||||
token_healing = False
|
token_healing = False
|
||||||
|
|
||||||
save_tokens = torch.cat((save_tokens, tokens), dim=-1)
|
save_tokens = torch.cat((save_tokens, tokens), dim=-1)
|
||||||
|
|
@ -535,7 +666,9 @@ class ModelContainer:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
elapsed = now - last_chunk_time
|
elapsed = now - last_chunk_time
|
||||||
|
|
||||||
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
|
if chunk_buffer != "" and (
|
||||||
|
elapsed > stream_interval or eos or generated_tokens == max_tokens
|
||||||
|
):
|
||||||
yield chunk_buffer, prompt_tokens, generated_tokens
|
yield chunk_buffer, prompt_tokens, generated_tokens
|
||||||
full_response += chunk_buffer
|
full_response += chunk_buffer
|
||||||
chunk_buffer = ""
|
chunk_buffer = ""
|
||||||
|
|
@ -549,12 +682,20 @@ class ModelContainer:
|
||||||
|
|
||||||
elapsed_time = last_chunk_time - start_time
|
elapsed_time = last_chunk_time - start_time
|
||||||
|
|
||||||
initial_response = f"Metrics: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
|
initial_response = (
|
||||||
|
f"Metrics: {generated_tokens} tokens generated in "
|
||||||
|
f"{round(elapsed_time, 2)} seconds"
|
||||||
|
)
|
||||||
itemization = []
|
itemization = []
|
||||||
extra_parts = []
|
extra_parts = []
|
||||||
|
|
||||||
# Add tokens per second
|
# Add tokens per second
|
||||||
itemization.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s")
|
tokens_per_second = (
|
||||||
|
"Indeterminate"
|
||||||
|
if elapsed_time == 0
|
||||||
|
else round(generated_tokens / elapsed_time, 2)
|
||||||
|
)
|
||||||
|
itemization.append(f"{tokens_per_second} T/s")
|
||||||
|
|
||||||
# Add context (original token count)
|
# Add context (original token count)
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
|
|
@ -564,4 +705,10 @@ class ModelContainer:
|
||||||
extra_parts.append("<-- Not accurate (truncated)")
|
extra_parts.append("<-- Not accurate (truncated)")
|
||||||
|
|
||||||
# Print output
|
# Print output
|
||||||
print(initial_response + " (" + ", ".join(itemization) + ") " + " ".join(extra_parts))
|
print(
|
||||||
|
initial_response
|
||||||
|
+ " ("
|
||||||
|
+ ", ".join(itemization)
|
||||||
|
+ ") "
|
||||||
|
+ " ".join(extra_parts)
|
||||||
|
)
|
||||||
|
|
|
||||||
15
requirements-dev.txt
Normal file
15
requirements-dev.txt
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
# formatting
|
||||||
|
ruff==0.1.9
|
||||||
|
|
||||||
|
## Implement below dependencies when support is added
|
||||||
|
|
||||||
|
# type checking
|
||||||
|
# mypy==0.991
|
||||||
|
# types-PyYAML
|
||||||
|
# types-requests
|
||||||
|
# types-setuptools
|
||||||
|
|
||||||
|
# testing
|
||||||
|
# pytest
|
||||||
|
# pytest-forked
|
||||||
|
# pytest-asyncio
|
||||||
|
|
@ -1,46 +1,56 @@
|
||||||
|
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.metadata import version as package_version
|
from importlib.metadata import version as package_version
|
||||||
|
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
# Small replication of AutoTokenizer's chat template system for efficiency
|
|
||||||
|
|
||||||
class PromptTemplate(BaseModel):
|
class PromptTemplate(BaseModel):
|
||||||
|
"""A template for chat completion prompts."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
template: str
|
template: str
|
||||||
|
|
||||||
def get_prompt_from_template(messages,
|
|
||||||
|
def get_prompt_from_template(
|
||||||
|
messages,
|
||||||
prompt_template: PromptTemplate,
|
prompt_template: PromptTemplate,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
special_tokens: Optional[Dict[str, str]] = None):
|
special_tokens: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""Get a prompt from a template and a list of messages."""
|
||||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Parsing these chat completion messages requires jinja2 3.0.0 or greater. "
|
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
||||||
f"Current version: {version('jinja2')}\n"
|
f"or greater. Current version: {package_version('jinja2')}\n"
|
||||||
"Please upgrade jinja by running the following command: "
|
"Please upgrade jinja by running the following command: "
|
||||||
"pip install --upgrade jinja2"
|
"pip install --upgrade jinja2"
|
||||||
)
|
)
|
||||||
|
|
||||||
compiled_template = _compile_template(prompt_template.template)
|
compiled_template = _compile_template(prompt_template.template)
|
||||||
return compiled_template.render(
|
return compiled_template.render(
|
||||||
messages = messages,
|
messages=messages,
|
||||||
add_generation_prompt = add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
**special_tokens,
|
**special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
|
||||||
|
# Inspired from
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def _compile_template(template: str):
|
def _compile_template(template: str):
|
||||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True)
|
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||||
jinja_template = jinja_env.from_string(template)
|
jinja_template = jinja_env.from_string(template)
|
||||||
return jinja_template
|
return jinja_template
|
||||||
|
|
||||||
# Find a matching template name from a model path
|
|
||||||
def find_template_from_model(model_path: pathlib.Path):
|
def find_template_from_model(model_path: pathlib.Path):
|
||||||
|
"""Find a matching template name from a model path."""
|
||||||
model_name = model_path.name
|
model_name = model_path.name
|
||||||
template_directory = pathlib.Path("templates")
|
template_directory = pathlib.Path("templates")
|
||||||
for filepath in template_directory.glob("*.jinja"):
|
for filepath in template_directory.glob("*.jinja"):
|
||||||
|
|
@ -50,14 +60,16 @@ def find_template_from_model(model_path: pathlib.Path):
|
||||||
if template_name in model_name.lower():
|
if template_name in model_name.lower():
|
||||||
return template_name
|
return template_name
|
||||||
|
|
||||||
# Get a template from a jinja file
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_template_from_file(prompt_template_name: str):
|
def get_template_from_file(prompt_template_name: str):
|
||||||
|
"""Get a template from a jinja file."""
|
||||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||||
if template_path.exists():
|
if template_path.exists():
|
||||||
with open(template_path, "r", encoding = "utf8") as raw_template:
|
with open(template_path, "r", encoding="utf8") as raw_template:
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
name = prompt_template_name,
|
name=prompt_template_name, template=raw_template.read()
|
||||||
template = raw_template.read()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
@ -66,15 +78,12 @@ def get_template_from_file(prompt_template_name: str):
|
||||||
# Get a template from a JSON file
|
# Get a template from a JSON file
|
||||||
# Requires a key and template name
|
# Requires a key and template name
|
||||||
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
||||||
|
"""Get a template from a JSON file. Requires a key and template name"""
|
||||||
if json_path.exists():
|
if json_path.exists():
|
||||||
with open(json_path, "r", encoding = "utf8") as config_file:
|
with open(json_path, "r", encoding="utf8") as config_file:
|
||||||
model_config = json.load(config_file)
|
model_config = json.load(config_file)
|
||||||
chat_template = model_config.get(key)
|
chat_template = model_config.get(key)
|
||||||
if chat_template:
|
if chat_template:
|
||||||
return PromptTemplate(
|
return PromptTemplate(name=name, template=chat_template)
|
||||||
name = name,
|
|
||||||
template = chat_template
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,49 @@
|
||||||
|
""" Test the model container. """
|
||||||
from model import ModelContainer
|
from model import ModelContainer
|
||||||
|
|
||||||
|
|
||||||
def progress(module, modules):
|
def progress(module, modules):
|
||||||
|
"""Wrapper callback for load progress."""
|
||||||
yield module, modules
|
yield module, modules
|
||||||
|
|
||||||
container = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/")
|
|
||||||
loader = container.load_gen(progress)
|
def test_load_gen(model_path):
|
||||||
for (module, modules) in loader:
|
"""Test loading a model."""
|
||||||
|
container = ModelContainer(model_path)
|
||||||
|
loader = container.load_gen(progress)
|
||||||
|
for module, modules in loader:
|
||||||
print(module, modules)
|
print(module, modules)
|
||||||
|
container.unload()
|
||||||
|
del container
|
||||||
|
|
||||||
generator = container.generate_gen("Once upon a tim", token_healing = True)
|
|
||||||
for g in generator:
|
|
||||||
print(g, end = "")
|
|
||||||
|
|
||||||
container.unload()
|
def test_generate_gen(model_path):
|
||||||
del container
|
"""Test generating from a model."""
|
||||||
|
container = ModelContainer(model_path)
|
||||||
|
generator = container.generate_gen("Once upon a tim", token_healing=True)
|
||||||
|
for chunk in generator:
|
||||||
|
print(chunk, end="")
|
||||||
|
container.unload()
|
||||||
|
del container
|
||||||
|
|
||||||
mc = ModelContainer("/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/")
|
|
||||||
mc.load(progress)
|
|
||||||
|
|
||||||
response = mc.generate("All work and no play makes turbo a derpy cat.\nAll work and no play makes turbo a derpy cat.\nAll", top_k = 1, max_new_tokens = 1000, stream_interval = 0.5)
|
def test_generate(model_path):
|
||||||
print (response)
|
"""Test generating from a model."""
|
||||||
|
model_container = ModelContainer(model_path)
|
||||||
|
model_container.load(progress)
|
||||||
|
prompt = (
|
||||||
|
"All work and no play makes turbo a derpy cat.\n"
|
||||||
|
"All work and no play makes turbo a derpy cat.\nAll"
|
||||||
|
)
|
||||||
|
response = model_container.generate(
|
||||||
|
prompt, top_k=1, max_new_tokens=1000, stream_interval=0.5
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
MODEL1 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.0bpw/"
|
||||||
|
MODEL2 = "/mnt/str/models/_exl2/mistral-7b-instruct-exl2/4.65bpw/"
|
||||||
|
test_load_gen(MODEL1)
|
||||||
|
test_generate_gen(MODEL1)
|
||||||
|
test_generate(MODEL2)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
""" Test if the wheels are installed correctly. """
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
|
@ -34,8 +35,12 @@ else:
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"\nSuccessful imports: {', '.join(successful_packages)}",
|
f"\nSuccessful imports: {', '.join(successful_packages)}",
|
||||||
f"\nErrored imports: {''.join(errored_packages)}"
|
f"\nErrored imports: {''.join(errored_packages)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(errored_packages) > 0:
|
if len(errored_packages) > 0:
|
||||||
print("\nIf packages are installed, but not found on this test, please check the wheel versions for the correct python version and CUDA version (if applicable).")
|
print(
|
||||||
|
"\nIf packages are installed, but not found on this test, please "
|
||||||
|
"check the wheel versions for the correct python version and CUDA "
|
||||||
|
"version (if applicable)."
|
||||||
|
)
|
||||||
|
|
|
||||||
35
utils.py
35
utils.py
|
|
@ -1,43 +1,54 @@
|
||||||
|
"""Common utilities for the tabbyAPI"""
|
||||||
import traceback
|
import traceback
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# Wrapper callback for load progress
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
def load_progress(module, modules):
|
def load_progress(module, modules):
|
||||||
|
"""Wrapper callback for load progress."""
|
||||||
yield module, modules
|
yield module, modules
|
||||||
|
|
||||||
# Common error types
|
|
||||||
class TabbyGeneratorErrorMessage(BaseModel):
|
class TabbyGeneratorErrorMessage(BaseModel):
|
||||||
|
"""Common error types."""
|
||||||
|
|
||||||
message: str
|
message: str
|
||||||
trace: Optional[str] = None
|
trace: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class TabbyGeneratorError(BaseModel):
|
class TabbyGeneratorError(BaseModel):
|
||||||
|
"""Common error types."""
|
||||||
|
|
||||||
error: TabbyGeneratorErrorMessage
|
error: TabbyGeneratorErrorMessage
|
||||||
|
|
||||||
|
|
||||||
def get_generator_error(message: str):
|
def get_generator_error(message: str):
|
||||||
|
"""Get a generator error."""
|
||||||
error_message = TabbyGeneratorErrorMessage(
|
error_message = TabbyGeneratorErrorMessage(
|
||||||
message = message,
|
message=message, trace=traceback.format_exc()
|
||||||
trace = traceback.format_exc()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
generator_error = TabbyGeneratorError(
|
generator_error = TabbyGeneratorError(error=error_message)
|
||||||
error = error_message
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log and send the exception
|
# Log and send the exception
|
||||||
print(f"\n{generator_error.error.trace}")
|
print(f"\n{generator_error.error.trace}")
|
||||||
return get_sse_packet(generator_error.model_dump_json())
|
return get_sse_packet(generator_error.model_dump_json())
|
||||||
|
|
||||||
|
|
||||||
def get_sse_packet(json_data: str):
|
def get_sse_packet(json_data: str):
|
||||||
|
"""Get an SSE packet."""
|
||||||
return f"data: {json_data}\n\n"
|
return f"data: {json_data}\n\n"
|
||||||
|
|
||||||
# Unwrap function for Optionals
|
|
||||||
def unwrap(wrapped, default = None):
|
def unwrap(wrapped, default=None):
|
||||||
|
"""Unwrap function for Optionals."""
|
||||||
if wrapped is None:
|
if wrapped is None:
|
||||||
return default
|
return default
|
||||||
else:
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
# Coalesce function for multiple unwraps
|
|
||||||
def coalesce(*args):
|
def coalesce(*args):
|
||||||
|
"""Coalesce function for multiple unwraps."""
|
||||||
return next((arg for arg in args if arg is not None), None)
|
return next((arg for arg in args if arg is not None), None)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue