OAI: Split up utility functions

Just like types, put utility functions in their own separate module
based on the route.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-02-01 00:26:42 -05:00
parent 634d299fd9
commit d3781920b3
6 changed files with 44 additions and 38 deletions

View file

@ -2,7 +2,7 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
from common.sampling import SamplerParams
from common.sampling import CommonSamplerRequest
class LogProbs(BaseModel):
@ -22,7 +22,7 @@ class UsageStats(BaseModel):
total_tokens: int
class CommonCompletionRequest(SamplerParams):
class CommonCompletionRequest(BaseSamplerRequest):
"""Represents a common completion request."""
# Model information
@ -49,5 +49,5 @@ class CommonCompletionRequest(SamplerParams):
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info (remainder is in SamplerParams superclass)
# Generation info (remainder is in CommonSamplerRequest superclass)
stream: Optional[bool] = False

View file

@ -1,5 +1,4 @@
""" Utility functions for the OpenAI server. """
import pathlib
from typing import Optional
from common.utils import unwrap
@ -12,8 +11,6 @@ from OAI.types.chat_completion import (
)
from OAI.types.completion import CompletionResponse, CompletionRespChoice
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
def create_completion_response(
@ -82,32 +79,3 @@ def create_chat_completion_stream_chunk(
)
return chunk
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
"""Get the list of models from the provided path."""
# Convert the provided draft model path to a pathlib path for
# equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id=path.name)
model_card_list.data.append(model_card) # pylint: disable=no-member
return model_card_list
def get_lora_list(lora_path: pathlib.Path):
"""Get the list of Lora cards from the provided path."""
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card) # pylint: disable=no-member
return lora_list

14
OAI/utils/lora.py Normal file
View file

@ -0,0 +1,14 @@
import pathlib
from OAI.types.lora import LoraCard, LoraList
def get_lora_list(lora_path: pathlib.Path):
"""Get the list of Lora cards from the provided path."""
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card) # pylint: disable=no-member
return lora_list

22
OAI/utils/model.py Normal file
View file

@ -0,0 +1,22 @@
import pathlib
from typing import Optional
from OAI.types.model import ModelCard, ModelList
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
"""Get the list of models from the provided path."""
# Convert the provided draft model path to a pathlib path for
# equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()
model_card_list = ModelList()
for path in model_path.iterdir():
# Don't include the draft models path
if path.is_dir() and path != draft_model_path:
model_card = ModelCard(id=path.name)
model_card_list.data.append(model_card) # pylint: disable=no-member
return model_card_list

View file

@ -282,6 +282,7 @@ class ExllamaV2Container:
def get_model_path(self, is_draft: bool = False):
"""Get the path for this model."""
model_path = pathlib.Path(
self.draft_config.model_dir if is_draft else self.config.model_dir
)
@ -296,6 +297,7 @@ class ExllamaV2Container:
module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int)
"""
for _ in self.load_gen(progress_callback):
pass

View file

@ -55,13 +55,13 @@ from OAI.types.token import (
TokenDecodeRequest,
TokenDecodeResponse,
)
from OAI.utils_oai import (
from OAI.utils.completion import (
create_completion_response,
get_model_list,
get_lora_list,
create_chat_completion_response,
create_chat_completion_stream_chunk,
)
from OAI.utils.model import get_model_list
from OAI.utils.lora import get_lora_list
logger = init_logger(__name__)