Source code for ceem.baseline_utils
#
# File: models.py
#
import torch
from ceem.nn import LNMLP
DEFAULT_NN_KWARGS = lambda H: dict(input_size=10*H,
hidden_sizes=[32]*8,
output_size = 6,
activation='tanh', gain=1.0, ln=False)
[docs]class LagModel(torch.nn.Module):
def __init__(self, neural_net_kwargs, H):
super().__init__()
self.net = LNMLP(**neural_net_kwargs)
[docs] def forward(self, data_batch):
# data batch is size ntrajs * H * ndim.
# Controls are the first 4 dimensions of ndim, then 3 velocities, then 3 rotation rates
ntrajs = data_batch.shape[0]
state_t = data_batch[:, -1:, 4:]
pred = self.net(data_batch.view(ntrajs, -1))
return pred