Source code for ceem.nn
# File: nn.py
#
from math import sqrt
import torch
from torch.nn import Module, Parameter, init
ACTIVATIONS = {'tanh': torch.nn.Tanh, 'relu': torch.nn.ReLU}
[docs]class LNMLP(torch.nn.Module):
def __init__(self, input_size, hidden_sizes, output_size, activation='relu', gain=1.0,
ln=False):
self._hidden_sizes = hidden_sizes
self._gain = gain
self._ln = ln
super().__init__()
if len(hidden_sizes) > 0:
activation = ACTIVATIONS[activation]
layers = [torch.nn.Linear(input_size, hidden_sizes[0])]
layers.append(activation())
if ln:
layers.append(torch.nn.LayerNorm(hidden_sizes[0]))
for i in range(len(hidden_sizes) - 1):
layers.append(torch.nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
layers.append(activation())
if ln:
layers.append(torch.nn.LayerNorm(hidden_sizes[i + 1]))
layers.append(torch.nn.Linear(hidden_sizes[-1], output_size))
else:
layers = [torch.nn.Linear(input_size, output_size)]
self._layers = layers
self.mlp = torch.nn.Sequential(*layers)
self.reset_params(gain=gain)
[docs] def forward(self, inp):
return self.mlp(inp)
[docs] def reset_params(self, gain=1.0):
self.apply(lambda x: weights_init_mlp(x, gain=gain))
[docs]def weights_init_mlp(m, gain=1.0):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
init_normc_(m.weight.data, gain)
if m.bias is not None:
m.bias.data.fill_(0)
[docs]def init_normc_(weight, gain=1.0):
weight.normal_(0, 1)
weight *= gain / torch.sqrt(weight.pow(2).sum(1, keepdim=True))