Model: Fetch from generation_config and tokenizer_config in Exl3
Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
59d081fe83
commit
e8f00412f6
2 changed files with 31 additions and 3 deletions
|
|
@ -120,7 +120,6 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
self.config.max_seq_len = 4096
|
self.config.max_seq_len = 4096
|
||||||
|
|
||||||
self.config.prepare()
|
self.config.prepare()
|
||||||
print(self.config.max_seq_len)
|
|
||||||
|
|
||||||
# Check if the model arch is compatible with various exl2 features
|
# Check if the model arch is compatible with various exl2 features
|
||||||
self.config.arch_compat_overrides()
|
self.config.arch_compat_overrides()
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import traceback
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
|
|
@ -32,7 +33,7 @@ from common.health import HealthManager
|
||||||
from common.multimodal import MultimodalEmbeddingWrapper
|
from common.multimodal import MultimodalEmbeddingWrapper
|
||||||
from common.sampling import BaseSamplerRequest
|
from common.sampling import BaseSamplerRequest
|
||||||
from common.templating import PromptTemplate, find_prompt_template
|
from common.templating import PromptTemplate, find_prompt_template
|
||||||
from common.transformers_utils import GenerationConfig
|
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
||||||
from common.utils import coalesce, unwrap
|
from common.utils import coalesce, unwrap
|
||||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||||
|
|
||||||
|
|
@ -60,6 +61,8 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
tokenizer: Tokenizer
|
tokenizer: Tokenizer
|
||||||
config: Config
|
config: Config
|
||||||
generator: Optional[AsyncGenerator] = None
|
generator: Optional[AsyncGenerator] = None
|
||||||
|
generation_config: Optional[GenerationConfig] = None
|
||||||
|
tokenizer_config: Optional[TokenizerConfig] = None
|
||||||
|
|
||||||
# Class-specific vars
|
# Class-specific vars
|
||||||
gpu_split: List[float] | None = None
|
gpu_split: List[float] | None = None
|
||||||
|
|
@ -96,6 +99,30 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
self.model = Model.from_config(self.config)
|
self.model = Model.from_config(self.config)
|
||||||
self.tokenizer = Tokenizer.from_config(self.config)
|
self.tokenizer = Tokenizer.from_config(self.config)
|
||||||
|
|
||||||
|
# Load generation config overrides
|
||||||
|
generation_config_path = model_directory / "generation_config.json"
|
||||||
|
if generation_config_path.exists():
|
||||||
|
try:
|
||||||
|
self.generation_config = await GenerationConfig.from_file(
|
||||||
|
model_directory
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
"Skipping generation config load because of an unexpected error."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load tokenizer config overrides
|
||||||
|
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||||
|
if tokenizer_config_path.exists():
|
||||||
|
try:
|
||||||
|
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
|
||||||
|
except Exception:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.warning(
|
||||||
|
"Skipping tokenizer config load because of an unexpected error."
|
||||||
|
)
|
||||||
|
|
||||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||||
|
|
||||||
|
|
@ -689,7 +716,9 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
|
|
||||||
prompts = [prompt]
|
prompts = [prompt]
|
||||||
stop_conditions = params.stop
|
stop_conditions = params.stop
|
||||||
add_bos_token = params.add_bos_token
|
add_bos_token = unwrap(
|
||||||
|
params.add_bos_token, self.tokenizer_config.add_bos_token, True
|
||||||
|
)
|
||||||
|
|
||||||
# Fetch EOS tokens from generation_config if they exist
|
# Fetch EOS tokens from generation_config if they exist
|
||||||
eos_tokens = (
|
eos_tokens = (
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue