Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViLT on GQA #85

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions EVAL.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# Evaluation
The results will vary a bit since we do a batched-inference, which yields padded image batch that would be inconsistently embedded while performing linear image patch projection.

## Evaluate GQA
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_gqa_randaug test_only=True precision=32 load_path="<YOUR_WEIGHT_ROOT>/vilt_gqa.ckpt"

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=64 task_finetune_gqa_randaug test_only=True precision=32 load_path="weights/vilt_gqa.ckpt"

output > This script will generate `result/gqa_submit_last.json`
```
## Evaluate VQAv2
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_vqa_randaug test_only=True precision=32 load_path="<YOUR_WEIGHT_ROOT>/vilt_vqa.ckpt"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ We provide five pretrained weights
3. ViLT-B/32 200k finetuned on NLVR2 [link](https://github.com/dandelin/ViLT/releases/download/200k/vilt_nlvr2.ckpt)
4. ViLT-B/32 200k finetuned on COCO IR/TR [link](https://github.com/dandelin/ViLT/releases/download/200k/vilt_irtr_coco.ckpt)
5. ViLT-B/32 200k finetuned on F30K IR/TR [link](https://github.com/dandelin/ViLT/releases/download/200k/vilt_irtr_f30k.ckpt)
6. ViLT-B/32 200k finetuned on GQA [link](https://github.com/keshavshivkumar/ViLT/releases/download/vilt_gqa/vilt_gqa.ckpt)

## Out-of-the-box MLM + Visualization Demo
<p align="center">
Expand Down
3 changes: 0 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
def main(_config):
_config = copy.deepcopy(_config)
pl.seed_everything(_config["seed"])

dm = MTDataModule(_config, dist=True)

model = ViLTransformerSS(_config)
Expand All @@ -29,7 +28,6 @@ def main(_config):
_config["log_dir"],
name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback, lr_callback]

Expand All @@ -44,7 +42,6 @@ def main(_config):
)

max_steps = _config["max_steps"] if _config["max_steps"] is not None else None

trainer = pl.Trainer(
gpus=_config["num_gpus"],
num_nodes=_config["num_nodes"],
Expand Down
34 changes: 32 additions & 2 deletions vilt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

ex = Experiment("ViLT")


def _loss_names(d):
ret = {
"itm": 0,
Expand All @@ -11,6 +10,7 @@ def _loss_names(d):
"vqa": 0,
"nlvr2": 0,
"irtr": 0,
"gqa": 0,
}
ret.update(d)
return ret
Expand All @@ -35,6 +35,7 @@ def config():

# Text Setting
vqav2_label_size = 3129
gqa_label_size = 1878
max_text_len = 40
tokenizer = "bert-base-uncased"
vocab_size = 30522
Expand Down Expand Up @@ -77,7 +78,7 @@ def config():
num_gpus = 1
num_nodes = 1
load_path = ""
num_workers = 8
num_workers = 4
precision = 16


Expand Down Expand Up @@ -179,6 +180,35 @@ def task_finetune_vqa_randaug():
val_check_interval = 0.1
lr_mult = 10

@ex.named_config
def task_finetune_gqa():
exp_name = "finetune_gqa"
datasets = ["gqa"]
loss_names = _loss_names({"gqa": 1})
batch_size = 256
max_epoch = 10
max_steps = None
warmup_steps = 0.1
draw_false_image = 0
learning_rate = 1e-4
val_check_interval = 0.1
lr_mult = 10

@ex.named_config
def task_finetune_gqa_randaug():
exp_name = "finetune_gqa_randaug"
datasets = ["gqa"]
train_transform_keys = ["pixelbert_randaug"]
loss_names = _loss_names({"gqa": 1})
batch_size = 256
max_epoch = 10
max_steps = None
warmup_steps = 0.1
draw_false_image = 0
learning_rate = 1e-4
val_check_interval = 0.1
lr_mult = 10


@ex.named_config
def task_finetune_irtr_coco():
Expand Down
2 changes: 2 additions & 0 deletions vilt/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .sbu_datamodule import SBUCaptionDataModule
from .vqav2_datamodule import VQAv2DataModule
from .nlvr2_datamodule import NLVR2DataModule
from .gqa_datamodule import GQADataModule

_datamodules = {
"vg": VisualGenomeCaptionDataModule,
Expand All @@ -14,4 +15,5 @@
"sbu": SBUCaptionDataModule,
"vqa": VQAv2DataModule,
"nlvr2": NLVR2DataModule,
"gqa": GQADataModule,
}
14 changes: 6 additions & 8 deletions vilt/datamodules/datamodule_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import os

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -132,18 +133,15 @@ def set_test_dataset(self):
image_only=self.image_only,
)

def setup(self, stage):
if not self.setup_flag:
def setup(self, stage=None):
if stage == "fit" or stage == "test" or stage is None:
self.set_train_dataset()
self.set_val_dataset()
self.set_test_dataset()

self.train_dataset.tokenizer = self.tokenizer
self.set_val_dataset()
self.val_dataset.tokenizer = self.tokenizer
self.set_test_dataset()
self.test_dataset.tokenizer = self.tokenizer

self.setup_flag = True

def train_dataloader(self):
loader = DataLoader(
self.train_dataset,
Expand Down Expand Up @@ -175,4 +173,4 @@ def test_dataloader(self):
pin_memory=True,
collate_fn=self.test_dataset.collate,
)
return loader
return loader
63 changes: 63 additions & 0 deletions vilt/datamodules/gqa_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from vilt.datasets import GQADataset
from .datamodule_base import BaseDataModule
from collections import defaultdict
import numpy as np

class GQADataModule(BaseDataModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def dataset_cls(self):
return GQADataset

@property
def dataset_name(self):
return "gqa"

def setup(self, stage):
super().setup(stage)

train_answers = self.train_dataset.table["answers"].to_pandas().tolist()
val_answers = self.val_dataset.table["answers"].to_pandas().tolist()
train_labels = self.train_dataset.table["answer_label"].to_pandas().tolist()
val_labels = self.val_dataset.table["answer_label"].to_pandas().tolist()

all_answers = [c for c in train_answers + val_answers if c is not None]

train_answer_tuples = [(label, answer) for labels, answers in zip(train_labels, train_answers) for label, answer in zip(labels.tolist(), answers.tolist())]
val_answer_tuples = [(label, answer) for labels, answers in zip(val_labels, val_answers) for label, answer in zip(labels.tolist(), answers.tolist())]

train_answer2id = {answer: label for label, answer in train_answer_tuples}
val_answer2id = {answer: label for label, answer in val_answer_tuples}
# print([i for i in train_answer2id if train_answer2id[i]==2])
# Merge train and val dictionaries, keeping the label ids from the train dictionary
self.answer2id = {**val_answer2id, **train_answer2id}

self.num_class = len(self.answer2id)
self.id2answer = defaultdict(lambda: "unknown")
for k, v in self.answer2id.items():
self.id2answer[v] = k

# Print some samples from the training dataset

# print("Training dataset samples:")
# for idx, sample in enumerate(self.train_dataset):
# if idx >= 10:
# break
# print('In GQADataModule')
# question = sample["text"]
# label = sample["gqa_label"]
# answer = self.id2answer[label]
# print(f"Question: {question}\nLabel: {label}\nAnswer: {answer}")

# print("\nValidation dataset samples:")
# # Print some samples from the validation dataset
# for idx, sample in enumerate(self.val_dataset):
# if idx >= 10:
# break
# print('In GQADataModule')
# question = sample["text"]
# label = sample["gqa_label"]
# answer = self.id2answer[label]
# print(f"Question: {question}\nLabel: {label}\nAnswer: {answer}")
2 changes: 1 addition & 1 deletion vilt/datamodules/multitask_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def prepare_data(self):
def setup(self, stage):
for dm in self.dms:
dm.setup(stage)

self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms])
self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms])
self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms])
Expand Down
1 change: 1 addition & 0 deletions vilt/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .sbu_caption_dataset import SBUCaptionDataset
from .vqav2_dataset import VQAv2Dataset
from .nlvr2_dataset import NLVR2Dataset
from .gqa_dataset import GQADataset
46 changes: 40 additions & 6 deletions vilt/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
"""
assert len(transform_keys) >= 1
super().__init__()

self.transforms = keys_to_transforms(transform_keys, size=image_size)
self.text_column_name = text_column_name
self.names = names
Expand All @@ -38,21 +37,38 @@ def __init__(
self.draw_false_text = draw_false_text
self.image_only = image_only
self.data_dir = data_dir

if len(names) != 0:
print(f"Attempting to load the following dataset files: {[f'{data_dir}/{name}.arrow' for name in names]}")
tables = [
pa.ipc.RecordBatchFileReader(
pa.memory_map(f"{data_dir}/{name}.arrow", "r")
).read_all()
for name in names
if os.path.isfile(f"{data_dir}/{name}.arrow")
]

self.table_names = list()
self.table = list()
for i, name in enumerate(names):
self.table_names += [name] * len(tables[i])
if i < len(tables):
self.table += [name] * len(tables[i])
else:
print(f"Warning: Skipping {name} as the index is out of range in tables.")

if self.table is None:
print("Error: The table is not properly loaded. Please check the dataset files and their paths.")
if len(tables) > 0:
self.table = pa.concat_tables(tables, promote=True)
else:
print("Warning: No tables to concatenate. Check if dataset is properly loaded.")
self.table = None

# if self.table is not None:
# print("Column names in the table schema:", [field.name for field in self.table.schema])
# print("Sample answer_scores:")
# for i in range(min(10, len(self.table))):
# print(f"Row {i}: {self.table['answer_scores'][i].as_py()}")


self.table = pa.concat_tables(tables, promote=True)
if text_column_name != "":
self.text_column_name = text_column_name
self.all_texts = self.table[text_column_name].to_pandas().tolist()
Expand All @@ -65,6 +81,22 @@ def __init__(
self.all_texts = list()
else:
self.all_texts = list()


if self.table is not None and text_column_name != "":
self.text_column_name = text_column_name
try:
self.all_texts = self.table[text_column_name].to_pandas().tolist()
self.all_texts = (
[list(set(texts)) for texts in self.all_texts]
if remove_duplicate
else self.all_texts
)
except KeyError:
print(f"Error: The text column '{text_column_name}' was not found in the table.")
self.all_texts = list()
else:
self.all_texts = list()

self.index_mapper = dict()

Expand All @@ -77,6 +109,8 @@ def __init__(
else:
for i in range(len(self.table)):
self.index_mapper[i] = (i, None)
print("Length of tables:", [len(t) for t in tables])


@property
def corpus(self):
Expand Down
56 changes: 56 additions & 0 deletions vilt/datasets/gqa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from .base_dataset import BaseDataset

class GQADataset(BaseDataset):
def __init__(self, *args, split="", **kwargs):
assert split in ["train", "val", "test", "testdev"]
self.split = split
self.print_counter = 0 # Add this line to initialize the counter

if split == "train":
names = ["gqa_train", "gqa_trainable_val"]
elif split == "val":
names = ["gqa_rest_val"]
elif split == "test":
names = ["gqa_testdev"]

super().__init__(
*args,
**kwargs,
names=names,
text_column_name="questions",
remove_duplicate=False,
)

def __getitem__(self, index):
image_tensor = self.get_image(index)["image"]
text = self.get_text(index)["text"]

index, question_index = self.index_mapper[index]
qid = self.table["question_id"][index][question_index].as_py()

if self.split != "test":
answers = self.table["answers"][index][question_index].as_py()
labels = self.table["answer_label"][index][question_index].as_py()
scores = self.table["answer_scores"][index][question_index].as_py()
else:
answers = list()
labels = list()
scores = list()

# Print the first 5 questions and answers
# if self.print_counter < 5:
# print('In GQADataset')
# print(f"Question: {text}")
# print(f"Label: {labels}")
# print(f"Answers: {answers}")
# self.print_counter += 1 # Increment the counter

return {
"image": image_tensor,
"text": text,
"gqa_answer": answers,
"gqa_label": labels,
"gqa_scores": scores,
"qid": qid,
}

Loading