Source code for ceem.data_utils

import os
import torch

opj = os.path.join

device = 'cpu'


[docs]def load_helidata(datadir, split, return_files=False): data = [] controls = [] files = os.listdir(opj(datadir, f'{split}')) for f in files: if f[0] == '.': continue time, dat, cont = torch.load(opj(f'{datadir}', f'{split}', f'{f}')) controls.append(cont) data.append(dat) data = torch.stack(data, dim=0) cont = torch.stack(controls, dim=0) target = data[:, :, -6:] data = torch.cat((cont, data[:, :, 7:-6]), dim=2) data = data.to(device) target = target.to(device) if return_files: return data, target, files else: return data, target
[docs]def load_statistics(datadir): data_mean, data_std, controls_mean, controls_std = torch.load(opj(datadir, 'statistics.pt')) y_mean = data_mean[-6:] u_mean = torch.cat([controls_mean, data_mean[7:-6]], dim=-1) y_std = data_std[-6:] u_std = torch.cat([controls_std, data_std[7:-6]], dim=-1) return y_mean, y_std, u_mean, u_std