diff --git a/OAI/utils.py b/OAI/utils.py index 87ce8be..769991b 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,5 +1,4 @@ import pathlib -from fastchat.model.model_adapter import get_conversation_template, Conversation from OAI.types.completion import CompletionResponse, CompletionRespChoice from OAI.types.chat_completion import ( ChatCompletionMessage, @@ -10,8 +9,17 @@ from OAI.types.chat_completion import ( ) from OAI.types.common import UsageStats from OAI.types.model import ModelList, ModelCard +from packaging import version from typing import Optional, List +# Check fastchat +try: + import fastchat + from fastchat.model.model_adapter import get_conversation_template + _fastchat_available = True +except ImportError: + _fastchat_available = False + def create_completion_response(text: str, model_name: Optional[str]): # TODO: Add method to get token amounts in model for UsageStats @@ -78,6 +86,21 @@ def get_model_list(model_path: pathlib.Path): return model_card_list def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]): + # Check if fastchat is available + + if not _fastchat_available: + raise ModuleNotFoundError( + "Fastchat must be installed to parse these chat completion messages.\n" + "Please run the following command: pip install fschat[model_worker]" + ) + if version.parse(fastchat.__version__) < version.parse("0.2.23"): + raise ImportError( + "Parsing these chat completion messages requires fastchat 0.2.23 or greater. " + f"Current version: {fastchat.__version__}\n" + "Please upgrade fastchat by running the following command: " + "pip install -U fschat[model_worker]" + ) + conv = get_conversation_template(model_path) for message in messages: msg_role = message.role diff --git a/requirements.txt b/requirements.txt index 7dd575d..2db8867 100644 Binary files a/requirements.txt and b/requirements.txt differ