# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: percent
#       format_version: '1.3'
#       jupytext_version: 1.16.1
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# %% [markdown]
# (input-interpolation)=
# # Input interpolation
#
# Some thoughts to consider when deciding how to interpolate your inputs.
# This feels trivial,
# but it turns out that there are more choices than may initially be expected.
# In this notebook, we go through some of those choices.

# %% [markdown]
# ## Imports
#
# Packages we'll use in this notebook.

# %%
import matplotlib.pyplot as plt
import numpy as np
import scipy.interpolate

# %% [markdown] jp-MarkdownHeadingCollapsed=true
# ## Example model
#
# To help this discussion, let's start with an example.
# Let's assume we're solving the following simple energy balance model
# (although the model details really don't matter too much for this discussion):
#
# $$
#     C \frac{dT(t)}{dt} = F(t) - \lambda_0 T(t)
# $$
#
# We want to solve the initial-value problem
# defined by this model combined with initial conditions.
# We want to solve for the temperature of the upper ocean ($T$)
# based on some input radiative forcing ($F$).

# %% [markdown]
# ## Our inputs
#
# The key subtlety is how we handle our inputs when solving the initial-value problem.
#
# The key issue is that we generally don't specify our inputs as continuous functions.
# Instead, we pass data around on discrete steps.
# For example, we may define our forcing as
# (this would be even better if put inside an xarray data array or something else
# that put the co-ordinates right next to the data,
# but let's not add that complication here):

# %%
# Our forcing is defined as discrete data at discrete points
# in time, rather than as continuous functions.
erf = np.array([0, 0.34, 0.34, 0.34])
time = np.array([1849, 1850, 1851, 1852])

# %% editable=true slideshow={"slide_type": ""}
# Visualise the forcing
PLT_SCATTER_KWARGS = dict(marker="x", s=100, label="discrete points", zorder=3)


fig, ax = plt.subplots()
ax.scatter(time, erf, **PLT_SCATTER_KWARGS)
ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# However, to solve our model,
# our solver needs to be able to determine the forcing at any point in time,
# not just on the discrete points given
# (some solvers don't require this, but many common ones, e.g. Runge-Kutta, do).
#
# **The key point**
#
# To solve our model,
# we have to make a decision about
# how to go from the discrete points we have defined to a continuous function.
# The key point of this notebook is examining this decision.
# This decision is a choice that is made by you, the modeller,
# and should be actively considered when solving an initial-value problem.
# There is no perfect decision for all situations,
# it depends on what you want the experiment to represent
# i.e. how you want the model to be solved.

# %% [markdown]
# ## Some common choices

# %% [markdown]
# ### Linear interpolation
#
# One option is to just linearly interpolate between the discrete points
# in order to create a continuous function.
# In many cases, this is a simple and good choice.

# %%
linear_spline = scipy.interpolate.interp1d(time, erf, kind="linear")

# %% [markdown]
# If we make such a choice, our continuous function looks like the below.

# %%
PLT_LINE_KWARGS = dict(alpha=0.6)
PLT_LINEAR_KWARGS = dict(
    **PLT_LINE_KWARGS, color="tab:orange", label="linear interpolation/piecewise linear"
)
TIME_FINE = np.linspace(time[0], time[-1], 250)

fig, ax = plt.subplots()
ax.scatter(time, erf, **PLT_SCATTER_KWARGS)
ax.plot(TIME_FINE, linear_spline(TIME_FINE), **{**PLT_LINEAR_KWARGS, "alpha": 1})

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# If we make such a choice,
# then the model's temperature response will be something like the below
# (obviously, all scaled by parameter values
# but the magnitudes don't matter for this current illustration).
# We only plot the temperature at discrete points,
# because solvers only report at discrete points.

# %%
linear_approx_tempreature_response = np.array([0.0, 0.1, 0.15, 0.18])

fig, ax = plt.subplots()
ax.scatter(
    time, linear_approx_tempreature_response, **{**PLT_LINEAR_KWARGS, "alpha": 1}
)

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# The thing to notice here is that,
# because the forcing is greater than zero between 1849 and 1850,
# the temperature in 1850 is greater than zero.
# This is sometimes what we want, but in other cases it won't be.
# One very obvious example is step experiments,
# where we want the step to happen abruptly (i.e. with a sharp leading-edge)
# rather than gradually over a year.
# For such cases, the next choice, constant interpolation, is better.

# %% [markdown]
# ### Constant interpolation
#
# The next option is to assume that the input is constant between each discrete point.

# %%
constant_spline = scipy.interpolate.interp1d(time, erf, kind="previous")
# Here we do 'previous constant', but you could also do 'kind="next"' or
# 'kind="nearest"' to get a different kind of constant interpolation.

# %% [markdown]
# If we make such a choice, our continuous function looks like the below
# (where previous interpolations are also included for comparison).

# %%
PLT_CONSTANT_KWARGS = dict(
    **PLT_LINE_KWARGS,
    color="tab:red",
    label="constant interpolation/piecewise constant",
)


fig, ax = plt.subplots()
ax.scatter(time, erf, **PLT_SCATTER_KWARGS)
ax.plot(TIME_FINE, linear_spline(TIME_FINE), **PLT_LINEAR_KWARGS)
ax.plot(TIME_FINE, constant_spline(TIME_FINE), **{**PLT_CONSTANT_KWARGS, "alpha": 1})

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# If we make such a choice,
# then the model's temperature response will be something like the below
# (again, all scaled by parameter values
# but the magnitudes don't matter for this current illustration).
# Again, we only plot the temperature at discrete points,
# because solvers only report at discrete points.

# %%
constant_approx_tempreature_response = np.array([0.0, 0.0, 0.1, 0.16])

fig, ax = plt.subplots()
ax.scatter(time, linear_approx_tempreature_response, **PLT_LINEAR_KWARGS)
ax.scatter(
    time, constant_approx_tempreature_response, **{**PLT_CONSTANT_KWARGS, "alpha": 1}
)

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# The thing to notice here is that, because we have a sharp edge,
# the reported temperature in 1850 will be zero.
# This is what we would want in a step experiment,
# but we wouldn't want something like this in, e.g.,
# a linear forcing experiment (almost by definition).
# In a linear forcing experiment,
# we want the forcing to increase linearly and the model to respond accordingly,
# rather than in a series of steps
# (so, in such an experiment,
# we would be better off using linear interpolation instead).

# %% [markdown]
# ### Quadratic interpolation
#
# The next option is to assume that the input is quadratic between each discrete point.

# %%
quadratic_spline = scipy.interpolate.interp1d(time, erf, kind="quadratic")

# %% [markdown]
# If we make such a choice, our continuous function looks like the below
# (where previous interpolations are also included for comparison).

# %%
PLT_QUDRATIC_KWARGS = dict(
    **PLT_LINE_KWARGS,
    color="tab:purple",
    label="quadratic interpolation/piecewise quadratic",
)


fig, ax = plt.subplots()
ax.scatter(time, erf, **PLT_SCATTER_KWARGS)
ax.plot(TIME_FINE, linear_spline(TIME_FINE), **PLT_LINEAR_KWARGS)
ax.plot(TIME_FINE, constant_spline(TIME_FINE), **PLT_CONSTANT_KWARGS)
ax.plot(TIME_FINE, quadratic_spline(TIME_FINE), **{**PLT_QUDRATIC_KWARGS, "alpha": 1})

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# If we make such a choice,
# then the model's temperature response will be something like the below
# (again, all scaled by parameter values
# but the magnitudes don't matter for this current illustration).
# Again, we only plot the temperature at discrete points,
# because solvers only report at discrete points.

# %%
quadratic_approx_tempreature_response = np.array([0.0, 0.12, 0.18, 0.19])

fig, ax = plt.subplots()
ax.scatter(time, linear_approx_tempreature_response, **PLT_LINEAR_KWARGS)
ax.scatter(time, constant_approx_tempreature_response, **PLT_CONSTANT_KWARGS)
ax.scatter(
    time, quadratic_approx_tempreature_response, **{**PLT_QUDRATIC_KWARGS, "alpha": 1}
)

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# The key thing to notice here is that the model's response would differ slightly,
# only because of a different choice of interpolation.
# A quadratic interpolation might be the best choice
# if the first-derivative of the model's inputs should be continuous,
# to avoid 'shock' effects appearing in the model's outputs
# (constant interpolation has zero-order discontinuities,
# linear interpolation has first-order discontinuities).
# Emissions input to a carbon cycle model may be such a case,
# as abrupt changes in the first-derivative of emissions
# may cause odd effects in the model's output.

# %% [markdown]
# ### Cubic interpolation
#
# The last option (covered in this notebook) is
# to assume that the input is cubic between each discrete point.

# %%
cubic_spline = scipy.interpolate.interp1d(time, erf, kind="cubic")

# %% [markdown]
# If we make such a choice, our continuous function looks like the below
# (where previous interpolations are also included for comparison).

# %%
PLT_CUBIC_KWARGS = dict(
    **PLT_LINE_KWARGS, color="tab:green", label="cubic interpolation/piecewise cubic"
)


fig, ax = plt.subplots()
ax.scatter(time, erf, **PLT_SCATTER_KWARGS)
ax.plot(TIME_FINE, linear_spline(TIME_FINE), **PLT_LINEAR_KWARGS)
ax.plot(TIME_FINE, constant_spline(TIME_FINE), **PLT_CONSTANT_KWARGS)
ax.plot(TIME_FINE, quadratic_spline(TIME_FINE), **PLT_QUDRATIC_KWARGS)
ax.plot(TIME_FINE, cubic_spline(TIME_FINE), **{**PLT_CUBIC_KWARGS, "alpha": 1})

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# If we make such a choice,
# then the model's temperature response will be something like the below
# (again, all scaled by parameter values
# but the magnitudes don't matter for this current illustration).
# Again, we only plot the temperature at discrete points,
# because solvers only report at discrete points.

# %%
cubic_approx_tempreature_response = np.array([0.0, 0.13, 0.182, 0.185])

fig, ax = plt.subplots()
ax.scatter(time, linear_approx_tempreature_response, **PLT_LINEAR_KWARGS)
ax.scatter(time, constant_approx_tempreature_response, **PLT_CONSTANT_KWARGS)
ax.scatter(time, quadratic_approx_tempreature_response, **PLT_QUDRATIC_KWARGS)
ax.scatter(time, cubic_approx_tempreature_response, **{**PLT_CUBIC_KWARGS, "alpha": 1})

ax.set_xticks(time)
_ = ax.legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

# %% [markdown]
# Again, it is worth noticing that the model's response would differ slightly,
# only because of a different choice of interpolation.
# The cubic interpolation is the next step up from the quadratic interpolation
# - in addition to having continuous first-oder derivatives,
# it also has continuous second-order derivatives.
# In some cases, this may be desirable
# (although we couldn't think of any at the time of writing).

# %% [markdown]
# ## Summary
#
# The key takeaway is that the choice of input interpolation is surprisingly important.
# When solving models with fgen, it is a key choice to make.
# We aim to provide a clean interface for these choices with the `fgen_interp1d` module.
# [TODO When we've worked it out,
# put a brief statement in here about where to look
# in order to control this interpolation choice when solving with `fgen_solve_ivp`.]
#
# For completeness, we plot all the input interpolations
# and approximate model responses below.

# %%
fig, axes = plt.subplots(nrows=2, sharex=True, figsize=(10, 6))

axes[0].set_title("Input")
axes[0].scatter(time, erf, **PLT_SCATTER_KWARGS)
axes[0].plot(TIME_FINE, linear_spline(TIME_FINE), **PLT_LINEAR_KWARGS)
axes[0].plot(TIME_FINE, constant_spline(TIME_FINE), **PLT_CONSTANT_KWARGS)
axes[0].plot(TIME_FINE, quadratic_spline(TIME_FINE), **PLT_QUDRATIC_KWARGS)
axes[0].plot(TIME_FINE, cubic_spline(TIME_FINE), **PLT_CUBIC_KWARGS)
axes[0].legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

axes[1].set_title("Output")
axes[1].scatter(time, linear_approx_tempreature_response, **PLT_LINEAR_KWARGS)
axes[1].scatter(time, constant_approx_tempreature_response, **PLT_CONSTANT_KWARGS)
axes[1].scatter(time, quadratic_approx_tempreature_response, **PLT_QUDRATIC_KWARGS)
axes[1].scatter(time, cubic_approx_tempreature_response, **PLT_CUBIC_KWARGS)
axes[1].legend(loc="center left", bbox_to_anchor=(1.05, 0.5))

axes[1].set_xticks(time)
plt.tight_layout()
