Source code for ceem.nested

# Based on https://raw.githubusercontent.com/google-research/batch-ppo/master/agents/tools/nested.py
"""Tools for manipulating nested tuples, list, and dictionaries."""

# Disable linter warning for using `flatten` as argument name.
# pylint: disable=redefined-outer-name

_builtin_zip = zip
_builtin_map = map
_builtin_filter = filter


[docs]def zip_(*structures, **kwargs): # pylint: disable=differing-param-doc,missing-param-doc """Combine corresponding elements in multiple nested structure to tuples. The nested structures can consist of any combination of lists, tuples, and dicts. All provided structures must have the same nesting. Args: *structures: Nested structures. flatten: Whether to flatten the resulting structure into a tuple. Keys of dictionaries will be discarded. Returns: Nested structure. """ # Named keyword arguments are not allowed after *args in Python 2. flatten = kwargs.pop('flatten', False) assert not kwargs, 'zip() got unexpected keyword arguments.' return map(lambda *x: x if len(x) > 1 else x[0], *structures, flatten=flatten)
[docs]def map_(function, *structures, **kwargs): # pylint: disable=differing-param-doc,missing-param-doc """Apply a function to every element in a nested structure. If multiple structures are provided as input, their structure must match and the function will be applied to corresponding groups of elements. The nested structure can consist of any combination of lists, tuples, and dicts. Args: function: The function to apply to the elements of the structure. Receives one argument for every structure that is provided. *structures: One of more nested structures. flatten: Whether to flatten the resulting structure into a tuple. Keys of dictionaries will be discarded. Returns: Nested structure. """ # Named keyword arguments are not allowed after *args in Python 2. flatten = kwargs.pop('flatten', False) assert not kwargs, 'map() got unexpected keyword arguments.' def impl(function, *structures): if len(structures) == 0: # pylint: disable=len-as-condition return structures if all(isinstance(s, (tuple, list)) for s in structures): if len(set(len(x) for x in structures)) > 1: raise ValueError('Cannot merge tuples or lists of different length.') args = tuple((impl(function, *x) for x in _builtin_zip(*structures))) if hasattr(structures[0], '_fields'): # namedtuple return type(structures[0])(*args) else: # tuple, list return type(structures[0])(args) if all(isinstance(s, dict) for s in structures): if len(set(frozenset(x.keys()) for x in structures)) > 1: raise ValueError('Cannot merge dicts with different keys.') merged = {k: impl(function, *(s[k] for s in structures)) for k in structures[0]} return type(structures[0])(merged) return function(*structures) result = impl(function, *structures) if flatten: result = flatten_(result) return result
[docs]def flatten_(structure): """Combine all leaves of a nested structure into a tuple. The nested structure can consist of any combination of tuples, lists, and dicts. Dictionary keys will be discarded but values will ordered by the sorting of the keys. Args: structure: Nested structure. Returns: Flat tuple. """ if isinstance(structure, dict): if structure: structure = zip(*sorted(structure.items(), key=lambda x: x[0]))[1] else: # Zip doesn't work on an the items of an empty dictionary. structure = () if isinstance(structure, (tuple, list)): result = [] for element in structure: result += flatten_(element) return tuple(result) return (structure,)
[docs]def filter_(predicate, *structures, **kwargs): # pylint: disable=differing-param-doc,missing-param-doc, too-many-branches """Select elements of a nested structure based on a predicate function. If multiple structures are provided as input, their structure must match and the function will be applied to corresponding groups of elements. The nested structure can consist of any combination of lists, tuples, and dicts. Args: predicate: The function to determine whether an element should be kept. Receives one argument for every structure that is provided. *structures: One of more nested structures. flatten: Whether to flatten the resulting structure into a tuple. Keys of dictionaries will be discarded. Returns: Nested structure. """ # Named keyword arguments are not allowed after *args in Python 2. flatten = kwargs.pop('flatten', False) assert not kwargs, 'filter() got unexpected keyword arguments.' def impl(predicate, *structures): if len(structures) == 0: # pylint: disable=len-as-condition return structures if all(isinstance(s, (tuple, list)) for s in structures): if len(set(len(x) for x in structures)) > 1: raise ValueError('Cannot merge tuples or lists of different length.') # Only wrap in tuples if more than one structure provided. if len(structures) > 1: filtered = (impl(predicate, *x) for x in _builtin_zip(*structures)) else: filtered = (impl(predicate, x) for x in structures[0]) # Remove empty containers and construct result structure. if hasattr(structures[0], '_fields'): # namedtuple filtered = (x if x != () else None for x in filtered) return type(structures[0])(*filtered) else: # tuple, list filtered = (x for x in filtered if not isinstance(x, (tuple, list, dict)) or x) return type(structures[0])(filtered) if all(isinstance(s, dict) for s in structures): if len(set(frozenset(x.keys()) for x in structures)) > 1: raise ValueError('Cannot merge dicts with different keys.') # Only wrap in tuples if more than one structure provided. if len(structures) > 1: filtered = {k: impl(predicate, *(s[k] for s in structures)) for k in structures[0]} else: filtered = {k: impl(predicate, v) for k, v in structures[0].items()} # Remove empty containers and construct result structure. filtered = { k: v for k, v in filtered.items() if not isinstance(v, (tuple, list, dict)) or v } return type(structures[0])(filtered) if len(structures) > 1: return structures if predicate(*structures) else () else: return structures[0] if predicate(structures[0]) else () result = impl(predicate, *structures) if flatten: result = flatten_(result) return result
# pylint: disable=redefined-builtin zip = zip_ map = map_ flatten = flatten_ filter = filter_