diff --git a/OAI/types/model.py b/OAI/types/model.py index 3845e2d..f23a4c6 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from time import time from typing import List, Optional from gen_logging import LogConfig @@ -45,6 +45,9 @@ class ModelLoadRequest(BaseModel): draft: Optional[DraftModelLoadRequest] = None class ModelLoadResponse(BaseModel): + # Avoids pydantic namespace warning + model_config = ConfigDict(protected_namespaces = []) + model_type: str = "model" module: int modules: int diff --git a/auth.py b/auth.py index dcedeb2..c31e27a 100644 --- a/auth.py +++ b/auth.py @@ -30,7 +30,7 @@ def load_auth_keys(): try: with open("api_tokens.yml", "r", encoding = 'utf8') as auth_file: auth_keys_dict = yaml.safe_load(auth_file) - auth_keys = AuthKeys.parse_obj(auth_keys_dict) + auth_keys = AuthKeys.model_validate(auth_keys_dict) except Exception as _: new_auth_keys = AuthKeys( api_key = secrets.token_hex(16), @@ -39,7 +39,7 @@ def load_auth_keys(): auth_keys = new_auth_keys with open("api_tokens.yml", "w", encoding = "utf8") as auth_file: - yaml.safe_dump(auth_keys.dict(), auth_file, default_flow_style=False) + yaml.safe_dump(auth_keys.model_dump(), auth_file, default_flow_style=False) print( f"Your API key is: {auth_keys.api_key}\n" diff --git a/gen_logging.py b/gen_logging.py index ff18ced..e0986e4 100644 --- a/gen_logging.py +++ b/gen_logging.py @@ -18,7 +18,7 @@ def update_from_dict(options_dict: Dict[str, bool]): if value is None: value = False - config = LogConfig.parse_obj(options_dict) + config = LogConfig.model_validate(options_dict) def broadcast_status(): enabled = [] diff --git a/main.py b/main.py index b55fe64..4687459 100644 --- a/main.py +++ b/main.py @@ -117,7 +117,7 @@ async def load_model(request: Request, data: ModelLoadRequest): model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / data.name - load_data = data.dict() + load_data = data.model_dump() # TODO: Add API exception if draft directory isn't found draft_config = unwrap(model_config.get("draft"), {}) @@ -156,7 +156,7 @@ async def load_model(request: Request, data: ModelLoadRequest): status="finished" ) - yield get_sse_packet(response.json(ensure_ascii = False)) + yield get_sse_packet(response.model_dump_json()) # Switch to model progress if the draft model is loaded if model_container.draft_config: @@ -171,7 +171,7 @@ async def load_model(request: Request, data: ModelLoadRequest): status="processing" ) - yield get_sse_packet(response.json(ensure_ascii=False)) + yield get_sse_packet(response.model_dump_json()) except CancelledError: print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.") except Exception as e: @@ -230,7 +230,7 @@ async def load_lora(data: LoraLoadRequest): if len(model_container.active_loras) > 0: model_container.unload(True) - result = model_container.load_loras(lora_dir, **data.dict()) + result = model_container.load_loras(lora_dir, **data.model_dump()) return LoraLoadResponse( success = unwrap(result.get("success"), []), failure = unwrap(result.get("failure"), []) @@ -281,7 +281,7 @@ async def generate_completion(request: Request, data: CompletionRequest): completion_tokens, model_path.name) - yield get_sse_packet(response.json(ensure_ascii=False)) + yield get_sse_packet(response.model_dump_json()) except CancelledError: print("Error: Completion request cancelled by user.") except Exception as e: @@ -334,7 +334,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest model_path.name ) - yield get_sse_packet(response.json(ensure_ascii=False)) + yield get_sse_packet(response.model_dump_json()) # Yield a finish response on successful generation finish_response = create_chat_completion_stream_chunk( @@ -342,7 +342,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest finish_reason = "stop" ) - yield get_sse_packet(finish_response.json(ensure_ascii=False)) + yield get_sse_packet(finish_response.model_dump_json()) except CancelledError: print("Error: Chat completion cancelled by user.") except Exception as e: diff --git a/requirements-amd.txt b/requirements-amd.txt index 00bc209..f4cd4e0 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -8,7 +8,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1 # Pip dependencies fastapi -pydantic < 2,>= 1 +pydantic PyYAML progress uvicorn diff --git a/requirements-cu118.txt b/requirements-cu118.txt index 96d6a33..2d8bec0 100644 --- a/requirements-cu118.txt +++ b/requirements-cu118.txt @@ -14,7 +14,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1 # Pip dependencies fastapi -pydantic < 2,>= 1 +pydantic PyYAML progress uvicorn diff --git a/requirements.txt b/requirements.txt index 5d8d51b..6b47c92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1 # Pip dependencies fastapi -pydantic < 2,>= 1 +pydantic PyYAML progress uvicorn diff --git a/utils.py b/utils.py index a243373..a94fcc9 100644 --- a/utils.py +++ b/utils.py @@ -26,7 +26,7 @@ def get_generator_error(message: str): # Log and send the exception print(f"\n{generator_error.error.trace}") - return get_sse_packet(generator_error.json(ensure_ascii = False)) + return get_sse_packet(generator_error.model_dump_json()) def get_sse_packet(json_data: str): return f"data: {json_data}\n\n"