-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorchscript_inference.py
101 lines (92 loc) · 3.19 KB
/
torchscript_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import Dataset
from datasets.utils import disable_progress_bar
from transformers import AutoTokenizer
import gc
from processing import prepare_validation_features, postprocess_qa_predictions
from utils import parse_args_inference
os.environ["TOKENIZERS_PARALLELISM"] = "false"
disable_progress_bar()
@torch.no_grad()
def predict(
model: nn.Module,
dataset: Dataset,
batch_size: int = 64,
workers: int = 4,
device: str = "cuda"
) -> np.ndarray:
model.eval()
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
start_logits = []
end_logits = []
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
output = model(input_ids, attention_mask)
start_logits.append(output[0].cpu().numpy())
end_logits.append(output[1].cpu().numpy())
return np.vstack(start_logits), np.vstack(end_logits)
if __name__ == "__main__":
config = parse_args_inference()
data = pd.read_csv(config.input_data)
tokenizer = AutoTokenizer.from_pretrained(config.base_model)
dataset = Dataset.from_pandas(data)
tokenized_dataset = dataset.map(
prepare_validation_features,
batched=True,
remove_columns=dataset.column_names,
fn_kwargs={
"tokenizer": tokenizer,
"max_length": config.max_length,
"doc_stride": config.doc_stride
}
)
input_dataset = tokenized_dataset.map(
lambda example: example, remove_columns=['example_id', 'offset_mapping']
)
input_dataset.set_format(type="torch")
if len(config.select_folds) > 0:
folds = [int(fold) for fold in config.select_folds]
else:
folds = range(config.num_folds)
for fold in folds:
print(f"Generating predictions for fold {fold}")
if config.model_name:
filename = f"torchscript_{config.model_name.replace('/', '-')}_fold_{fold}"
else:
filename = f"torchscript_{config.base_model.replace('/', '-')}_fold_{fold}"
checkpoint = os.path.join(config.model_weights_dir, f"{filename}.pt")
model = torch.jit.load(checkpoint)
model.to(config.device)
start_logits, end_logits = predict(
model,
input_dataset,
config.batch_size,
config.dataloader_workers,
config.device
)
if config.output_csv:
pred_df = postprocess_qa_predictions(
dataset,
tokenized_dataset,
(start_logits, end_logits),
tokenizer
)
pred_df.to_csv(os.path.join(config.save_dir, f"{filename}.csv"), index=False)
if config.output_logits:
np.save(os.path.join(config.save_dir, f"{filename}_start_logits.npy"), start_logits)
np.save(os.path.join(config.save_dir, f"{filename}_end_logits.npy"), end_logits)
del model
gc.collect()
torch.cuda.empty_cache()