Progress: Switch to Rich

Rich is a more mature library for displaying progress bars, logging,
and console output. This should help properly align progress bars
within the terminal.

Side note: "We're Rich!"

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-06 00:37:31 -05:00 committed by Brian Dashore
parent 39617adb65
commit fe0ff240e7
6 changed files with 55 additions and 24 deletions

View file

@ -2,6 +2,14 @@
import traceback
from pydantic import BaseModel
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
)
from typing import Optional
from common.logger import init_logger
@ -58,6 +66,18 @@ def get_sse_packet(json_data: str):
return f"data: {json_data}\n\n"
def get_loading_progress_bar():
"""Gets a pre-made progress bar for loading tasks."""
return Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
TimeRemainingColumn(),
)
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None:

51
main.py
View file

@ -15,7 +15,6 @@ from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from functools import partial
from progress.bar import IncrementalBar
import common.gen_logging as gen_logging
from backends.exllamav2.model import ExllamaV2Container
@ -46,6 +45,7 @@ from common.templating import (
)
from common.utils import (
get_generator_error,
get_loading_progress_bar,
get_sse_packet,
handle_request_error,
load_progress,
@ -233,6 +233,9 @@ async def load_model(request: Request, data: ModelLoadRequest):
load_status = MODEL_CONTAINER.load_gen(load_progress)
try:
progress = get_loading_progress_bar()
progress.start()
for module, modules in load_status:
if await request.is_disconnected():
logger.error(
@ -242,10 +245,23 @@ async def load_model(request: Request, data: ModelLoadRequest):
break
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
loading_task = progress.add_task(
"[cyan]Loading modules", total=modules
)
else:
progress.advance(loading_task)
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing",
)
yield get_sse_packet(response.model_dump_json())
if module == modules:
progress.stop()
response = ModelLoadResponse(
model_type=model_type,
@ -259,17 +275,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
# Switch to model progress if the draft model is loaded
if MODEL_CONTAINER.draft_config:
model_type = "model"
else:
loading_bar.next()
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing",
)
yield get_sse_packet(response.model_dump_json())
except CancelledError:
logger.error(
"Model load cancelled by user. "
@ -277,6 +283,8 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
except Exception as exc:
yield get_generator_error(str(exc))
finally:
progress.stop()
# Determine whether to use or skip the queue
if data.skip_queue:
@ -749,14 +757,17 @@ def entrypoint(args: Optional[dict] = None):
model_path.resolve(), False, **model_config
)
load_status = MODEL_CONTAINER.load_gen(load_progress)
progress = get_loading_progress_bar()
progress.start()
for module, modules in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max=modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
loading_task = progress.add_task("[cyan]Loading modules", total=modules)
else:
loading_bar.next()
progress.advance(loading_task)
if module == modules:
progress.stop()
# Load loras after loading the model
lora_config = get_lora_config()

View file

@ -10,7 +10,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1
fastapi
pydantic >= 2.0.0
PyYAML
progress
rich
uvicorn
jinja2 >= 3.0.0
colorlog

View file

@ -16,7 +16,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1
fastapi
pydantic >= 2.0.0
PyYAML
progress
rich
uvicorn
jinja2 >= 3.0.0
colorlog

View file

@ -2,7 +2,7 @@
fastapi
pydantic >= 2.0.0
PyYAML
progress
rich
uvicorn
jinja2 >= 3.0.0
colorlog

View file

@ -16,7 +16,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.15/exllamav2-0.0.1
fastapi
pydantic >= 2.0.0
PyYAML
progress
rich
uvicorn
jinja2 >= 3.0.0
colorlog