# Based on https://raw.githubusercontent.com/google-research/batch-ppo/master/agents/tools/nested.py
"""Tools for manipulating nested tuples, list, and dictionaries."""
# Disable linter warning for using `flatten` as argument name.
# pylint: disable=redefined-outer-name
_builtin_zip = zip
_builtin_map = map
_builtin_filter = filter
[docs]def zip_(*structures, **kwargs):
    # pylint: disable=differing-param-doc,missing-param-doc
    """Combine corresponding elements in multiple nested structure to tuples.
  The nested structures can consist of any combination of lists, tuples, and
  dicts. All provided structures must have the same nesting.
  Args:
    *structures: Nested structures.
    flatten: Whether to flatten the resulting structure into a tuple. Keys of
        dictionaries will be discarded.
  Returns:
    Nested structure.
  """
    # Named keyword arguments are not allowed after *args in Python 2.
    flatten = kwargs.pop('flatten', False)
    assert not kwargs, 'zip() got unexpected keyword arguments.'
    return map(lambda *x: x if len(x) > 1 else x[0], *structures, flatten=flatten) 
[docs]def map_(function, *structures, **kwargs):
    # pylint: disable=differing-param-doc,missing-param-doc
    """Apply a function to every element in a nested structure.
  If multiple structures are provided as input, their structure must match and
  the function will be applied to corresponding groups of elements. The nested
  structure can consist of any combination of lists, tuples, and dicts.
  Args:
    function: The function to apply to the elements of the structure. Receives
        one argument for every structure that is provided.
    *structures: One of more nested structures.
    flatten: Whether to flatten the resulting structure into a tuple. Keys of
        dictionaries will be discarded.
  Returns:
    Nested structure.
  """
    # Named keyword arguments are not allowed after *args in Python 2.
    flatten = kwargs.pop('flatten', False)
    assert not kwargs, 'map() got unexpected keyword arguments.'
    def impl(function, *structures):
        if len(structures) == 0:  # pylint: disable=len-as-condition
            return structures
        if all(isinstance(s, (tuple, list)) for s in structures):
            if len(set(len(x) for x in structures)) > 1:
                raise ValueError('Cannot merge tuples or lists of different length.')
            args = tuple((impl(function, *x) for x in _builtin_zip(*structures)))
            if hasattr(structures[0], '_fields'):  # namedtuple
                return type(structures[0])(*args)
            else:  # tuple, list
                return type(structures[0])(args)
        if all(isinstance(s, dict) for s in structures):
            if len(set(frozenset(x.keys()) for x in structures)) > 1:
                raise ValueError('Cannot merge dicts with different keys.')
            merged = {k: impl(function, *(s[k] for s in structures)) for k in structures[0]}
            return type(structures[0])(merged)
        return function(*structures)
    result = impl(function, *structures)
    if flatten:
        result = flatten_(result)
    return result 
[docs]def flatten_(structure):
    """Combine all leaves of a nested structure into a tuple.
  The nested structure can consist of any combination of tuples, lists, and
  dicts. Dictionary keys will be discarded but values will ordered by the
  sorting of the keys.
  Args:
    structure: Nested structure.
  Returns:
    Flat tuple.
  """
    if isinstance(structure, dict):
        if structure:
            structure = zip(*sorted(structure.items(), key=lambda x: x[0]))[1]
        else:
            # Zip doesn't work on an the items of an empty dictionary.
            structure = ()
    if isinstance(structure, (tuple, list)):
        result = []
        for element in structure:
            result += flatten_(element)
        return tuple(result)
    return (structure,) 
[docs]def filter_(predicate, *structures, **kwargs):
    # pylint: disable=differing-param-doc,missing-param-doc, too-many-branches
    """Select elements of a nested structure based on a predicate function.
  If multiple structures are provided as input, their structure must match and
  the function will be applied to corresponding groups of elements. The nested
  structure can consist of any combination of lists, tuples, and dicts.
  Args:
    predicate: The function to determine whether an element should be kept.
        Receives one argument for every structure that is provided.
    *structures: One of more nested structures.
    flatten: Whether to flatten the resulting structure into a tuple. Keys of
        dictionaries will be discarded.
  Returns:
    Nested structure.
  """
    # Named keyword arguments are not allowed after *args in Python 2.
    flatten = kwargs.pop('flatten', False)
    assert not kwargs, 'filter() got unexpected keyword arguments.'
    def impl(predicate, *structures):
        if len(structures) == 0:  # pylint: disable=len-as-condition
            return structures
        if all(isinstance(s, (tuple, list)) for s in structures):
            if len(set(len(x) for x in structures)) > 1:
                raise ValueError('Cannot merge tuples or lists of different length.')
            # Only wrap in tuples if more than one structure provided.
            if len(structures) > 1:
                filtered = (impl(predicate, *x) for x in _builtin_zip(*structures))
            else:
                filtered = (impl(predicate, x) for x in structures[0])
            # Remove empty containers and construct result structure.
            if hasattr(structures[0], '_fields'):  # namedtuple
                filtered = (x if x != () else None for x in filtered)
                return type(structures[0])(*filtered)
            else:  # tuple, list
                filtered = (x for x in filtered if not isinstance(x, (tuple, list, dict)) or x)
                return type(structures[0])(filtered)
        if all(isinstance(s, dict) for s in structures):
            if len(set(frozenset(x.keys()) for x in structures)) > 1:
                raise ValueError('Cannot merge dicts with different keys.')
            # Only wrap in tuples if more than one structure provided.
            if len(structures) > 1:
                filtered = {k: impl(predicate, *(s[k] for s in structures)) for k in structures[0]}
            else:
                filtered = {k: impl(predicate, v) for k, v in structures[0].items()}
            # Remove empty containers and construct result structure.
            filtered = {
                k: v for k, v in filtered.items() if not isinstance(v, (tuple, list, dict)) or v
            }
            return type(structures[0])(filtered)
        if len(structures) > 1:
            return structures if predicate(*structures) else ()
        else:
            return structures[0] if predicate(structures[0]) else ()
    result = impl(predicate, *structures)
    if flatten:
        result = flatten_(result)
    return result 
# pylint: disable=redefined-builtin
zip = zip_
map = map_
flatten = flatten_
filter = filter_