Source code for ceem.baseline_utils

#
# File: models.py
#
import torch

from ceem.nn import LNMLP

DEFAULT_NN_KWARGS = lambda H: dict(input_size=10*H, 
                    hidden_sizes=[32]*8,  
                    output_size = 6,
                    activation='tanh', gain=1.0, ln=False)

[docs]class LagModel(torch.nn.Module): def __init__(self, neural_net_kwargs, H): super().__init__() self.net = LNMLP(**neural_net_kwargs)
[docs] def forward(self, data_batch): # data batch is size ntrajs * H * ndim. # Controls are the first 4 dimensions of ndim, then 3 velocities, then 3 rotation rates ntrajs = data_batch.shape[0] state_t = data_batch[:, -1:, 4:] pred = self.net(data_batch.view(ntrajs, -1)) return pred