Skip to content

Latest commit

 

History

History
44 lines (29 loc) · 2.49 KB

README.md

File metadata and controls

44 lines (29 loc) · 2.49 KB

RelaxedE3NN

[GRaM at ICML'24] Relaxed Equivariant Graph Neural Networks.

This repo contains code for Relaxed Graph Equivariant Neural Networks (https://arxiv.org/abs/2407.20471).

Install

Dependencies

PyTorch

e3nn requires PyTorch >=1.8.0 For installation instructions, please see the PyTorch homepage.

torch_geometric

First you have to install pytorch_geometric. For torch 1.11 and no CUDA support:

CUDA=cpu

pip install --upgrade --force-reinstall torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+${CUDA}.html
pip install --upgrade --force-reinstall torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+${CUDA}.html
pip install torch-geometric

See here to get cuda support or newer versions.

e3nn

Stable (PyPI)

$ pip install e3nn

Organization

We provide code for two toy experiments.

Shape Deformations

shape_deformations_3d.ipynb contains symmetry breaking examples deforming a cube into a rectangular prism/an asymmetric shape. We demonstrate that the relaxed weights are interpretable through plotting their spherical harmonic projections (see the paper for more detail).

Electric Field Simulation

electric_field_sim.ipynb contains an example learning the direction of the electric and magnetic force for a charged particle in a magnetic field.

Models

relaxed_e3nn_conv.py contains a simple relaxed e3nn convolution model made of stacked RelaxedConvolutions. The electric_field_model folder contains an example of how to incorporate the relaxed e3nn layer into a more complicated model based on the sample networks contained in e3nn. The files modified to contain the relaxed equivariant layer are specifically based SimpleNetwork and NetworkForAGraphWithAttributes in e3nn here. Within the electric_field_model, relaxed_points_conv.py modifies the points convolution in e3nn. gate_points_message_passing_relaxed.py modifies the message passing to use the relaxed convolution. gate_points_networks_relaxed.py modifies the sample message passing graph neural network models in e3nn to use the relaxed convolution layer.