#
# File: learner.py
#
from copy import deepcopy
import numpy as np
import torch
from scipy.optimize import least_squares, minimize
from torch.autograd import backward
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from ceem import utils
from ceem.opt_criteria import GroupCriterion, STRParamCriterion
[docs]def learner(model, criterion_list, criterion_x_list, opt_list, params_list, crit_kwargs_list,
opt_kwargs_list=None, subsetinds=None):
"""
Generic learner function
Args:
model (DiscreteDynamicalSystem)
criterion_list: list of Criterion instances
criterion_x_list: list of torch.tensor to pass to criteria
opt_list: list of optimizers (see OPTIMIZERS)
params_list: list of list of parameters for each optimizer
crit_kwargs_list: list of kwargs for each criterion
opt_kwargs_list: kwargs to optimizer
subsetinds: array specifying the batch indices to operate on
Returns:
opt_result_list (list):
"""
if opt_kwargs_list is None:
opt_kwargs_list = [DEFAULT_OPTIMIZER_KWARGS[opt] for opt in opt_list]
# set all requires_grad to false
for x in criterion_x_list:
x.requires_grad_(False)
opt_result_list = []
for i in range(len(criterion_list)):
criterion = criterion_list[i]
criterion_x = criterion_x_list[i]
opt = OPTIMIZERS[opt_list[i]]
params = params_list[i]
crit_kwargs = crit_kwargs_list[i]
crit_kwargs['inds'] = subsetinds
opt_kwargs = opt_kwargs_list[i]
opt_result = opt(criterion, model, criterion_x, params, crit_kwargs, opt_kwargs)
opt_result_list.append(opt_result)
return opt_result_list
[docs]def scipy_minimize(criterion, model, criterion_x, params, crit_kwargs, opt_kwargs):
""" Wrapper function to call scipy optimizers
"""
opt_kwargs = deepcopy(opt_kwargs)
if 'tr_rho' in opt_kwargs:
tr_rho = opt_kwargs.pop('tr_rho')
criterion = GroupCriterion([criterion, STRParamCriterion(tr_rho, params)])
B, T, n = criterion_x.shape
vparams0 = parameters_to_vector(params).clone().detach()
def eval_f(vparams):
vparams = torch.tensor(vparams).to(torch.get_default_dtype())
vector_to_parameters(vparams, params)
with torch.no_grad():
loss = criterion(model, criterion_x, **crit_kwargs)
vector_to_parameters(vparams0, params)
return loss.detach().numpy()
def eval_g(vparams):
vparams = torch.tensor(vparams).to(torch.get_default_dtype())
vector_to_parameters(vparams, params)
loss = criterion(model, criterion_x, **crit_kwargs)
loss.backward()
grads = torch.cat([
p.grad.view(-1) if p.grad is not None else torch.zeros_like(p).view(-1) for p in params
])
grads = grads.detach().numpy()
vector_to_parameters(vparams0, params)
return grads
start_feval = eval_f(vparams0.numpy())
start_gradnorm = np.linalg.norm(eval_g(vparams0.numpy()))
with utils.Timer() as time:
opt_result_ = minimize(eval_f, vparams0.numpy(), jac=eval_g, **opt_kwargs)
vparams = opt_result_['x']
end_feval = eval_f(vparams)
end_gradnorm = np.linalg.norm(eval_g(vparams))
vparams = torch.tensor(vparams).to(torch.get_default_dtype())
vector_to_parameters(vparams, params)
net_update_norm = (vparams - vparams0).norm()
opt_result = {'net_update_norm': net_update_norm.detach().item()}
opt_result['start_feval'] = float(start_feval)
opt_result['start_gradnorm'] = start_gradnorm
opt_result['end_feval'] = float(end_feval)
opt_result['end_gradnorm'] = end_gradnorm
opt_result['time'] = time.dt
return opt_result
TORCH_OPTIMIZERS = {'LBFGS': torch.optim.LBFGS, 'SGD': torch.optim.SGD, 'Adam': torch.optim.Adam}
[docs]def ensure_default_torch_kwargs(opt_kwargs):
if 'method' not in opt_kwargs:
opt_kwargs['method'] = 'LBFGS'
if 'lr' not in opt_kwargs:
if opt_kwargs['method'] == 'LBFGS':
opt_kwargs['lr'] = 1e-1
else:
opt_kwargs['lr'] = 1e-4
if 'nepochs' not in opt_kwargs:
opt_kwargs['nepochs'] = 100
if 'max_grad_norm' not in opt_kwargs:
opt_kwargs['max_grad_norm'] = 1.0
return opt_kwargs
[docs]def torch_minimize(criterion, model, criterion_x, params, crit_kwargs, opt_kwargs):
""" Wrapper function to use torch.optim optimizers
"""
opt_kwargs = deepcopy(opt_kwargs)
opt_kwargs = ensure_default_torch_kwargs(opt_kwargs)
if 'tr_rho' in opt_kwargs:
tr_rho = opt_kwargs.pop('tr_rho')
criterion = GroupCriterion([criterion, STRParamCriterion(tr_rho, params)])
method = opt_kwargs.pop('method')
nepochs = opt_kwargs.pop('nepochs')
max_grad_norm = opt_kwargs.pop('max_grad_norm')
opt = TORCH_OPTIMIZERS[method](params, **opt_kwargs)
def closure():
opt.zero_grad()
loss = criterion(model, criterion_x, **crit_kwargs)
loss.backward()
return loss
start_feval = closure()
start_gradnorm = utils.get_grad_norm(params)
vparams0 = parameters_to_vector(params).clone().detach()
with utils.Timer() as time:
for epoch in range(nepochs):
if method == 'LBFGS':
opt.step(closure=closure)
loss = closure()
else:
loss = closure()
torch.nn.utils.clip_grad_norm_(params, max_grad_norm)
opt.step()
if epoch % 100 == 0:
print('Epoch %d, Loss %.3e' % (epoch, float(loss)))
end_feval = closure()
end_gradnorm = utils.get_grad_norm(params)
vparams = parameters_to_vector(params).clone().detach()
net_update_norm = (vparams - vparams0).norm()
opt_result = {
'time': time.dt,
'start_feval': start_feval.detach().item(),
'start_gradnorm': start_gradnorm,
'end_feval': end_feval.detach().item(),
'end_gradnorm': end_gradnorm,
'net_update_norm': net_update_norm.detach().item()
}
return opt_result
OPTIMIZERS = {'scipy_minimize': scipy_minimize, 'torch_minimize': torch_minimize}
DEFAULT_OPTIMIZER_KWARGS = {
'scipy_minimize': {
'method': 'Nelder-Mead'
},
'torch_minimize': ensure_default_torch_kwargs({})
}