Skip to content

Commit

Permalink
add an out-of-the-box captioning and visualization demo
Browse files Browse the repository at this point in the history
  • Loading branch information
dandelin committed Apr 3, 2021
1 parent 71f671d commit e1962f0
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 600 deletions.
12 changes: 7 additions & 5 deletions EVAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ The results will vary a bit since we do a batched-inference, which yields padded

## Evaluate VQAv2
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_vqa_randaug test_only=True load_path="weights/vilt_vqa.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_vqa_randaug test_only=True load_path="<YOUR_WEIGHT_ROOT>/vilt_vqa.ckpt"

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=64 task_finetune_vqa_randaug test_only=True load_path="weights/vilt_vqa.ckpt"
Expand All @@ -14,7 +14,7 @@ output > This script will generate `result/vqa_submit_vilt_vqa.json`, you can up

## Evaluate NLVR2
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_nlvr2_randaug test_only=True load_path="weights/vilt_nlvr2.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_nlvr2_randaug test_only=True load_path="<YOUR_WEIGHT_ROOT>/vilt_nlvr2.ckpt"

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=64 task_finetune_nlvr2_randaug test_only=True load_path="weights/vilt_nlvr2.ckpt"
Expand All @@ -37,9 +37,9 @@ INFO - ViLT - Completed after 0:01:31

## Evaluate COCO IR/TR
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_irtr_coco_randaug test_only=True load_path="weights/vilt_irtr_coco.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_irtr_coco_randaug test_only=True load_path="<YOUR_WEIGHT_ROOT>/vilt_irtr_coco.ckpt"

or you can evaluate zero-shot performance just simply using "weights/vilt_200k_mlm_itm.ckpt" instead.
or you can evaluate zero-shot performance just simply using "<YOUR_WEIGHT_ROOT>/vilt_200k_mlm_itm.ckpt" instead.

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=4 task_finetune_irtr_coco_randaug test_only=True load_path="weights/vilt_irtr_coco.ckpt"
Expand All @@ -58,7 +58,9 @@ INFO - ViLT - Completed after 1 day, 10:59:12

## Evaluate F30K IR/TR
```bash
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_irtr_f30k_randaug test_only=True load_path="weights/vilt_irtr_f30k.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> per_gpu_batchsize=<BS_FITS_YOUR_GPU> task_finetune_irtr_f30k_randaug test_only=True load_path="<YOUR_WEIGHT_ROOT>/vilt_irtr_f30k.ckpt"

or you can evaluate zero-shot performance just simply using "<YOUR_WEIGHT_ROOT>/vilt_200k_mlm_itm.ckpt" instead.

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=4 task_finetune_irtr_f30k_randaug test_only=True load_path="weights/vilt_irtr_f30k.ckpt"
Expand Down
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ pip install -r requirements.txt
pip install -e .
```

## Dataset Preparation
See [`DATA.md`](DATA.md)

## Download Pretrained Weights
We provide five pretrained weights
1. ViLT-B/32 Pretrained with MLM+ITM for 200k steps on GCC+SBU+COCO+VG (ViLT-B/32 200k) [link](https://www.dropbox.com/s/5b3slhy5uvdw8k0/vilt_200k_mlm_itm.ckpt?dl=0)
Expand All @@ -24,6 +21,18 @@ We provide five pretrained weights
4. ViLT-B/32 200k finetuned on COCO IR/TR [link](https://www.dropbox.com/s/dx3id644873fcgn/vilt_irtr_coco.ckpt?dl=0)
5. ViLT-B/32 200k finetuned on F30K IR/TR [link](https://www.dropbox.com/s/asidty0d4a1p2f4/vilt_irtr_f30k.ckpt?dl=0)

## Out-of-the-box MLM + Visualization Demo
```bash
pip install gradio==1.6.4
python demo.py with num_gpus=<0 if you have no gpus else 1> load_path="<YOUR_WEIGHT_ROOT>/vilt_200k_mlm_itm.ckpt"

ex)
python demo.py with num_gpus=0 load_path="weights/vilt_200k_mlm_itm.ckpt"
```
## Dataset Preparation
See [`DATA.md`](DATA.md)
## Train New Models
See [`TRAIN.md`](TRAIN.md)
Expand Down
8 changes: 4 additions & 4 deletions TRAIN.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_ml
export MASTER_ADDR=$DIST_0_IP
export MASTER_PORT=$DIST_0_PORT
export NODE_RANK=$DIST_RANK
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_vqa_trainval_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="weights/vilt_200k_mlm_itm.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_vqa_trainval_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<YOUR_WEIGHT_ROOT>/vilt_200k_mlm_itm.ckpt"

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_vqa_trainval_randaug per_gpu_batchsize=64 load_path="weights/vilt_200k_mlm_itm.ckpt"
Expand All @@ -27,7 +27,7 @@ python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_fi
export MASTER_ADDR=$DIST_0_IP
export MASTER_PORT=$DIST_0_PORT
export NODE_RANK=$DIST_RANK
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_nlvr2_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="weights/vilt_200k_mlm_itm.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_nlvr2_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<YOUR_WEIGHT_ROOT>/vilt_200k_mlm_itm.ckpt"

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_nlvr2_randaug per_gpu_batchsize=32 load_path="weights/vilt_200k_mlm_itm.ckpt"
Expand All @@ -38,7 +38,7 @@ python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_fi
export MASTER_ADDR=$DIST_0_IP
export MASTER_PORT=$DIST_0_PORT
export NODE_RANK=$DIST_RANK
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_irtr_coco_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="weights/vilt_200k_mlm_itm.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_irtr_coco_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<YOUR_WEIGHT_ROOT>/vilt_200k_mlm_itm.ckpt"

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_irtr_coco_randaug per_gpu_batchsize=4 load_path="weights/vilt_200k_mlm_itm.ckpt"
Expand All @@ -49,7 +49,7 @@ python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_fi
export MASTER_ADDR=$DIST_0_IP
export MASTER_PORT=$DIST_0_PORT
export NODE_RANK=$DIST_RANK
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_irtr_f30k_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="weights/vilt_200k_mlm_itm.ckpt"
python run.py with data_root=<ARROW_ROOT> num_gpus=<NUM_GPUS> num_nodes=<NUM_NODES> task_finetune_irtr_f30k_randaug per_gpu_batchsize=<BS_FITS_YOUR_GPU> load_path="<YOUR_WEIGHT_ROOT>/vilt_200k_mlm_itm.ckpt"

ex)
python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 task_finetune_irtr_f30k_randaug per_gpu_batchsize=4 load_path="weights/vilt_200k_mlm_itm.ckpt"
Expand Down
233 changes: 233 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import gradio as gr
import torch
import copy
import time
import requests
import io
import numpy as np
import re

import ipdb

from PIL import Image

from vilt.config import ex
from vilt.modules import ViLTransformerSS

from vilt.modules.objectives import cost_matrix_cosine, ipot
from vilt.transforms import pixelbert_transform
from vilt.datamodules.datamodule_base import get_pretrained_tokenizer


@ex.automain
def main(_config):
_config = copy.deepcopy(_config)

loss_names = {
"itm": 0,
"mlm": 0.5,
"mpp": 0,
"vqa": 0,
"imgcls": 0,
"nlvr2": 0,
"irtr": 0,
"arc": 0,
}
tokenizer = get_pretrained_tokenizer(_config["tokenizer"])

_config.update(
{
"loss_names": loss_names,
}
)

model = ViLTransformerSS(_config)
model.setup("test")
model.eval()

device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
model.to(device)

def infer(url, mp_text, hidx):
try:
res = requests.get(url)
image = Image.open(io.BytesIO(res.content)).convert("RGB")
img = pixelbert_transform(size=384)(image)
img = img.unsqueeze(0).to(device)
except:
return False

batch = {"text": [""], "image": [None]}
tl = len(re.findall("\[MASK\]", mp_text))
inferred_token = [mp_text]
batch["image"][0] = img

with torch.no_grad():
for i in range(tl):
batch["text"] = inferred_token
encoded = tokenizer(inferred_token)
batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
encoded = encoded["input_ids"][0][1:-1]
infer = model(batch)
mlm_logits = model.mlm_score(infer["text_feats"])[0, 1:-1]
mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
mlm_values[torch.tensor(encoded) != 103] = 0
select = mlm_values.argmax().item()
encoded[select] = mlm_ids[select].item()
inferred_token = [tokenizer.decode(encoded)]

selected_token = ""
encoded = tokenizer(inferred_token)

if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]):
with torch.no_grad():
batch["text"] = inferred_token
batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
infer = model(batch)
txt_emb, img_emb = infer["text_feats"], infer["image_feats"]
txt_mask, img_mask = (
infer["text_masks"].bool(),
infer["image_masks"].bool(),
)
for i, _len in enumerate(txt_mask.sum(dim=1)):
txt_mask[i, _len - 1] = False
txt_mask[:, 0] = False
img_mask[:, 0] = False
txt_pad, img_pad = ~txt_mask, ~img_mask

cost = cost_matrix_cosine(txt_emb.float(), img_emb.float())
joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
cost.masked_fill_(joint_pad, 0)

txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to(
dtype=cost.dtype
)
img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to(
dtype=cost.dtype
)
T = ipot(
cost.detach(),
txt_len,
txt_pad,
img_len,
img_pad,
joint_pad,
0.1,
1000,
1,
)

plan = T[0]
plan_single = plan * len(txt_emb)
cost_ = plan_single.t()

cost_ = cost_[hidx][1:].cpu()

patch_index, (H, W) = infer["patch_index"]
heatmap = torch.zeros(H, W)
for i, pidx in enumerate(patch_index[0]):
h, w = pidx[0].item(), pidx[1].item()
heatmap[h, w] = cost_[i]

heatmap = (heatmap - heatmap.mean()) / heatmap.std()
heatmap = np.clip(heatmap, 1.0, 3.0)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

_w, _h = image.size
overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize(
(_w, _h), resample=Image.NEAREST
)
image_rgba = image.copy()
image_rgba.putalpha(overlay)
image = image_rgba

selected_token = tokenizer.convert_ids_to_tokens(
encoded["input_ids"][0][hidx]
)

return [np.array(image), inferred_token[0], selected_token]

inputs = [
gr.inputs.Textbox(
label="A url of an image.",
lines=5,
),
gr.inputs.Textbox(label="A caption with [MASK] tokens to be filled.", lines=5),
gr.inputs.Slider(
minimum=0,
maximum=38,
step=1,
label="A index of token for heatmap visualization (ignored if zero)",
),
]
outputs = [
gr.outputs.Image(label="Image"),
gr.outputs.Textbox(label="description"),
gr.outputs.Textbox(label="selected token"),
]

interface = gr.Interface(
fn=infer,
inputs=inputs,
outputs=outputs,
server_name="0.0.0.0",
server_port=8888,
examples=[
[
"https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
"a display of flowers growing out and over the [MASK] [MASK] in front of [MASK] on a [MASK] [MASK].",
0,
],
[
"https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
"a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
4,
],
[
"https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
"a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
11,
],
[
"https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
"a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
15,
],
[
"https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
"a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day.",
18,
],
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
"a room with a [MASK], a [MASK], a [MASK], and a [MASK].",
0,
],
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
"a room with a rug, a chair, a painting, and a plant.",
5,
],
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
"a room with a rug, a chair, a painting, and a plant.",
8,
],
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
"a room with a rug, a chair, a painting, and a plant.",
11,
],
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/40/Living_Room.jpg/800px-Living_Room.jpg",
"a room with a rug, a chair, a painting, and a plant.",
15,
],
],
)

interface.launch(debug=True)
Loading

0 comments on commit e1962f0

Please sign in to comment.