from collections import OrderedDict
import torch
from torch import nn
from ceem.dynamics import AnalyticObsJacMixin, C2DSystem, DynJacMixin
from ceem.utils import set_rng_seed
[docs]class LorenzSystem(C2DSystem, nn.Module, AnalyticObsJacMixin, DynJacMixin):
"""Basic Lorenz Attractor
"""
def __init__(self, sigmas, rhos, betas, H, C, dt, method='midpoint'):
"""
Args:
sigma (torch.tensor): (n//3,) torch tensor
rho (torch.tensor): (n//3,) torch tensor
beta (torch.tensor): (n//3,) torch tensor
H (torch.tensor): (n,n) coupling matrix
C (torch.tensor): (ydim, n) observation matrix
"""
C2DSystem.__init__(self, dt=dt, method=method)
nn.Module.__init__(self)
n = H.shape[0]
assert n % 3 == 0, 'n must be a multiple of 3'
self._n = n
self._sigmas = nn.Parameter(sigmas.unsqueeze(0).unsqueeze(0))
self._rhos = nn.Parameter(rhos.unsqueeze(0).unsqueeze(0))
self._betas = nn.Parameter(betas.unsqueeze(0).unsqueeze(0))
self._H = nn.Parameter(H.unsqueeze(0).unsqueeze(0))
self._C = C.unsqueeze(0).unsqueeze(0) # not a learned parameter
self._xdim = n
self._udim = None
self._ydim = C.shape[0]
[docs] def step_derivs(self, t, x, u=None):
n_ = self._n // 3
x_ = x[:, :, :n_]
y_ = x[:, :, n_:2 * n_]
z_ = x[:, :, 2 * n_:]
dxdt = self._sigmas * (y_ - x_)
dydt = x_ * (self._rhos - z_) - y_
dzdt = x_ * y_ - self._betas * z_
dinpdt = torch.cat([dxdt, dydt, dzdt], dim=2)
# coupling
dinpdt = dinpdt + (self._H @ x.unsqueeze(3)).squeeze(3)
return dinpdt
[docs] def observe(self, t, x, u=None):
return (self._C @ x.unsqueeze(3)).squeeze(3)
[docs] def jac_obs_x(self, t, x, u=None):
B, T, n = x.shape
return self._C.repeat(B, T, 1, 1)
[docs] def jac_obs_theta(self, t, x, u=None):
# observation doesnt depend on learned params
jacobians = OrderedDict([
('_sigma', None),
('_rho', None),
('_beta', None),
])
return jacobians
[docs]def default_lorenz_system(k, seed=4, obsdif=2, dt=0.04):
currng = torch.random.get_rng_state()
torch.manual_seed(seed)
n = 3 * k
sigmas = torch.tensor([10.] * k)
rhos = torch.tensor([28.] * k)
betas = torch.tensor([8. / 3] * k)
# generate coupling matrix
H = torch.randn(n, n)
for i in range(k):
H[i::k, i::k] = 0.
H = 5. * H / H.norm()
obsdim = n - obsdif
C = torch.randn(obsdim, n)
true_system = LorenzSystem(sigmas, rhos, betas, H, C, dt, method='rk4')
torch.random.set_rng_state(currng)
return true_system
if __name__ == '__main__':
main()