-
Notifications
You must be signed in to change notification settings - Fork 0
/
file_builder.py
151 lines (125 loc) · 5.28 KB
/
file_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import json
import torch
import pandas as pd
from tqdm import tqdm
import pickle
from bokeh.palettes import d3
from utils.reactions import reaction_fps
parser = argparse.ArgumentParser()
parser.add_argument("--settings", "-s", type=str, help="Path to model settings JSON file")
parser.add_argument("--model", "-m", type=str, help="Path to a saved pytorch model")
parser.add_argument("--test-data", "-t", type=str, help="Path to test data")
parser.add_argument("--output", "-o", type=str, help="Name of the output file")
parser.add_argument("--device", "-d", type=str, default="cpu", help="Device: cpu or cuda")
parser.add_argument("--transformer", default=False, action="store_true",
help="A flag whether to transform test SMILES to BERT fingerprints")
parser.add_argument("--remove-repeated", default=False, action="store_true",
help="A flag whether to remove repeating compounds from reaction before calculating fingerprints")
parser.add_argument("--banned-reagents", default='', type=str,
help="All the reagents separated by comma that will be removed from reactions"
"before calculating fingerprints. Example: 'CCO,CN'")
args = parser.parse_args()
dev = torch.device(args.device)
loaded_model = torch.load(args.model,
map_location=dev)
loaded_model.eval()
banned_reagents = set(args.banned_reagents.split(","))
def extract_ag_reag_prod(reac_smi: str):
reactants, agents, products = reac_smi.split(">")
return reactants.rstrip("."), agents.rstrip("."), products
def count_ag_reag(reac_smi: str):
reag_smi, ag_smi, _ = extract_ag_reag_prod(reac_smi)
return len((reag_smi + "." + ag_smi).rstrip(".").split("."))
def remove_repeated_reagents(reac_smi: str):
reag_smi, ag_smi, prod_smi = extract_ag_reag_prod(reac_smi)
present_rgs = set()
refined_reag = []
refined_ag = []
for rg in reag_smi.split("."):
if rg not in present_rgs:
present_rgs.add(rg)
refined_reag.append(rg)
for rg in ag_smi.split("."):
if rg not in present_rgs:
present_rgs.add(rg)
refined_ag.append(rg)
return f'{".".join(refined_reag)}>{".".join(refined_ag)}>{prod_smi}'
def remove_particular_reagents(reac_smi: str):
reag_smi, ag_smi, prod_smi = extract_ag_reag_prod(reac_smi)
refined_reag = []
refined_ag = []
for rg in reag_smi.split("."):
if rg not in banned_reagents:
refined_reag.append(rg)
for rg in ag_smi.split("."):
if rg not in banned_reagents:
refined_ag.append(rg)
return f'{".".join(refined_reag)}>{".".join(refined_ag)}>{prod_smi}'
with open("data/visual_validation/rxnClasses.pickle", "rb") as f:
classes = pickle.load(f)
classes = {int(k): v for k, v in classes.items()}
data = pd.read_csv(args.test_data, sep=";", header=None)
try:
data.columns = ["smiles", "label"]
except ValueError:
data.columns = ["smiles"]
all_embs = {"x": [], "y": []}
if args.remove_repeated:
data["smiles"] = data["smiles"].map(remove_repeated_reagents)
if len(banned_reagents) > 0:
data["smiles"] = data["smiles"].map(remove_particular_reagents)
if args.transformer:
from rxnfp.transformer_fingerprints import (
get_default_model_and_tokenizer, RXNBERTFingerprintGenerator
)
model, tokenizer = get_default_model_and_tokenizer()
rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)
for i in tqdm(range(data.shape[0])):
smi = data.iloc[i]["smiles"]
fp = torch.tensor(rxnfp_generator.convert(smi))
fp.resize_((1, 256))
with torch.no_grad():
embs = loaded_model(fp)
x, y = embs.tolist()[0]
all_embs["x"].append(x)
all_embs["y"].append(y)
else:
with open(args.settings) as f:
all_settings = json.load(f)
settings = all_settings["settings"]
fp_method = settings["fp_method"]
params = {"n_bits": settings["n_bits"],
"fp_type": settings["fp_type"],
"include_agents": settings["include_agents"],
"agent_weight": settings["agent_weight"],
"non_agent_weight": settings["non_agent_weight"],
"bit_ratio_agents": settings["bit_ratio_agents"]
}
for i in tqdm(range(data.shape[0])):
smi = data.iloc[i]["smiles"]
descriptors = reaction_fps(smi,
fp_method=fp_method,
**params)
fp = torch.from_numpy(descriptors).float().to(dev)
fp.resize_((1, settings["n_bits"]))
with torch.no_grad():
embs = loaded_model(fp)
x, y = embs.tolist()[0]
all_embs["x"].append(x)
all_embs["y"].append(y)
all_embs = pd.DataFrame(all_embs)
res = pd.concat((data, all_embs), axis=1)
res["alpha"] = 1
res["sizes"] = 5
try:
res["reaction_class"] = res["label"].map(classes)
factors = [v for v in classes.values()]
palette = d3['Category10'][len(factors)]
color_map = {k: v for k, v in zip(factors, palette)}
res["color_transform"] = res["reaction_class"].map(color_map)
except KeyError:
pass
finally:
res["num_reagents"] = res["smiles"].map(count_ag_reag)
res.to_csv(args.output, sep=",", header=True, index=False)