diff --git a/OAI/types/common.py b/OAI/types/common.py index e90919e..dd25ad6 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from typing import List, Dict, Optional -from common.sampling import SamplerParams +from common.sampling import CommonSamplerRequest class LogProbs(BaseModel): @@ -22,7 +22,7 @@ class UsageStats(BaseModel): total_tokens: int -class CommonCompletionRequest(SamplerParams): +class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" # Model information @@ -49,5 +49,5 @@ class CommonCompletionRequest(SamplerParams): description="Not parsed. Only used for OAI compliance.", default=None ) - # Generation info (remainder is in SamplerParams superclass) + # Generation info (remainder is in CommonSamplerRequest superclass) stream: Optional[bool] = False diff --git a/OAI/utils_oai.py b/OAI/utils/completion.py similarity index 67% rename from OAI/utils_oai.py rename to OAI/utils/completion.py index 5ad2873..500451a 100644 --- a/OAI/utils_oai.py +++ b/OAI/utils/completion.py @@ -1,5 +1,4 @@ """ Utility functions for the OpenAI server. """ -import pathlib from typing import Optional from common.utils import unwrap @@ -12,8 +11,6 @@ from OAI.types.chat_completion import ( ) from OAI.types.completion import CompletionResponse, CompletionRespChoice from OAI.types.common import UsageStats -from OAI.types.lora import LoraList, LoraCard -from OAI.types.model import ModelList, ModelCard def create_completion_response( @@ -82,32 +79,3 @@ def create_chat_completion_stream_chunk( ) return chunk - - -def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): - """Get the list of models from the provided path.""" - - # Convert the provided draft model path to a pathlib path for - # equality comparisons - if draft_model_path: - draft_model_path = pathlib.Path(draft_model_path).resolve() - - model_card_list = ModelList() - for path in model_path.iterdir(): - # Don't include the draft models path - if path.is_dir() and path != draft_model_path: - model_card = ModelCard(id=path.name) - model_card_list.data.append(model_card) # pylint: disable=no-member - - return model_card_list - - -def get_lora_list(lora_path: pathlib.Path): - """Get the list of Lora cards from the provided path.""" - lora_list = LoraList() - for path in lora_path.iterdir(): - if path.is_dir(): - lora_card = LoraCard(id=path.name) - lora_list.data.append(lora_card) # pylint: disable=no-member - - return lora_list diff --git a/OAI/utils/lora.py b/OAI/utils/lora.py new file mode 100644 index 0000000..81c9f9c --- /dev/null +++ b/OAI/utils/lora.py @@ -0,0 +1,14 @@ +import pathlib + +from OAI.types.lora import LoraCard, LoraList + + +def get_lora_list(lora_path: pathlib.Path): + """Get the list of Lora cards from the provided path.""" + lora_list = LoraList() + for path in lora_path.iterdir(): + if path.is_dir(): + lora_card = LoraCard(id=path.name) + lora_list.data.append(lora_card) # pylint: disable=no-member + + return lora_list diff --git a/OAI/utils/model.py b/OAI/utils/model.py new file mode 100644 index 0000000..ac6d117 --- /dev/null +++ b/OAI/utils/model.py @@ -0,0 +1,22 @@ +import pathlib +from typing import Optional + +from OAI.types.model import ModelCard, ModelList + + +def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None): + """Get the list of models from the provided path.""" + + # Convert the provided draft model path to a pathlib path for + # equality comparisons + if draft_model_path: + draft_model_path = pathlib.Path(draft_model_path).resolve() + + model_card_list = ModelList() + for path in model_path.iterdir(): + # Don't include the draft models path + if path.is_dir() and path != draft_model_path: + model_card = ModelCard(id=path.name) + model_card_list.data.append(model_card) # pylint: disable=no-member + + return model_card_list diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index f908402..54cb416 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -282,6 +282,7 @@ class ExllamaV2Container: def get_model_path(self, is_draft: bool = False): """Get the path for this model.""" + model_path = pathlib.Path( self.draft_config.model_dir if is_draft else self.config.model_dir ) @@ -296,6 +297,7 @@ class ExllamaV2Container: module loaded. Prototype: def progress(loaded_modules: int, total_modules: int) """ + for _ in self.load_gen(progress_callback): pass diff --git a/main.py b/main.py index cbd8f60..a0f2b39 100644 --- a/main.py +++ b/main.py @@ -55,13 +55,13 @@ from OAI.types.token import ( TokenDecodeRequest, TokenDecodeResponse, ) -from OAI.utils_oai import ( +from OAI.utils.completion import ( create_completion_response, - get_model_list, - get_lora_list, create_chat_completion_response, create_chat_completion_stream_chunk, ) +from OAI.utils.model import get_model_list +from OAI.utils.lora import get_lora_list logger = init_logger(__name__)