We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The vmap call to in the value estimators is giving rise to an incompatibility issues.
In particular, this is the call
rl/torchrl/objectives/value/advantages.py
Line 144 in 98b45a6
My problem today is in relation to a feature for GNNs.
Basically in GNN you might need to compute the graph adjacency from an input vector containing positions using this function https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.radius_graph.html
This function, given inputs of the same shape, could give outputs of different size and was not made to be vmap compatible.
Here is the closest reprod script I came up with
import torch from torch_geometric.nn import radius_graph class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, batch): edge_index = radius_graph(x, r=3.5, batch=batch, loop=False) return x model = Model() x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0], [0.5, 0.5]]) batch = torch.tensor([0, 0, 1, 1, 1]) # Normal call works data_out = model(x, batch) # Vmap does not x = torch.stack([x, x], dim=0) batch = torch.stack([batch, batch], dim=0) data_out = torch.vmap(model, (0, 0))(x, batch)
Any suggestions?
In general I think that this specific vmap call has been cause of incompatibility for many, would it be possible to make it optional?
The text was updated successfully, but these errors were encountered:
vmoens
No branches or pull requests
The vmap call to in the value estimators is giving rise to an incompatibility issues.
In particular, this is the call
rl/torchrl/objectives/value/advantages.py
Line 144 in 98b45a6
My problem today is in relation to a feature for GNNs.
Basically in GNN you might need to compute the graph adjacency from an input vector containing positions using this function
https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.radius_graph.html
This function, given inputs of the same shape, could give outputs of different size and was not made to be vmap compatible.
Here is the closest reprod script I came up with
Any suggestions?
In general I think that this specific vmap call has been cause of incompatibility for many, would it be possible to make it optional?
The text was updated successfully, but these errors were encountered: