Source code for fgen_runtime.units

"""
Unit-handling
"""
from __future__ import annotations

from collections.abc import Iterable
from typing import Any, Callable, TypeVar, Union, overload

import numpy as np
import numpy.typing as npt
import pint
from typing_extensions import ParamSpec, TypeAlias

UnitDefinition: TypeAlias = Union[str, pint.registry.Unit]
"""Types which can be used to specify the exepcted (input) unit"""

ReturnUnitDefinition: TypeAlias = Union[
    str, tuple[str, ...], pint.registry.Unit, tuple[pint.registry.Unit, ...]
]
"""Types which can be used to specify the exepcted return unit"""

VerifyUnitsSupported: TypeAlias = Union[UnitDefinition, None]
"""Types which are supported as inputs to :func:`verify_units`"""

Q: TypeAlias = pint.registry.UnitRegistry.Quantity
"""Short-hand to save space and typing"""

PintSupported: TypeAlias = Union[
    float, int, np.number[Any], npt.NDArray[np.number[Any]], npt.ArrayLike
]
"""
Types that are able to be cast to :obj:`pint.registry.UnitRegistry.Quantity` by pint
"""

P = ParamSpec("P")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])


# Overload variants created with scripts/make-verify-units-overloads.py
# Best not to do this by hand unless you want cramp

T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")
X = TypeVar("X")
Y = TypeVar("Y")
Z = TypeVar("Z")
UT = TypeVar("UT")
VT = TypeVar("VT")
XT = TypeVar("XT")
YT = TypeVar("YT")
PintT = TypeVar("PintT", bound=PintSupported)
PintU = TypeVar("PintU", bound=PintSupported)
PintV = TypeVar("PintV", bound=PintSupported)
PintX = TypeVar("PintX", bound=PintSupported)
PintY = TypeVar("PintY", bound=PintSupported)
PintZ = TypeVar("PintZ", bound=PintSupported)
PintUT = TypeVar("PintUT", bound=PintSupported)
PintVT = TypeVar("PintVT", bound=PintSupported)
PintXT = TypeVar("PintXT", bound=PintSupported)
PintYT = TypeVar("PintYT", bound=PintSupported)


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[()],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[P, PintT]], Callable[P, Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU], PintT]], Callable[[Q], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[P, PintT]], Callable[P, Q]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU], T]], Callable[[Q], T]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, PintV], PintT]], Callable[[Q, Q], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, U], PintT]], Callable[[Q, U], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU], PintT]], Callable[[U, Q], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[P, PintT]], Callable[P, Q]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, PintV], T]], Callable[[Q, Q], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, U], T]], Callable[[Q, U], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[None, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU], T]], Callable[[U, Q], T]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, PintV, PintX], PintT]], Callable[[Q, Q, Q], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition, UnitDefinition, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, PintV, U], PintT]], Callable[[Q, Q, U], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition, None, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, U, PintV], PintT]], Callable[[Q, U, Q], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition, None, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, U, V], PintT]], Callable[[Q, U, V], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, PintV], PintT]], Callable[[U, Q, Q], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, UnitDefinition, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, V], PintT]], Callable[[U, Q, V], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, None, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, V, PintU], PintT]], Callable[[U, V, Q], Q]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, None, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[P, PintT]], Callable[P, Q]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, PintV, PintX], T]], Callable[[Q, Q, Q], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition, UnitDefinition, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, PintV, U], T]], Callable[[Q, Q, U], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition, None, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, U, PintV], T]], Callable[[Q, U, Q], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition, None, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, U, V], T]], Callable[[Q, U, V], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[None, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, PintV], T]], Callable[[U, Q, Q], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[None, UnitDefinition, None],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, V], T]], Callable[[U, Q, V], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[None, None, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, V, PintU], T]], Callable[[U, V, Q], T]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[PintU, PintV, PintX, PintY], PintT]], Callable[[Q, Q, Q, Q], Q]
]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, PintV, PintX], PintT]], Callable[[U, Q, Q, Q], Q]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[PintU, PintV, PintX, PintY], T]], Callable[[Q, Q, Q, Q], T]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[None, UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, PintV, PintX], T]], Callable[[U, Q, Q, Q], T]]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[
        UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[PintU, PintV, PintX, PintY, PintZ], PintT]], Callable[[Q, Q, Q, Q, Q], Q]
]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[None, UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, PintY], PintT]], Callable[[U, Q, Q, Q, Q], Q]
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[PintU, PintV, PintX, PintY, PintZ], T]], Callable[[Q, Q, Q, Q, Q], T]
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[None, UnitDefinition, UnitDefinition, UnitDefinition, UnitDefinition],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, PintY], T]], Callable[[U, Q, Q, Q, Q], T]
]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[PintU, PintV, PintX, PintY, PintZ, PintUT], PintT]],
    Callable[[Q, Q, Q, Q, Q, Q], Q],
]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, PintY, PintZ], PintT]],
    Callable[[U, Q, Q, Q, Q, Q], Q],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[PintU, PintV, PintX, PintY, PintZ, PintUT], T]],
    Callable[[Q, Q, Q, Q, Q, Q], T],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, PintY, PintZ], T]],
    Callable[[U, Q, Q, Q, Q, Q], T],
]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[PintU, PintV, PintX, PintY, PintZ, PintUT, PintVT], PintT]],
    Callable[[Q, Q, Q, Q, Q, Q, Q], Q],
]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, PintY, PintZ, PintUT], PintT]],
    Callable[[U, Q, Q, Q, Q, Q, Q], Q],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[PintU, PintV, PintX, PintY, PintZ, PintUT, PintVT], T]],
    Callable[[Q, Q, Q, Q, Q, Q, Q], T],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, PintY, PintZ, PintUT], T]],
    Callable[[U, Q, Q, Q, Q, Q, Q], T],
]:
    ...


# Below here are overloads written by hand
# Our lack of all combos means we have to write some by hand.
# This could become extremely painful, or this file may become ludicrously long.
@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        None,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, PintV, V], X]], Callable[[U, Q, Q, V], X]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        None,
        None,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, PintU, PintV, V, X], Y]], Callable[[U, Q, Q, V, X], Y]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        None,
        UnitDefinition,
        UnitDefinition,
        None,
        None,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, V, PintY, PintZ, X, Y], Z]],
    Callable[[U, Q, Q, Q, V, Q, Q, X, Y], Z],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        None,
        UnitDefinition,
        UnitDefinition,
        None,
        None,
        None,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, PintU, PintV, PintX, V, PintY, PintZ, X, Y, Z], UT]],
    Callable[[U, Q, Q, Q, V, Q, Q, X, Y, Z], UT],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        None,
        None,
        None,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, V, X, Y, PintU], Z]],
    Callable[[U, V, X, Y, Q], Z],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, V, PintU, PintV, PintX, PintZ], Z]],
    Callable[[U, V, Q, Q, Q, Q], Z],
]:
    ...


@overload
def verify_units(
    ret: ReturnUnitDefinition,
    args: tuple[
        None,
        None,
        None,
        None,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, V, X, Y], PintT]],
    Callable[[U, V, X, Y], Q],
]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        None,
        UnitDefinition,
        UnitDefinition,
        UnitDefinition,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[Callable[[U, V, PintU, PintV, PintX], Y]], Callable[[U, V, Q, Q, Q], Y]]:
    ...


@overload
def verify_units(
    ret: None,
    args: tuple[
        None,
        None,
        None,
        None,
        None,
        None,
        UnitDefinition,
        None,
        None,
    ],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[
    [Callable[[U, V, X, Y, Z, UT, PintU, VT, XT], YT]],
    Callable[[U, V, X, Y, Z, UT, Q, VT, XT], YT],
]:
    ...


# The do nothing case. Here the decorator does not alter the signature of the
# input function.
@overload
def verify_units(
    ret: None,
    args: tuple[None, ...],
    strict: bool = True,
    ureg: pint.registry.UnitRegistry | None = None,
) -> Callable[[FuncT], FuncT]:
    ...


[docs]def verify_units( ret: VerifyUnitsSupported | Iterable[VerifyUnitsSupported], args: Iterable[VerifyUnitsSupported], strict: bool = True, ureg: pint.registry.UnitRegistry | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ Wrap a function to make it pint-aware This wraps :func:`unit_registry.wraps`, but provides better handling for types Parameters ---------- ret Units of each of the return values. Use `None` to skip argument conversion. args Units of each of the arguments. Use `None` to skip argument conversion. strict Indicates that only quantities are accepted. (Default value = True) ureg Unit registry to use Defaults to :attr:`unit_registry` in the `fgen.units` module. Returns ------- Decorator for wrapping callables """ # Work around quirk where wraps doesn't accept dimensionless clean_args = [arg if arg != "dimensionless" else "" for arg in args] if ureg is None: ureg: pint.registry.UnitRegistry = pint.get_application_registry() # type: ignore return ureg.wraps(ret, clean_args, strict) # type: ignore