Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Mapping over compatible trees #169

Open
1 task done
LarsKue opened this issue Oct 28, 2024 · 4 comments
Open
1 task done

[Feature Request] Mapping over compatible trees #169

LarsKue opened this issue Oct 28, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@LarsKue
Copy link

LarsKue commented Oct 28, 2024

Required prerequisites

  • I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)

Motivation

I often work with nested containers that constitute partially available data. It seems that mapping these in conjunction with other compatible trees is not currently supported by the library.

Simple example:

a = {"one": 1, "two": 2, "three": 3}
b = {"two": 2, "three": 3}

def add(x, y):
    return x + y

optree.tree_map(add, a, b)  # ValueError

It would be nice if we had support for mapping over compatible trees. In particular, I would consider trees compatible when each leaf node of tree b has a corresponding leaf node in a where the path to both is identical.

It might be safer to only enable this functionality with an extra argument or in another function than the plain tree_map, however.

Solution

Referring to my example above, I would expect the solution:

{"one": 1, "two": 4, "three": 6}

Alternatives

Allowing us to broadcast trees would also solve this:

a, b = optree.tree_broadcast(a, b)

b
>>> {"one": None, "two": 2, "three": 3}

optree.tree_map(add, a, b)  # works fine already with these inputs (ignoring the None edge-case)

Alternatively, we could also work around this particular problem by using the tree paths or accessors in a generic way:

def add_to_a(accessor_or_path, value):
    # something like this tree_get which I think does not exist yet
    # an in-place method would also be nice
    result = optree.tree_get(a, accessor_or_path) + value
    return value

optree.tree_map_with_path(add_to_a, b)
# or
optree.tree_map_with_accessor(add_to_a, b)

However, this might be highly inefficient.

Additional context

Perhaps this functionality (or a work-around) already exists. I would be happy to be educated.

@LarsKue LarsKue added the enhancement New feature or request label Oct 28, 2024
@XuehaiPan
Copy link
Member

Hi @LarsKue, could you elaborate more on the definition of "compatible trees"? A tree can be a nested collection of list/tuple/dict/namedtuple etc., not only nested dicts.

I think the following code could be a workaround for this feature:

import optree

def tree_get(accessor, tree, default=None):
    try:
        return accessor(tree)
    except LookupError:
        return default

tree_x = {"one": 1, "two": 2, "three": 3}
tree_y = {"two": 2, "three": 3}

result = optree.tree_map_with_accessor(lambda a, x: x + tree_get(a, tree_y, default=0), tree_x)
# {'one': 1, 'two': 4, 'three': 6}

# What's the expected output of some_tree_add(tree_x, tree_y) and some_tree_add(tree_y, tree_x)?

@LarsKue
Copy link
Author

LarsKue commented Oct 28, 2024

@XuehaiPan Thank you for the work-around, this already looks great.

I think there are many good definitions for a compatible tree. In particular, consider the set of leaf nodes, along with their paths (nested dict keys, tuple indices, etc.) in the respective trees. We could consider tree A a supertree of tree B if A's leaf node set is a superset of B's. B is then a subtree of A.

Currently, tree_map only works on trees where these sets are exactly identical, which is a little bit limiting. Ideally, we could specify a supertree and allow mapping over any subtree, or simply map over the intersection of both leaf node sets.

Example: List in Dict

a = {"a": [1, 2, 3], "b": 3}
b = {"a": [1, 2, 3, 4], "c": 4}

# the leaf node sets are
# a: {("a", 0), ("a", 1), ("a", 2), "b"}
# b: {("a", 0), ("a", 1), ("a", 2), ("a", 3), "c"}

add_intersection(a, b)
>>> {"a": [2, 4, 6]}  # note that we only return leaf nodes from the intersection of the leaf node set

# the final leaf node set is
# {("a", 0), ("a", 1), ("a", 2)}

@XuehaiPan
Copy link
Member

I think there are many good definitions for a compatible tree.

In optree and also JAX PyTree / PyTorch PyTree / DM-Tree, a tree X and its compatible tree Y are that the former one is the prefix tree of the latter one. That is if we replace the leaf node in tree X with some subtree or a leaf, the structure of the new tree can be identical to tree Y.

prefix_tree = {'a': 1, 'b': (2, 3), 'c': {'d': 4, 'e': 5}}
full_tree   = {'a': (1, 1, 1), 'b': ([2, 2], 3), 'c': {'d': 4, 'e': {'f': 5, 'g': 5}}}

We can see that the path of the prefix_tree is a prefix of the path of the full tree.

optree.tree_paths(prefix_tree)
# [
#     ('a',),
#     ('b', 0),
#     ('b', 1),
#     ('c', 'd'),
#     ('c', 'e'),
# ]
# vs.
optree.tree_paths(full_tree)
# [
#     ('a', 0),
#     ('a', 1),
#     ('a', 2),
#     ('b', 0, 0),
#     ('b', 0, 1),
#     ('b', 1),
#     ('c', 'd'),
#     ('c', 'e', 'f'),
#     ('c', 'e', 'g'),
# ]

In particular, consider the set of leaf nodes, along with their paths (nested dict keys, tuple indices, etc.) in the respective trees. We could consider tree A a supertree of tree B if A's leaf node set is a superset of B's. B is then a subtree of A.

This definition will change the node arity of the node. The interaction of {'a': 1, 'b': 2, 'c': 3} / {'a': 1, 'c': 3, 'd': 4, 'e': 5} will get key set {'a', 'c'}. The input dicts have arity 3 and 4, but the result has arity 2.

@LarsKue
Copy link
Author

LarsKue commented Oct 28, 2024

I see, thank you for the valuable insights. Using prefix trees as a default for compatibility certainly seems like a good choice.

I still believe my request has merit, particularly considering the tree_get via accessor may be inefficient if called repeatedly. Is there no viable option for you to implement tree spec broadcasting or a different mapping algorithm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants