109 lines
2.8 KiB
Python
109 lines
2.8 KiB
Python
"""Common utility functions"""
|
|
|
|
import inspect
|
|
from types import NoneType
|
|
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def unwrap(wrapped: Optional[T], default: T = None) -> T:
|
|
"""Unwrap function for Optionals."""
|
|
if wrapped is None:
|
|
return default
|
|
|
|
return wrapped
|
|
|
|
|
|
def coalesce(*args):
|
|
"""Coalesce function for multiple unwraps."""
|
|
return next((arg for arg in args if arg is not None), None)
|
|
|
|
|
|
def filter_none_values(collection: Union[dict, list]) -> Union[dict, list]:
|
|
"""Remove None values from a collection."""
|
|
|
|
if isinstance(collection, dict):
|
|
return {
|
|
k: filter_none_values(v) for k, v in collection.items() if v is not None
|
|
}
|
|
elif isinstance(collection, list):
|
|
return [filter_none_values(i) for i in collection if i is not None]
|
|
else:
|
|
return collection
|
|
|
|
|
|
def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
|
|
"""Merge 2 dictionaries"""
|
|
for key, value in dict2.items():
|
|
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
|
|
merge_dict(dict1[key], value)
|
|
else:
|
|
dict1[key] = value
|
|
return dict1
|
|
|
|
|
|
def merge_dicts(*dicts: Dict) -> Dict:
|
|
"""Merge an arbitrary amount of dictionaries"""
|
|
result = {}
|
|
for dictionary in dicts:
|
|
result = merge_dict(result, dictionary)
|
|
|
|
return result
|
|
|
|
|
|
def flat_map(input_list):
|
|
"""Flattens a list of lists into a single list."""
|
|
|
|
return [item for sublist in input_list for item in sublist]
|
|
|
|
|
|
def is_list_type(type_hint) -> bool:
|
|
"""Checks if a type contains a list."""
|
|
|
|
if get_origin(type_hint) is list:
|
|
return True
|
|
|
|
# Recursively check for lists inside type arguments
|
|
type_args = get_args(type_hint)
|
|
if type_args:
|
|
return any(is_list_type(arg) for arg in type_args)
|
|
|
|
return False
|
|
|
|
|
|
def unwrap_optional_type(type_hint) -> Type:
|
|
"""
|
|
Unwrap Optional[type] annotations.
|
|
This is not the same as unwrap.
|
|
"""
|
|
|
|
if get_origin(type_hint) is Union:
|
|
args = get_args(type_hint)
|
|
if NoneType in args:
|
|
for arg in args:
|
|
if arg is not NoneType:
|
|
return arg
|
|
|
|
return type_hint
|
|
|
|
|
|
def calculate_rope_alpha(base_seq_len: int, target_seq_len: int):
|
|
"""
|
|
Converts a given max sequence length to a rope alpha value.
|
|
|
|
Args:
|
|
base_seq_len: The model's configured sequence length.
|
|
target_seq_len: The user-specified max sequence length.
|
|
"""
|
|
|
|
# Get the ratio of the model's max sequence length to the target
|
|
ratio = target_seq_len / base_seq_len
|
|
|
|
# Default to a 1 alpha if the sequence length is ever less
|
|
# than or equal to 1
|
|
if ratio <= 1.0:
|
|
alpha = 1
|
|
else:
|
|
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
|
|
return alpha
|