Source code for fgen_runtime.testing

"""
Testing tools
"""
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Callable, Protocol

import pint.testing

from fgen_runtime.timeseries_collection import TimeseriesCollection

if TYPE_CHECKING:
    import pint


[docs]def assert_timeseries_collection_equal( left: TimeseriesCollection, right: TimeseriesCollection ) -> None: """ Assert that two :obj:`TimeseriesCollection` are equal Parameters ---------- left First object to compare right Second object to compare Raises ------ AssertionError The two objects are not equal """ def get_msg(i: int) -> str: return f"left_v.data != right_v.data for index={i}" assert_timeseries_collection_compare( left=left, right=right, compare_data=pint.testing.assert_equal, # type: ignore # mypy and pint fighting get_data_error_message=get_msg, )
[docs]def assert_timeseries_collection_allclose( left: TimeseriesCollection, right: TimeseriesCollection, rtol: float = 1e-6, atol: float = 1e-8, ) -> None: """ Assert that two :obj:`TimeseriesCollection`'s values are all close Parameters ---------- left First object to compare right Second object to compare rtol Relative tolerance to pass through to :func:`pint.testing.assert_allclose` atol Absolute tolerance to pass through to :func:`pint.testing.assert_allclose` Raises ------ AssertionError The two objects are not equal """ def get_msg(i: int) -> str: return f"left_v.data not close to right_v.data for index={i}" assert_timeseries_collection_compare( left=left, right=right, compare_data=partial(pint.testing.assert_allclose, rtol=rtol, atol=atol), get_data_error_message=get_msg, )
[docs]class CompareDataCallable(Protocol): def __call__( self, left: pint.UnitRegistry.Quantity, right: pint.UnitRegistry.Quantity, msg: str | None = None, ) -> None: """ Compare the data """ ...
[docs]def assert_timeseries_collection_compare( left: TimeseriesCollection, right: TimeseriesCollection, compare_data: CompareDataCallable, get_data_error_message: Callable[[int], str], ) -> None: """ Assert that two :obj:`TimeseriesCollection`'s values are all close Parameters ---------- left First object to compare right Second object to compare compare_data Function to use to compare the data attributes of left and right get_data_error_message Function that creates the error message if the data attributes are not identical based on the index that is being compared. Raises ------ AssertionError The two objects are not equal """ # First check that there are the same number of Timeseries in each if len(left) != len(right): msg = f"{len(left)} != {len(right)}" raise AssertionError(msg) # Order matters for timeseries collections, so we can check with basic iteration for i in range(len(left)): # I would want to use zip here but mypy was complaining left_v = left[i] right_v = right[i] if left_v.name != right_v.name: msg = f"{left_v.name} != {right_v.name} for index={i}" raise AssertionError(msg) msg = get_data_error_message(i) compare_data( left_v.values.values, right_v.values.values, msg=msg, ) compare_data( left_v.values.value_last_bound, right_v.values.value_last_bound, msg=msg, ) try: compare_data(left_v.time.values, right_v.time.values) compare_data(left_v.time.value_last_bound, right_v.time.value_last_bound) except AssertionError as exc: msg = f"left_v.time_axis != right_v.time_axis for index={i}" raise AssertionError(msg) from exc if left_v.spline != right_v.spline: msg = f"{left_v.spline} != {right_v.spline} for index={i}" raise AssertionError(msg)