Source code for ceem.exp_utils
import torch
import numpy as np
from ceem.baseline_utils import LagModel
from ceem.systems import DiscreteLinear
from ceem.data_utils import *
from ceem.smoother import EKF
from scipy.io import loadmat
[docs]def compute_rms(y, ypred, y_std=None):
"""
Compute RMS prediction error on type of trajectory
Args: y (torch.tensor): (B,T,m) tensor of true observations
ypred (torch.tensor): (B,T,m) tensor of predicted observations
demonames (list): len=T list of demo names
y_std (torch.tensor): (m,) standard deviations
Returns:
(B,) rms errors on each demo
"""
B,T,m = y.shape
if y_std is None:
y_std = torch.ones(y.shape[-1])
all_rms = ((y-ypred)* y_std.view(1,1,m))**2
all_rms = all_rms.sum(-1).sum(-1) / (T*m)
all_rms = all_rms.sqrt()
return all_rms
[docs]def gen_ypred_benchmark(net, H, u):
B,T,m = u.shape
ypreds = []
for t in range(T-H+1):
D = u[:, t:t+H]
ypreds.append(net(D))
ypred = torch.stack(ypreds,dim=1)
return ypred
[docs]def gen_ypred_model(model, u, y, s0fac=1.0, rfac=1.0, qfac=1.0):
xdim = model._xdim
B,T,udim = u.shape
_,_,ydim = y.shape
x0 = torch.zeros(B,1,xdim)
sigma0 = torch.eye(xdim) * s0fac
Q = torch.eye(xdim) * qfac
R = torch.eye(ydim) * rfac
x, ypred = EKF(x0, y, u, sigma0, Q, R, model)
return ypred