Source code for ceem.odesolver

#
# File: odesolver.py
#

import abc
import torch


[docs]class Euler:
[docs] @staticmethod def step_func(func, t, dt, y, u, transforms=None): return tuple(dt * f_ for f_ in func(t, y, u=u))
@property def order(self): return 1
[docs]class Midpoint:
[docs] @staticmethod def step_func(func, t, dt, y, u, transforms=None): y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y, u=u))) y_mid = tuple(trans(y_) for y_, trans in zip(y_mid, transforms)) return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid, u=u))
@property def order(self): return 2
[docs]class RK4:
[docs] @staticmethod def step_func(func, t, dt, y, u, transforms=None): return rk4_alt_step_func(func, t, dt, y, u=u)
@property def order(self): return 4
[docs]def rk4_alt_step_func(func, t, dt, y, k1=None, u=None): """Smaller error with slightly more compute.""" if k1 is None: k1 = func(t, y, u=u) k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1)), u=u) k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2)), u=u) k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3)), u=u) return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
[docs]def odestep(func, t, dt, y0, u=None, method='midpoint', transforms=None): tensor_input, func, y0, t = _check_inputs(func, y0, t) if transforms is None: transforms = [lambda x: x for _ in range(len(y0))] dy = SOLVERS[method].step_func(func, t, dt, y0, u=u, transforms=transforms) y = tuple(trans(y0_ + dy_) for y0_, dy_, trans in zip(y0, dy, transforms)) if tensor_input: y = y[0] return y
SOLVERS = { 'euler': Euler, 'midpoint': Midpoint, 'rk4': RK4, } def _check_inputs(func, y0, t): tensor_input = False if torch.is_tensor(y0): tensor_input = True y0 = (y0,) _base_nontuple_func_ = func func = lambda t, y, u: (_base_nontuple_func_(t, y[0], u),) assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' for y0_ in y0: assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format( type(y0_)) for y0_ in y0: if not torch.is_floating_point(y0_): raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type())) if not torch.is_floating_point(t): raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type())) return tensor_input, func, y0, t