From 894be4a81859fd3d3d93c04d8dfc9e7a697acf4b Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 11 Mar 2024 21:55:16 -0400 Subject: [PATCH] Startup: Check if the port is available and fallback Similar to Gradio, fall back to port + 1 if the config port isn't bindable. If both ports aren't available, let the user know and exit. An infinite loop of finding a port isn't advisable. Signed-off-by: kingbri --- common/utils.py | 12 ++++++++++++ main.py | 26 +++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/common/utils.py b/common/utils.py index 0f207f4..5ad12a1 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,5 +1,6 @@ """Common utility functions""" +import socket import traceback from loguru import logger from pydantic import BaseModel @@ -67,3 +68,14 @@ def prune_dict(input_dict): """Trim out instances of None from a dictionary""" return {k: v for k, v in input_dict.items() if v is not None} + + +def is_port_in_use(port: int) -> bool: + """ + Checks if a port is in use + + From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python + """ + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 diff --git a/main.py b/main.py index dfb6c06..8b0aebd 100644 --- a/main.py +++ b/main.py @@ -53,6 +53,7 @@ from common.utils import ( get_generator_error, handle_request_error, load_progress, + is_port_in_use, unwrap, ) from OAI.types.completion import CompletionRequest @@ -732,6 +733,28 @@ def entrypoint(args: Optional[dict] = None): network_config = get_network_config() + host = unwrap(network_config.get("host"), "127.0.0.1") + port = unwrap(network_config.get("port"), 5000) + + # Check if the port is available and attempt to bind a fallback + if is_port_in_use(port): + fallback_port = port + 1 + + if is_port_in_use(fallback_port): + logger.error( + f"Ports {port} and {fallback_port} are in use by different services.\n" + "Please free up those ports or specify a different one.\n" + "Exiting." + ) + + return + else: + logger.warning( + f"Port {port} is currently in use. Switching to {fallback_port}." + ) + + port = fallback_port + # Initialize auth keys load_auth_keys(unwrap(network_config.get("disable_auth"), False)) @@ -788,9 +811,6 @@ def entrypoint(args: Optional[dict] = None): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config) - host = unwrap(network_config.get("host"), "127.0.0.1") - port = unwrap(network_config.get("port"), 5000) - # TODO: Replace this with abortables, async via producer consumer, or something else api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)