#
# File: odesolver.py
#
import abc
import torch
[docs]class Euler:
[docs]    @staticmethod
    def step_func(func, t, dt, y, u, transforms=None):
        return tuple(dt * f_ for f_ in func(t, y, u=u)) 
    @property
    def order(self):
        return 1 
[docs]class Midpoint:
[docs]    @staticmethod
    def step_func(func, t, dt, y, u, transforms=None):
        y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y, u=u)))
        y_mid = tuple(trans(y_) for y_, trans in zip(y_mid, transforms))
        return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid, u=u)) 
    @property
    def order(self):
        return 2 
[docs]class RK4:
[docs]    @staticmethod
    def step_func(func, t, dt, y, u, transforms=None):
        return rk4_alt_step_func(func, t, dt, y, u=u) 
    @property
    def order(self):
        return 4 
[docs]def rk4_alt_step_func(func, t, dt, y, k1=None, u=None):
    """Smaller error with slightly more compute."""
    if k1 is None:
        k1 = func(t, y, u=u)
    k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1)), u=u)
    k3 = func(t + dt * 2 / 3,
              tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2)), u=u)
    k4 = func(t + dt,
              tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3)), u=u)
    return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8)
                 for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) 
[docs]def odestep(func, t, dt, y0, u=None, method='midpoint', transforms=None):
    tensor_input, func, y0, t = _check_inputs(func, y0, t)
    if transforms is None:
        transforms = [lambda x: x for _ in range(len(y0))]
    dy = SOLVERS[method].step_func(func, t, dt, y0, u=u, transforms=transforms)
    y = tuple(trans(y0_ + dy_) for y0_, dy_, trans in zip(y0, dy, transforms))
    if tensor_input:
        y = y[0]
    return y 
SOLVERS = {
    'euler': Euler,
    'midpoint': Midpoint,
    'rk4': RK4,
}
def _check_inputs(func, y0, t):
    tensor_input = False
    if torch.is_tensor(y0):
        tensor_input = True
        y0 = (y0,)
        _base_nontuple_func_ = func
        func = lambda t, y, u: (_base_nontuple_func_(t, y[0], u),)
        assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
    for y0_ in y0:
        assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(
            type(y0_))
    for y0_ in y0:
        if not torch.is_floating_point(y0_):
            raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
    if not torch.is_floating_point(t):
        raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))
    return tensor_input, func, y0, t