Source code for ceem.exp_utils
import torch
import numpy as np
from ceem.baseline_utils import LagModel
from ceem.systems import DiscreteLinear
from ceem.data_utils import *
from ceem.smoother import EKF
from scipy.io import loadmat
[docs]def compute_rms(y, ypred, y_std=None):
    """
    Compute RMS prediction error on type of trajectory
    Args: y (torch.tensor): (B,T,m) tensor of true observations
          ypred (torch.tensor): (B,T,m) tensor of predicted observations
          demonames (list): len=T list of demo names
          y_std (torch.tensor): (m,) standard deviations
    Returns:
          (B,) rms errors on each demo
    """
    B,T,m = y.shape
    if y_std is None:
        y_std = torch.ones(y.shape[-1])
    all_rms = ((y-ypred)* y_std.view(1,1,m))**2 
    all_rms = all_rms.sum(-1).sum(-1) / (T*m)
    all_rms = all_rms.sqrt()
    return all_rms 
[docs]def gen_ypred_benchmark(net, H, u):
    
    B,T,m = u.shape
    ypreds = []
    for t in range(T-H+1):
        D = u[:, t:t+H]
        ypreds.append(net(D))
    ypred = torch.stack(ypreds,dim=1)
    return ypred 
[docs]def gen_ypred_model(model, u, y, s0fac=1.0, rfac=1.0, qfac=1.0):
    xdim = model._xdim
    B,T,udim = u.shape
    _,_,ydim = y.shape
    x0 = torch.zeros(B,1,xdim)
    sigma0 = torch.eye(xdim) * s0fac
    Q = torch.eye(xdim) * qfac
    R = torch.eye(ydim) * rfac
    x, ypred = EKF(x0, y, u, sigma0, Q, R, model)
    return ypred