Skip to content

Commit

Permalink
dgl serialization and lmdb update
Browse files Browse the repository at this point in the history
  • Loading branch information
Wenbin Xu committed Dec 28, 2023
1 parent a5ea643 commit ea71b8a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 12 deletions.
54 changes: 50 additions & 4 deletions HiPRGen/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,35 @@
from tqdm import tqdm
import glob

import io
import dgl
import tempfile
import bisect

def load_dgl_graph_from_serialized(serialized_graph):
with tempfile.NamedTemporaryFile(mode='wb', delete=True) as tmpfile:
tmpfile.write(serialized_graph)
tmpfile.flush() # Ensure all data is written

# Rewind the file to the beginning before reading
tmpfile.seek(0)

# Load the graph using the file handle
graphs, _ = dgl.load_graphs(tmpfile.name)

return graphs[0] # Assuming there's only one graph

def TransformMol(data_object):
serialized_graph = data_object['molecule_graph']
dgl_graph = load_dgl_graph_from_serialized(serialized_graph)
data_object["molecule_graph"] = dgl_graph
return data_object

def TransformReaction(data_object):
serialized_graph = data_object['reaction_graph']
dgl_graph = load_dgl_graph_from_serialized(serialized_graph)
data_object["reaction_graph"] = dgl_graph
return data_object

class LmdbBaseDataset(Dataset):

Expand Down Expand Up @@ -124,9 +153,9 @@ def __getitem__(self, idx):

data_object = pickle.loads(datapoint_pickled)

# TODO
if self.transform is not None:
data_object = self.transform(data_object)
# TODO
if self.transform is not None:
data_object = self.transform(data_object)

return data_object

Expand Down Expand Up @@ -631,6 +660,21 @@ def write2moleculelmdb(mp_args
db.close()


def serialize_dgl_graph(dgl_graph):
# import pdb
# pdb.set_trace()
# Create a temporary file
with tempfile.NamedTemporaryFile() as tmpfile:
# Save the graph to the temporary file

dgl.save_graphs(tmpfile.name, [dgl_graph])

# Read the content of the file
tmpfile.seek(0)
serialized_data = tmpfile.read()

return serialized_data

def dump_molecule_lmdb(
indices,
graphs,
Expand All @@ -649,7 +693,9 @@ def dump_molecule_lmdb(
# else:
key_tempalte = ["molecule_index", "molecule_graph", "molecule_wrapper"]

dataset = [{k: v for k, v in zip(key_tempalte, values)} for values in zip(indices, graphs, pmgs)]
serialized_graphs = [serialize_dgl_graph(graph) for graph in graphs]

dataset = [{k: v for k, v in zip(key_tempalte, values)} for values in zip(indices, serialized_graphs, pmgs)]

global_keys = {
"charges" : charges,
Expand Down
6 changes: 4 additions & 2 deletions HiPRGen/rxn_networks_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tqdm
import pickle
from HiPRGen.lmdb_dataset import write_to_lmdb
from HiPRGen.lmdb_dataset import serialize_dgl_graph


class rxn_networks_graph:
Expand All @@ -24,7 +25,7 @@ def __init__(
):
#wx, which one should come from molecule lmdbs?
self.mol_entries = mol_entries
self.dgl_mol_dict = dgl_molecules_dict
self.dgl_mol_dict = dgl_molecules_dict #not used at all?
self.grapher_features = grapher_features
#self.report_file_path = report_file_path

Expand Down Expand Up @@ -249,7 +250,8 @@ def find_total_bonds(rxn, species, reactants, products):

#wx, step 6: structure and save reaction graph.
self.data["reaction_index"] = rxn_id
self.data["reaction_graph"] = rxn_graph
#self.data["reaction_graph"] = rxn_graph
self.data["reaction_graph"] = serialize_dgl_graph(rxn_graph)
self.data["reaction_feature"] = features
self.data["reaction_molecule_info"] = {
"reactants" : { "reactants" : list(rxn["reactants"]), #TODO, keep unique value if two reactants are same?
Expand Down
12 changes: 11 additions & 1 deletion HiPRGen/species_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from HiPRGen.lmdb_dataset import dump_molecule_lmdb
import numpy as np
import copy
import dgl
from HiPRGen.lmdb_dataset import serialize_dgl_graph

"""
Phase 1: species filtering
Expand Down Expand Up @@ -359,7 +361,7 @@ def collapse_isomorphism_group(g):
#wx, dump molecule lmdb.
dump_molecule_lmdb(
molecule_ind_list,
dgl_molecules, #dgl_graphs
dgl_molecules, #molecular wrapper
pmg_objects,
charge_set,
ring_size_set,
Expand All @@ -375,6 +377,14 @@ def collapse_isomorphism_group(g):
with open(mol_entries_pickle_location, "wb") as f:
pickle.dump(mol_entries, f)

#use dgl serialize
# import pdb
# pdb.set_trace()
for graph_id in dgl_molecules_dict:
dgl_molecules_dict[graph_id] = serialize_dgl_graph(dgl_molecules_dict[graph_id])


# dgl.save_graphs(dgl_mol_grphs_pickle_location, list(dgl_molecules_dict.values()))
with open(dgl_mol_grphs_pickle_location, "wb") as f:
pickle.dump(dgl_molecules_dict, f)

Expand Down
6 changes: 4 additions & 2 deletions run_network_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
worker,
DISPATCHER_RANK
)


import dgl
from HiPRGen.lmdb_dataset import load_dgl_graph_from_serialized
# python run_network_generation.py mol_entries_pickle_file dispatcher_payload.json worker_payload.json


Expand All @@ -30,6 +30,8 @@

with open(dgl_molecules_dict_pickle_file, 'rb') as f:
dgl_molecules_dict_pickle_file = pickle.load(f)
for graph_i in dgl_molecules_dict_pickle_file:
dgl_molecules_dict_pickle_file[graph_i] = load_dgl_graph_from_serialized(dgl_molecules_dict_pickle_file[graph_i])

with open(grapher_features_dict_pickle_file, 'rb') as f:
grapher_features_dict_pickle_file = pickle.load(f)
Expand Down
6 changes: 3 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def euvl_bondnet_test():

# "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/dgl_mol_graphs.pickle",
# "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/grapher_features.pickle",
folder + "/dgl_mol_graphs.pickle",
folder + "/dgl_mol_graphs.pickle", #use dgl.save
folder + "/grapher_features.pickle",

#wx, path to write reaction lmdb
Expand Down Expand Up @@ -1029,8 +1029,8 @@ def euvl_bondnet_test():
# flicho_test,
# co2_test,
# euvl_phase1_test,
euvl_phase2_test,
# euvl_bondnet_test
#euvl_phase2_test,
euvl_bondnet_test
]

for test in tests:
Expand Down

0 comments on commit ea71b8a

Please sign in to comment.