Config: Add experimental torch cuda malloc backend
This option saves some VRAM, but does have the chance to error out. Add this in the experimental config section. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
664e2c417e
commit
949248fb94
3 changed files with 16 additions and 1 deletions
|
|
@ -140,3 +140,8 @@ def add_developer_args(parser: argparse.ArgumentParser):
|
|||
type=str_to_bool,
|
||||
help="Disables API request streaming",
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--cuda-malloc-backend",
|
||||
type=str_to_bool,
|
||||
help="Disables API request streaming",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ sampling:
|
|||
# WARNING: Using this can result in a generation speed penalty
|
||||
#override_preset:
|
||||
|
||||
# Options for development
|
||||
# Options for development and experimentation
|
||||
developer:
|
||||
# Skips exllamav2 version check (default: False)
|
||||
# It's highly recommended to update your dependencies rather than enabling this flag
|
||||
|
|
@ -46,6 +46,10 @@ developer:
|
|||
# A kill switch for turning off SSE in the API server
|
||||
#disable_request_streaming: False
|
||||
|
||||
# Enable the torch CUDA malloc backend (default: False)
|
||||
# This can save a few MBs of VRAM, but has a risk of errors. Use at your own risk.
|
||||
#cuda_malloc_backend: False
|
||||
|
||||
# Options for model overrides and loading
|
||||
model:
|
||||
# Overrides the directory to look for models (default: models)
|
||||
|
|
|
|||
6
main.py
6
main.py
|
|
@ -1,4 +1,5 @@
|
|||
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||
import os
|
||||
import pathlib
|
||||
import uvicorn
|
||||
from asyncio import CancelledError
|
||||
|
|
@ -600,6 +601,11 @@ def entrypoint(args: Optional[dict] = None):
|
|||
else:
|
||||
check_exllama_version()
|
||||
|
||||
# Enable CUDA malloc backend
|
||||
if unwrap(developer_config.get("cuda_malloc_backend"), False):
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
||||
logger.warning("Enabled the experimental CUDA malloc backend.")
|
||||
|
||||
network_config = get_network_config()
|
||||
|
||||
# Initialize auth keys
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue