Source code for ceem.nn

# File: nn.py
#
from math import sqrt

import torch
from torch.nn import Module, Parameter, init

ACTIVATIONS = {'tanh': torch.nn.Tanh, 'relu': torch.nn.ReLU}


[docs]class LNMLP(torch.nn.Module): def __init__(self, input_size, hidden_sizes, output_size, activation='relu', gain=1.0, ln=False): self._hidden_sizes = hidden_sizes self._gain = gain self._ln = ln super().__init__() if len(hidden_sizes) > 0: activation = ACTIVATIONS[activation] layers = [torch.nn.Linear(input_size, hidden_sizes[0])] layers.append(activation()) if ln: layers.append(torch.nn.LayerNorm(hidden_sizes[0])) for i in range(len(hidden_sizes) - 1): layers.append(torch.nn.Linear(hidden_sizes[i], hidden_sizes[i + 1])) layers.append(activation()) if ln: layers.append(torch.nn.LayerNorm(hidden_sizes[i + 1])) layers.append(torch.nn.Linear(hidden_sizes[-1], output_size)) else: layers = [torch.nn.Linear(input_size, output_size)] self._layers = layers self.mlp = torch.nn.Sequential(*layers) self.reset_params(gain=gain)
[docs] def forward(self, inp): return self.mlp(inp)
[docs] def reset_params(self, gain=1.0): self.apply(lambda x: weights_init_mlp(x, gain=gain))
[docs]def weights_init_mlp(m, gain=1.0): classname = m.__class__.__name__ if classname.find('Linear') != -1: init_normc_(m.weight.data, gain) if m.bias is not None: m.bias.data.fill_(0)
[docs]def init_normc_(weight, gain=1.0): weight.normal_(0, 1) weight *= gain / torch.sqrt(weight.pow(2).sum(1, keepdim=True))