The differential equation solvers in ProbNum
are able to handle events.
At the moment, an event can either be a set of grid-points that must be included in the posterior, or a state for which a condition-function
evaluates to True
.
This notebook explains how this can be used with ProbNum (some examples are taken from https://diffeq.sciml.ai/stable/features/callback_functions/)
What is the easiest way to force events into your ODE solution? Let us define a simple, linear ODE that describes exponential decay.
# Make inline plots vector graphics instead of raster graphics
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")
# Plotting
import matplotlib.pyplot as plt
plt.style.use("../../probnum.mplstyle")
/tmp/ipykernel_57678/794108844.py:5: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()` set_matplotlib_formats("pdf", "svg")
from probnum import diffeq, randvars, randprocs, problems
import numpy as np
# For easy modification of the states in the callbacks
import dataclasses
def f(t, y):
return -y
def df(t, y):
return -1.0 * np.eye(len(y)) # np.ones((len(y), len(y)))
t0 = 0.0
tmax = 5.0
y0 = np.array([4])
To show off the ability to include a set number of grid-points, let us define a dense grid in a subset of the integration domain.
time_stops = np.linspace(3.5, 4.0, 50)
To force the ODE solver to include these time-stamps, just pass them to probsolve_ivp
. Here, we pick a large relative tolerance because we want to see a range of samples (the ODE is so simple, it is solved very accurately on large steps).
probsol = diffeq.probsolve_ivp(
f=f,
t0=t0,
tmax=tmax,
y0=y0,
time_stops=time_stops,
rtol=0.8,
)
# Draw 10 samples from the posterior and plot.
rng = np.random.default_rng(seed=2)
samples = probsol.sample(size=10, rng=rng)
for sample in samples:
plt.plot(probsol.locations, sample, "o-", color="C0")
plt.show()
Observe how there is a dense gathering of grid-points between 3.5 and 4.0. These are our events!
The same works for e.g. perturbsolve_ivp
. Let us compute 10 perturbed solutions, so the plots look similar to the samples from the posterior of the probabilistic solver.
# every solve is random
rng = np.random.default_rng()
time_stops = np.linspace(3.5, 4.0, 100)
perturbsols = [
diffeq.perturbsolve_ivp(
f=f,
t0=t0,
tmax=tmax,
y0=y0,
rng=rng,
noise_scale=0.05,
time_stops=time_stops,
)
for _ in range(10)
]
for perturbsol in perturbsols:
plt.plot(perturbsol.locations, perturbsol.states.mean, "o-", color="C1")
plt.show()
Again, observe how there are many locations between 3.5 and 4.0.
It is also possible to modify the solver states whenever an event happens.
This is not possible via the top-level interface functions (e.g. probsolve_ivp
) - we have to build an ODE solver from scratch (see the respective notebook for an explanation thereof).
# Construct IVP, prior, linearization, diffusion, and initialization
ivp = problems.InitialValueProblem(t0=t0, tmax=tmax, y0=y0, f=f, df=df)
prior_process = randprocs.markov.integrator.IntegratedWienerProcess(
initarg=ivp.t0,
num_derivatives=1,
wiener_process_dimension=ivp.dimension,
forward_implementation="sqrt",
backward_implementation="sqrt",
)
firststep = diffeq.stepsize.propose_firststep(ivp)
steprule = diffeq.stepsize.AdaptiveSteps(firststep=firststep, atol=1e-1, rtol=1e-1)
solver = diffeq.odefilter.ODEFilter(
steprule=steprule,
prior_process=prior_process,
with_smoothing=False,
)
To describe a discrete event, we define a condition function that checks whether the current time-point is either 2.0 or 4.0. At both locations, we reset the current state to $y=6.$ (careful! The state of a filter-based solver consists of $[y, \dot y, \ddot y, ...]$).
Let us construct both functions and pass them to a DiscreteEventHandler
.
Since the solver is unlikely to stop at exactly 2.0 or 4.0, let us force these locations into the ODE solver posterior.
def condition(state: diffeq.ODESolverState) -> bool:
return state.t in [2.0, 4.0]
def replace(state: diffeq.ODESolverState) -> diffeq.ODESolverState:
"""Replace an ODE solver state whenever a condition is True."""
new_mean = np.array([6.0, -6])
new_rv = randvars.Normal(
new_mean, cov=0 * state.rv.cov, cov_cholesky=0 * state.rv.cov_cholesky
)
return dataclasses.replace(state, rv=new_rv)
callback = diffeq.callbacks.DiscreteCallback(condition=condition, replace=replace)
odesol = solver.solve(ivp=ivp, stop_at=[2.0, 4.0], callbacks=callback)
plt.plot(odesol.locations, odesol.states.mean, "o-")
plt.show()