from termcolor import colored
import collections
import datetime
import dateutil.tz
import json
import logging
import os
import os.path as osp
import sys
import tempfile
import subprocess
import shutil
from torch.utils.tensorboard import SummaryWriter
__all__ = ['setup', 'scoped_setup']
class KVWriter(object):
def writekvs(self, kvs):
raise NotImplementedError
class SeqWriter(object):
def writeseq(self, seq):
raise NotImplementedError
class HumanOutputFormat(KVWriter, SeqWriter):
def __init__(self, filename_or_file):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, 'wt')
self.own_file = True
else:
assert hasattr(filename_or_file,
'read'), 'expected file or str, got %s' % filename_or_file
self.file = filename_or_file
self.own_file = False
def writekvs(self, kvs):
# Create strings for printing
key2str = {}
for (key, val) in sorted(kvs.items()):
if isinstance(val, float):
valstr = '%-8.6g' % (val,)
else:
valstr = str(val)
key2str[self._truncate(key)] = self._truncate(valstr)
# Find max widths
if len(key2str) == 0:
print('WARNING: tried to write empty key-value dict')
return
else:
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write out the data
dashes = '-' * (keywidth + valwidth + 7)
lines = [dashes]
for (key, val) in sorted(key2str.items()):
lines.append('| %s%s | %s%s |' % (
key,
' ' * (keywidth - len(key)),
val,
' ' * (valwidth - len(val)),
))
lines.append(dashes)
self.file.write('\n'.join(lines) + '\n')
# Flush the output to the file
self.file.flush()
def _truncate(self, s):
return s[:30] + '...' if len(s) > 33 else s
def writeseq(self, seq):
for arg in seq:
self.file.write(arg)
self.file.write('\n')
self.file.flush()
def close(self):
if self.own_file:
self.file.close()
class JSONOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, 'wt')
def writekvs(self, kvs):
for k, v in sorted(kvs.items()):
if hasattr(v, 'dtype'):
v = v.tolist()
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + '\n')
self.file.flush()
def close(self):
self.file.close()
class CSVOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, 'w+t')
self.keys = []
self.sep = ','
def writekvs(self, kvs):
# Add our current row to the history
extra_keys = kvs.keys() - self.keys
if extra_keys:
self.keys.extend(extra_keys)
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(',')
self.file.write(k)
self.file.write('\n')
for line in lines[1:]:
self.file.write(line[:-1])
self.file.write(self.sep * len(extra_keys))
self.file.write('\n')
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(',')
v = kvs.get(k)
if v:
self.file.write(str(v))
self.file.write('\n')
self.file.flush()
def close(self):
self.file.close()
class TensorboardOutputFormat(KVWriter):
def __init__(self, dirname):
self._writer = SummaryWriter(dirname)
self.step = 0
def writekvs(self, kvs):
for k, v in kvs.items():
self._writer.add_scalar(k, v, self.step)
self.step += 1
def close(self):
self._writer.close()
class _MyFormatter(logging.Formatter):
def format(self, record):
date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
msg = '%(message)s'
if record.levelno == logging.WARNING:
fmt = date + ' ' + colored('WRN', 'red', attrs=['blink']) + ' ' + msg
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
fmt = date + ' ' + colored('ERR', 'red', attrs=['blink', 'underline']) + ' ' + msg
else:
fmt = date + ' ' + msg
if hasattr(self, '_style'):
# Python3 compatibilty
self._style._fmt = fmt
self._fmt = fmt
return super(_MyFormatter, self).format(record)
class Logger:
CURRENT = None
def __init__(self, dirname, output_formats):
self.name2val = collections.defaultdict(float)
self.name2cnt = collections.defaultdict(int)
self.dirname = dirname
self.output_formats = output_formats
self._logger = logging.getLogger('ceem')
self._logger.propagate = False
self._logger.setLevel(logging.INFO)
self._handler = logging.StreamHandler(sys.stdout)
self._handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
self._logger.addHandler(self._handler)
def logkv(self, key, val):
self.name2val[key] = val
def logkv_mean(self, key, val):
if val is None:
self.name2val[key] = None
return
oldval, cnt = self.name2val[key], self.name2cnt[key]
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
self.name2cnt[key] = cnt + 1
def dumpkvs(self):
for fmt in self.output_formats:
if isinstance(fmt, KVWriter):
fmt.writekvs(self.name2val)
self.name2val.clear()
self.name2cnt.clear()
def add_image(self, key, img):
for fmt in self.output_formats:
if isinstance(fmt, TensorboardOutputFormat):
fmt._writer.add_image(key, img, fmt.step)
def add_figure(self, key, fig, copy=False):
for fmt in self.output_formats:
if isinstance(fmt, TensorboardOutputFormat):
fmt._writer.add_figure(key, fig, fmt.step)
if copy and self.dirname is not None:
fig.savefig(
osp.join(self.dirname, "figs", "{}_epoch={}.png".format(key, fmt.step)))
def add_text(self, key, txt):
for fmt in self.output_formats:
if isinstance(fmt, TensorboardOutputFormat):
fmt._writer.add_text(key, txt, fmt.step)
def add_hist(self, key, hist):
for fmt in self.output_formats:
if isinstance(fmt, TensorboardOutputFormat):
fmt._writer.add_histogram(key, hist, fmt.step)
def close(self):
for fmt in self.output_formats:
fmt.close()
def get_dir(self):
return self.dirname
def make_output_format(format, ev_dir):
os.makedirs(ev_dir, exist_ok=True)
if format == 'stdout':
return HumanOutputFormat(sys.stdout)
elif format == 'json':
return JSONOutputFormat(osp.join(ev_dir, 'progress.json'))
elif format == 'csv':
return CSVOutputFormat(osp.join(ev_dir, 'progress.csv'))
elif format == 'tensorboard':
return TensorboardOutputFormat(osp.join(ev_dir, 'tb'))
else:
raise ValueError('Unknown format specified: %s' % (format,))
Logger.CURRENT = Logger(dirname=None, output_formats=[HumanOutputFormat(sys.stdout)])
def _get_time_str():
return datetime.datetime.now().strftime('%m%d-%H%M%S')
[docs]def setup(dirname=None, format_strs=['stdout', 'tensorboard', 'csv'], action=None):
if dirname is None:
dirname = os.getenv('CEEM_LOGDIR')
if dirname is None:
dirname = osp.join(tempfile.gettempdir(),
datetime.datetime.now().strftime("ceem-%Y-%m-%d-%H-%M-%S-%f"))
if os.path.isdir(dirname) and len(os.listdir(dirname)):
if not action:
warn("""\
Log directory {} exists! Please either backup/delete it, or use a new directory.""".format(dirname))
warn("""\
If you're resuming from a previous run you can choose to keep it.""")
info("Select Action: k (keep) / d (delete) / q (quit):")
while not action:
action = input().lower().strip()
act = action
if act == 'b':
backup_name = dirname + _get_time_str()
shutil.move(dirname, backup_name)
info("Directory '{}' backuped to '{}'".format(dirname, backup_name)) # noqa: F821
elif act == 'd':
shutil.rmtree(dirname)
elif act == 'n':
dirname = dirname + _get_time_str()
info("Use a new log directory {}".format(dirname)) # noqa: F821
elif act == 'k':
pass
elif act == 'q':
raise OSError("Directory {} exits!".format(dirname))
else:
raise ValueError("Unknown action: {}".format(act))
os.makedirs(dirname, exist_ok=True)
os.makedirs(osp.join(dirname, 'figs'), exist_ok=True)
os.makedirs(osp.join(dirname, 'ckpts'), exist_ok=True)
output_formats = [make_output_format(f, dirname) for f in format_strs]
Logger.CURRENT = Logger(dirname=dirname, output_formats=output_formats)
hdl = logging.FileHandler(filename=osp.join(dirname, 'log.log'), encoding='utf-8', mode='w')
hdl.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
Logger.CURRENT._logger.removeHandler(Logger.CURRENT._handler)
Logger.CURRENT._logger.addHandler(hdl)
Logger.CURRENT._logger.info("Argv: " + ' '.join(sys.argv))
has_git = True
timestamp = datetime.datetime.now(dateutil.tz.tzlocal()).strftime('%Y_%m_%d_%H_%M_%S')
try:
current_commit = subprocess.check_output(["git", "rev-parse",
"HEAD"]).strip().decode("utf-8")
clean_state = len(subprocess.check_output(["git", "status", "--porcelain"])) == 0
except subprocess.CalledProcessError as _:
Logger.CURRENT._logger.warn("Warning: failed to execute git commands")
has_git = False
if has_git:
if clean_state:
Logger.CURRENT._logger.info("Commit: {}".format(current_commit))
else:
Logger.CURRENT._logger.info("Commit: {}_dirty_{}".format(current_commit, timestamp))
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug', 'setLevel']
# export logger functions
for func in _LOGGING_METHOD:
locals()[func] = getattr(Logger.CURRENT._logger, func)
__all__.append(func)
def get_dir():
return Logger.CURRENT.get_dir()
def logkv(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
"""
Logger.CURRENT.logkv(key, val)
def logkvs(d):
"""
Log a dictionary of key-value pairs
"""
for (k, v) in d.items():
logkv(k, v)
def dumpkvs():
"""
Write all of the diagnostics from the current iteration
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
"""
Logger.CURRENT.dumpkvs()
def add_image(key, img):
Logger.CURRENT.add_image(key, img)
def add_text(key, txt):
Logger.CURRENT.add_text(key, txt)
def add_hist(key, hist):
Logger.CURRENT.add_hist(key, hist)
def add_figure(key, fig, copy=False):
Logger.CURRENT.add_figure(key, fig, copy)
def set_step(stepval):
for fmt in Logger.CURRENT.output_formats:
if isinstance(fmt, TensorboardOutputFormat):
fmt.step = stepval
[docs]class scoped_setup(object):
def __init__(self, dirname=None, format_strs=None, action='b'):
self.dirname = dirname
self.format_strs = format_strs
self.prevlogger = None
self.action = action
def __enter__(self):
self.prevlogger = Logger.CURRENT
setup(dirname=self.dirname, format_strs=self.format_strs, action=self.action)
def __exit__(self, *args):
Logger.CURRENT.close()
Logger.CURRENT = self.prevlogger