Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Vmap call in value estimate not compatible with torch_geometric radius graph #2537

Open
matteobettini opened this issue Nov 5, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@matteobettini
Copy link
Contributor

matteobettini commented Nov 5, 2024

The vmap call to in the value estimators is giving rise to an incompatibility issues.

In particular, this is the call

data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)(

My problem today is in relation to a feature for GNNs.

Basically in GNN you might need to compute the graph adjacency from an input vector containing positions using this function
https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.radius_graph.html

This function, given inputs of the same shape, could give outputs of different size and was not made to be vmap compatible.

Here is the closest reprod script I came up with

import torch
from torch_geometric.nn import radius_graph


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, batch):
        edge_index = radius_graph(x, r=3.5, batch=batch, loop=False)
        return x


model = Model()

x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0], [0.5, 0.5]])
batch = torch.tensor([0, 0, 1, 1, 1])

# Normal call works
data_out = model(x, batch)

# Vmap does not
x = torch.stack([x, x], dim=0)
batch = torch.stack([batch, batch], dim=0)
data_out = torch.vmap(model, (0, 0))(x, batch)

Any suggestions?

In general I think that this specific vmap call has been cause of incompatibility for many, would it be possible to make it optional?

@matteobettini matteobettini added the bug Something isn't working label Nov 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants