Progress: Switch to Rich
Rich is a more mature library for displaying progress bars, logging, and console output. This should help properly align progress bars within the terminal. Side note: "We're Rich!" Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
39617adb65
commit
fe0ff240e7
6 changed files with 55 additions and 24 deletions
|
|
@ -2,6 +2,14 @@
|
|||
|
||||
import traceback
|
||||
from pydantic import BaseModel
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
from common.logger import init_logger
|
||||
|
|
@ -58,6 +66,18 @@ def get_sse_packet(json_data: str):
|
|||
return f"data: {json_data}\n\n"
|
||||
|
||||
|
||||
def get_loading_progress_bar():
|
||||
"""Gets a pre-made progress bar for loading tasks."""
|
||||
|
||||
return Progress(
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
TimeRemainingColumn(),
|
||||
)
|
||||
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
"""Unwrap function for Optionals."""
|
||||
if wrapped is None:
|
||||
|
|
|
|||
51
main.py
51
main.py
|
|
@ -15,7 +15,6 @@ from fastapi.concurrency import run_in_threadpool
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from functools import partial
|
||||
from progress.bar import IncrementalBar
|
||||
|
||||
import common.gen_logging as gen_logging
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
|
|
@ -46,6 +45,7 @@ from common.templating import (
|
|||
)
|
||||
from common.utils import (
|
||||
get_generator_error,
|
||||
get_loading_progress_bar,
|
||||
get_sse_packet,
|
||||
handle_request_error,
|
||||
load_progress,
|
||||
|
|
@ -233,6 +233,9 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||
|
||||
try:
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
|
||||
for module, modules in load_status:
|
||||
if await request.is_disconnected():
|
||||
logger.error(
|
||||
|
|
@ -242,10 +245,23 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
break
|
||||
|
||||
if module == 0:
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
|
||||
elif module == modules:
|
||||
loading_bar.next()
|
||||
loading_bar.finish()
|
||||
loading_task = progress.add_task(
|
||||
"[cyan]Loading modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task)
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
|
||||
if module == modules:
|
||||
progress.stop()
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
|
|
@ -259,17 +275,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
# Switch to model progress if the draft model is loaded
|
||||
if MODEL_CONTAINER.draft_config:
|
||||
model_type = "model"
|
||||
else:
|
||||
loading_bar.next()
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
except CancelledError:
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
|
|
@ -277,6 +283,8 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
)
|
||||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
finally:
|
||||
progress.stop()
|
||||
|
||||
# Determine whether to use or skip the queue
|
||||
if data.skip_queue:
|
||||
|
|
@ -749,14 +757,17 @@ def entrypoint(args: Optional[dict] = None):
|
|||
model_path.resolve(), False, **model_config
|
||||
)
|
||||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
for module, modules in load_status:
|
||||
if module == 0:
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
|
||||
elif module == modules:
|
||||
loading_bar.next()
|
||||
loading_bar.finish()
|
||||
loading_task = progress.add_task("[cyan]Loading modules", total=modules)
|
||||
else:
|
||||
loading_bar.next()
|
||||
progress.advance(loading_task)
|
||||
|
||||
if module == modules:
|
||||
progress.stop()
|
||||
|
||||
# Load loras after loading the model
|
||||
lora_config = get_lora_config()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1
|
|||
fastapi
|
||||
pydantic >= 2.0.0
|
||||
PyYAML
|
||||
progress
|
||||
rich
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1
|
|||
fastapi
|
||||
pydantic >= 2.0.0
|
||||
PyYAML
|
||||
progress
|
||||
rich
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
fastapi
|
||||
pydantic >= 2.0.0
|
||||
PyYAML
|
||||
progress
|
||||
rich
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1
|
|||
fastapi
|
||||
pydantic >= 2.0.0
|
||||
PyYAML
|
||||
progress
|
||||
rich
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue