Source code for auto_uncertainties.util

from __future__ import annotations

from functools import wraps
from typing import TypeVar
import warnings

from jax import Array
import numpy as np
from numpy.typing import NDArray

from . import DowncastWarning

[docs] T = TypeVar("T", bound=np.generic, covariant=True)
[docs] def ignore_runtime_warnings(f): """ A decorator to ignore runtime warnings. :param f: The function to wrap :return: The wrapped function """ @wraps(f) def runtime_warn_inner(*args, **kwargs): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore", category=RuntimeWarning) return f(*args, **kwargs) return runtime_warn_inner
[docs] def ignore_numpy_downcast_warnings(f): """ A decorator to ignore `DowncastWarning`. :param f: The function to wrap :return: The wrapped function """ @wraps(f) def user_warn_inner(*args, **kwargs): with warnings.catch_warnings(record=True): warnings.simplefilter("ignore", category=DowncastWarning) return f(*args, **kwargs) return user_warn_inner
[docs] def deprecated(reason="", category=DeprecationWarning): """ Raise a deprecation warning for a decorated function. :param reason: Why the deprecation is being issued :param category: The type of warning to issue """ def decorator(func): msg = f"{func.__name__} is deprecated" + f" ({reason})." if reason else "." @wraps(func) def wrapper(*args, **kwargs): warnings.warn(msg, category, stacklevel=2) return func(*args, **kwargs) return wrapper return decorator
[docs] def is_iterable(y): try: iter(y) except TypeError: return False return True
[docs] def has_length(y): try: len(y) except TypeError: return False return True
[docs] def ndarray_to_scalar(value: NDArray[T]) -> T: return np.ndarray.item(strip_device_array(value))
[docs] def strip_device_array(value: Array | NDArray | float) -> NDArray: return np.array(value)