Source code for ceem.utils

# File: utils.py
#

import random
import resource
import sys
import timeit
from contextlib import contextmanager

import numpy as np
import torch


[docs]class Timer(object): def __enter__(self): self.t_start = timeit.default_timer() return self def __exit__(self, _1, _2, _3): self.t_end = timeit.default_timer() self.dt = self.t_end - self.t_start
[docs]class objectview(object): def __init__(self, d): self.__dict__ = d def __repr__(self): return str(self.__dict__)
[docs]def peak_memory_mb() -> float: """ Get peak memory usage for this process, as measured by max-resident-set size: https://unix.stackexchange.com/questions/30940/getrusage-system-call-what-is-maximum-resident-set-size Only works on OSX and Linux, returns 0.0 otherwise. """ if resource is None or sys.platform not in ('linux', 'darwin'): return 0.0 peak = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss # type: ignore if sys.platform == 'darwin': # On OSX the result is in bytes. return peak / 1_000_000 else: # On Linux the result is in kilobytes. return peak / 1_000
[docs]def set_rng_seed(rng_seed: int) -> None: random.seed(rng_seed) torch.manual_seed(rng_seed) np.random.seed(rng_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(rng_seed)
[docs]def disable_grad(vs): for v in vs: v.detach_()
[docs]@contextmanager def temp_require_grad(vs): prev_grad_status = [v.requires_grad for v in vs] require_and_zero_grads(vs) yield for v, status in zip(vs, prev_grad_status): v.requires_grad_(status)
[docs]@contextmanager def temp_disable_grad(vs): prev_grad_status = [v.requires_grad for v in vs] disable_grad(vs) yield for v, status in zip(vs, prev_grad_status): v.requires_grad_(status)
[docs]def get_grad_norm(params): # check grad norm total_norm = 0. for p in params: if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item()**2 total_norm = total_norm**(1. / 2) return total_norm