Source code for ceem.learner

#
# File: learner.py
#

from copy import deepcopy

import numpy as np
import torch
from scipy.optimize import least_squares, minimize
from torch.autograd import backward
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from ceem import utils
from ceem.opt_criteria import GroupCriterion, STRParamCriterion


[docs]def learner(model, criterion_list, criterion_x_list, opt_list, params_list, crit_kwargs_list, opt_kwargs_list=None, subsetinds=None): """ Generic learner function Args: model (DiscreteDynamicalSystem) criterion_list: list of Criterion instances criterion_x_list: list of torch.tensor to pass to criteria opt_list: list of optimizers (see OPTIMIZERS) params_list: list of list of parameters for each optimizer crit_kwargs_list: list of kwargs for each criterion opt_kwargs_list: kwargs to optimizer subsetinds: array specifying the batch indices to operate on Returns: opt_result_list (list): """ if opt_kwargs_list is None: opt_kwargs_list = [DEFAULT_OPTIMIZER_KWARGS[opt] for opt in opt_list] # set all requires_grad to false for x in criterion_x_list: x.requires_grad_(False) opt_result_list = [] for i in range(len(criterion_list)): criterion = criterion_list[i] criterion_x = criterion_x_list[i] opt = OPTIMIZERS[opt_list[i]] params = params_list[i] crit_kwargs = crit_kwargs_list[i] crit_kwargs['inds'] = subsetinds opt_kwargs = opt_kwargs_list[i] opt_result = opt(criterion, model, criterion_x, params, crit_kwargs, opt_kwargs) opt_result_list.append(opt_result) return opt_result_list
[docs]def scipy_minimize(criterion, model, criterion_x, params, crit_kwargs, opt_kwargs): """ Wrapper function to call scipy optimizers """ opt_kwargs = deepcopy(opt_kwargs) if 'tr_rho' in opt_kwargs: tr_rho = opt_kwargs.pop('tr_rho') criterion = GroupCriterion([criterion, STRParamCriterion(tr_rho, params)]) B, T, n = criterion_x.shape vparams0 = parameters_to_vector(params).clone().detach() def eval_f(vparams): vparams = torch.tensor(vparams).to(torch.get_default_dtype()) vector_to_parameters(vparams, params) with torch.no_grad(): loss = criterion(model, criterion_x, **crit_kwargs) vector_to_parameters(vparams0, params) return loss.detach().numpy() def eval_g(vparams): vparams = torch.tensor(vparams).to(torch.get_default_dtype()) vector_to_parameters(vparams, params) loss = criterion(model, criterion_x, **crit_kwargs) loss.backward() grads = torch.cat([ p.grad.view(-1) if p.grad is not None else torch.zeros_like(p).view(-1) for p in params ]) grads = grads.detach().numpy() vector_to_parameters(vparams0, params) return grads start_feval = eval_f(vparams0.numpy()) start_gradnorm = np.linalg.norm(eval_g(vparams0.numpy())) with utils.Timer() as time: opt_result_ = minimize(eval_f, vparams0.numpy(), jac=eval_g, **opt_kwargs) vparams = opt_result_['x'] end_feval = eval_f(vparams) end_gradnorm = np.linalg.norm(eval_g(vparams)) vparams = torch.tensor(vparams).to(torch.get_default_dtype()) vector_to_parameters(vparams, params) net_update_norm = (vparams - vparams0).norm() opt_result = {'net_update_norm': net_update_norm.detach().item()} opt_result['start_feval'] = float(start_feval) opt_result['start_gradnorm'] = start_gradnorm opt_result['end_feval'] = float(end_feval) opt_result['end_gradnorm'] = end_gradnorm opt_result['time'] = time.dt return opt_result
TORCH_OPTIMIZERS = {'LBFGS': torch.optim.LBFGS, 'SGD': torch.optim.SGD, 'Adam': torch.optim.Adam}
[docs]def ensure_default_torch_kwargs(opt_kwargs): if 'method' not in opt_kwargs: opt_kwargs['method'] = 'LBFGS' if 'lr' not in opt_kwargs: if opt_kwargs['method'] == 'LBFGS': opt_kwargs['lr'] = 1e-1 else: opt_kwargs['lr'] = 1e-4 if 'nepochs' not in opt_kwargs: opt_kwargs['nepochs'] = 100 if 'max_grad_norm' not in opt_kwargs: opt_kwargs['max_grad_norm'] = 1.0 return opt_kwargs
[docs]def torch_minimize(criterion, model, criterion_x, params, crit_kwargs, opt_kwargs): """ Wrapper function to use torch.optim optimizers """ opt_kwargs = deepcopy(opt_kwargs) opt_kwargs = ensure_default_torch_kwargs(opt_kwargs) if 'tr_rho' in opt_kwargs: tr_rho = opt_kwargs.pop('tr_rho') criterion = GroupCriterion([criterion, STRParamCriterion(tr_rho, params)]) method = opt_kwargs.pop('method') nepochs = opt_kwargs.pop('nepochs') max_grad_norm = opt_kwargs.pop('max_grad_norm') opt = TORCH_OPTIMIZERS[method](params, **opt_kwargs) def closure(): opt.zero_grad() loss = criterion(model, criterion_x, **crit_kwargs) loss.backward() return loss start_feval = closure() start_gradnorm = utils.get_grad_norm(params) vparams0 = parameters_to_vector(params).clone().detach() with utils.Timer() as time: for epoch in range(nepochs): if method == 'LBFGS': opt.step(closure=closure) loss = closure() else: loss = closure() torch.nn.utils.clip_grad_norm_(params, max_grad_norm) opt.step() if epoch % 100 == 0: print('Epoch %d, Loss %.3e' % (epoch, float(loss))) end_feval = closure() end_gradnorm = utils.get_grad_norm(params) vparams = parameters_to_vector(params).clone().detach() net_update_norm = (vparams - vparams0).norm() opt_result = { 'time': time.dt, 'start_feval': start_feval.detach().item(), 'start_gradnorm': start_gradnorm, 'end_feval': end_feval.detach().item(), 'end_gradnorm': end_gradnorm, 'net_update_norm': net_update_norm.detach().item() } return opt_result
OPTIMIZERS = {'scipy_minimize': scipy_minimize, 'torch_minimize': torch_minimize} DEFAULT_OPTIMIZER_KWARGS = { 'scipy_minimize': { 'method': 'Nelder-Mead' }, 'torch_minimize': ensure_default_torch_kwargs({}) }