Skip to content

Commit

Permalink
Merge pull request #92 from choderalab/no_dgl_import
Browse files Browse the repository at this point in the history
no global import dgl
  • Loading branch information
yuanqing-wang authored Oct 26, 2021
2 parents c93b501 + ef5c8e3 commit 51ea6db
Show file tree
Hide file tree
Showing 16 changed files with 24 additions and 284 deletions.
3 changes: 0 additions & 3 deletions espaloma/app/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
# =============================================================================
import abc
import copy

import numpy as np
import dgl
import torch

import espaloma as esp
Expand Down
3 changes: 2 additions & 1 deletion espaloma/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# =============================================================================
import random

import dgl
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -105,6 +104,7 @@ def split(ds, partition):

def batch(ds, batch_size, seed=2666):
"""Batch graphs and values after shuffling."""
import dgl
# get the numebr of data
n_data_points = len(ds)
n_batches = n_data_points // batch_size # drop the rest
Expand All @@ -127,6 +127,7 @@ def batch(ds, batch_size, seed=2666):


def collate_fn(graphs):
import dgl
return esp.HomogeneousGraph(dgl.batch(graphs))


Expand Down
6 changes: 3 additions & 3 deletions espaloma/graphs/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# IMPORTS
# =============================================================================
import abc

import dgl
import openff.toolkit

import espaloma as esp
Expand Down Expand Up @@ -72,7 +70,7 @@ def __init__(self, mol=None, homograph=None, heterograph=None):
def save(self, path):
import os
import json

import dgl
os.mkdir(path)
dgl.save_graphs(path + "/homograph.bin", [self.homograph])
dgl.save_graphs(path + "/heterograph.bin", [self.heterograph])
Expand All @@ -82,6 +80,7 @@ def save(self, path):
@classmethod
def load(cls, path):
import json
import dgl

homograph = dgl.load_graphs(path + "/homograph.bin")[0][0]
heterograph = dgl.load_graphs(path + "/heterograph.bin")[0][0]
Expand Down Expand Up @@ -117,6 +116,7 @@ def get_homograph_from_mol(mol):

@staticmethod
def get_heterograph_from_graph_and_mol(graph, mol):
import dgl
assert isinstance(
graph, dgl.DGLGraph
), "graph can only be dgl Graph object."
Expand Down
2 changes: 1 addition & 1 deletion espaloma/graphs/utils/read_heterogeneous_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import numpy as np
import torch
from espaloma.graphs.utils import offmol_indices
Expand Down Expand Up @@ -262,6 +261,7 @@ def from_homogeneous_and_mol(g, offmol):
axis=1,
)

import dgl
hg = dgl.heterograph({key: list(value) for key, value in hg.items()})

hg.nodes["n1"].data["h0"] = g.ndata["h0"]
Expand Down
3 changes: 2 additions & 1 deletion espaloma/graphs/utils/read_homogeneous_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import torch

# =============================================================================
Expand Down Expand Up @@ -118,6 +117,7 @@ def fp_rdkit(atom):
# MODULE FUNCTIONS
# =============================================================================
def from_openff_toolkit_mol(mol, use_fp=True):
import dgl
# initialize graph
from rdkit import Chem

Expand Down Expand Up @@ -207,6 +207,7 @@ def from_oemol(mol, use_fp=True):


def from_rdkit_mol(mol, use_fp=True):
import dgl
from rdkit import Chem

# initialize graph
Expand Down
4 changes: 2 additions & 2 deletions espaloma/mm/energy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import torch

import espaloma as esp
Expand Down Expand Up @@ -254,7 +253,7 @@ def energy_in_graph(
"""
# TODO: this is all very restricted for now
# we need to make this better

import dgl
if "n2" in terms:
# apply energy function

Expand Down Expand Up @@ -417,6 +416,7 @@ def forward(self, g):
class CarryII(torch.nn.Module):
def forward(self, g):
import math
import dgl

g.multi_update_all(
{
Expand Down
3 changes: 1 addition & 2 deletions espaloma/mm/geometry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import torch

# =============================================================================
Expand Down Expand Up @@ -181,7 +180,7 @@ def geometry_in_graph(g):
This function modifies graphs in-place.
"""

import dgl
# Copy coordinates to higher-order nodes.
g.multi_update_all(
{
Expand Down
5 changes: 1 addition & 4 deletions espaloma/mm/nonbonded.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import torch

# =============================================================================
# CONSTANTS
# =============================================================================
from simtk import unit

import espaloma as esp

# CODATA 2018
Expand Down Expand Up @@ -43,7 +40,7 @@ def _arithmetic_mean(nodes):
# COMBINATION RULES FOR NONBONDED
# =============================================================================
def lorentz_berthelot(g, suffix=""):

import dgl
g.multi_update_all(
{
"n1_as_%s_in_%s"
Expand Down
9 changes: 2 additions & 7 deletions espaloma/nn/layers/dgl_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
"""

import math
from copy import deepcopy

import dgl

# =============================================================================
# IMPORTS
# =============================================================================
import torch
from dgl.nn import pytorch as dgl_pytorch

# =============================================================================
# CONSTANT
Expand All @@ -35,7 +29,7 @@ def __init__(
kwargs={},
):
super(GN, self).__init__()

from dgl.nn import pytorch as dgl_pytorch
if kwargs == {}:
if model_name in DEFAULT_MODEL_KWARGS:
kwargs = DEFAULT_MODEL_KWARGS[model_name]
Expand All @@ -58,6 +52,7 @@ def forward(self, g, x):


def gn(model_name="GraphConv", kwargs={}):
from dgl.nn import pytorch as dgl_pytorch
if model_name == "GINConv":
return lambda in_features, out_features: dgl_pytorch.conv.GINConv(
apply_func=torch.nn.Linear(in_features, out_features),
Expand Down
3 changes: 1 addition & 2 deletions espaloma/nn/readout/charge_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
# IMPORTS
# =============================================================================
import torch
import dgl
import math

# =============================================================================
# UTILITY FUNCTIONS
Expand Down Expand Up @@ -67,6 +65,7 @@ def __init__(self):
def forward(self, g, total_charge=0.0):
""" apply charge equilibrium to all molecules in batch """
# calculate $s ^ {-1}$ and $ es ^ {-1}$
import dgl
g.apply_nodes(
lambda node: {"s_inv": node.data["s"] ** -1}, ntype="n1"
)
Expand Down
7 changes: 5 additions & 2 deletions espaloma/nn/readout/graph_level_readout.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import torch
import espaloma as esp

Expand All @@ -17,10 +16,13 @@ def __init__(
config_local,
config_global,
out_name,
pool=dgl.function.sum,
pool=None,
):

super(GraphLevelReadout, self).__init__()
import dgl
if pool is None:
pool = dgl.function.sum
self.in_features = in_features
self.config_local = config_local
self.config_global = config_global
Expand All @@ -42,6 +44,7 @@ def __init__(
self.out_name = out_name

def forward(self, g):
import dgl
g.apply_nodes(
lambda node: {"h_global": self.d_local(None, node.data["h"])},
ntype="n1",
Expand Down
5 changes: 3 additions & 2 deletions espaloma/nn/readout/janossy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# =============================================================================
# IMPORTS
# =============================================================================
import dgl
import torch

import espaloma as esp
Expand Down Expand Up @@ -98,6 +97,7 @@ def forward(self, g):
g : dgl.DGLHeteroGraph,
input graph.
"""
import dgl

# copy
g.multi_update_all(
Expand Down Expand Up @@ -233,7 +233,8 @@ def forward(self, g):
g : dgl.DGLHeteroGraph,
input graph.
"""

import dgl

# copy
g.multi_update_all(
{
Expand Down
4 changes: 1 addition & 3 deletions espaloma/nn/sequential.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
""" Chain mutiple layers of GN together.
"""
import dgl
import torch


class _Sequential(torch.nn.Module):
"""Sequentially staggered neural networks."""

Expand Down Expand Up @@ -115,7 +113,7 @@ def forward(self, g, x=None):
g : `dgl.DGLHeteroGraph`
output graph
"""

import dgl
# get homogeneous subgraph
g_ = dgl.to_homo(g.edge_type_subgraph(["n1_neighbors_n1"]))

Expand Down
62 changes: 0 additions & 62 deletions espaloma/redux/energy.py

This file was deleted.

Loading

0 comments on commit 51ea6db

Please sign in to comment.