forked from rajpurkarlab/CXR-RePaiR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_corpus_embeddings.py
100 lines (84 loc) · 2.91 KB
/
gen_corpus_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import os
import pandas as pd
import torch
from tqdm import tqdm
import clip
from utils import nonpretrained_params
def encode_texts(imps, model, device):
trimmed_impressions = imps
with torch.no_grad():
imp_toks = clip.tokenize(
trimmed_impressions, context_length=model.context_length
).to(device)
embeddings = model.encode_text(imp_toks)
embeddings /= embeddings.norm(dim=-1, keepdim=True)
return embeddings
def main(args):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if (
args.clip_pretrained
): # clip model is pretrained on chest X-rays, uses different architecture
model, _ = clip.load("ViT-B/32", device=device, jit=False)
print("Loaded in pretrained model.")
else:
model = clip.CLIP(**nonpretrained_params)
print("Loaded in clip model.")
model.load_state_dict(torch.load(args.clip_model_path, map_location=device))
model = model.to(device)
impressions = pd.read_csv(args.data_path)["report"]
# fill null impressions
impressions = impressions.fillna("None")
impressions_size = impressions.shape[0]
bs = args.batch_size
num_batches = impressions_size // bs
tensors = []
for i in tqdm(range(num_batches)):
batch = impressions[bs * i : bs * i + bs]
print(type(batch))
_weights = encode_texts(batch, model, device)
tensors.append(_weights)
_weights = encode_texts(impressions[bs * num_batches :], model, device)
tensors.append(_weights)
clip_embeddings = torch.cat(tensors)
print(impressions.shape, clip_embeddings.shape)
out_data = (impressions, clip_embeddings)
if not os.path.exists("corpus_embeddings"):
os.makedirs("corpus_embeddings")
out_path = "corpus_embeddings/" + args.out
torch.save(out_data, out_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate clip embeddings for a training corpus (either sentence level or report level"
)
parser.add_argument(
"--clip_model_path",
type=str,
required=True,
help="name of clip model state dictionary for generating embeddings",
)
parser.add_argument(
"--clip_pretrained",
action="store_true",
help="Whether clip model was first pre-trained on natural images",
)
parser.add_argument(
"--data_path",
type=str,
required=True,
help="path of csv file containing training corpus (either sentence level or report level)",
)
parser.add_argument(
"--out",
type=str,
required=True,
help="name for saved corpus embeddings (include .pt extension)",
)
parser.add_argument(
"--batch_size",
type=int,
default=2000,
help="Batch size for generating clip embeddings",
)
args = parser.parse_args()
main(args)