Skip to content

Commit

Permalink
add vqa demo
Browse files Browse the repository at this point in the history
  • Loading branch information
dandelin committed Apr 30, 2021
1 parent e1962f0 commit 041c08f
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ ex)
python demo.py with num_gpus=0 load_path="weights/vilt_200k_mlm_itm.ckpt"
```
## Out-of-the-box VQA Demo
```bash
pip install gradio==1.6.4
python demo_vqa.py with num_gpus=<0 if you have no gpus else 1> load_path="<YOUR_WEIGHT_ROOT>/vilt_vqa.ckpt" test_only=True

ex)
python demo_vqa.py with num_gpus=0 load_path="weights/vilt_vqa.ckpt" test_only=True
```
## Dataset Preparation
See [`DATA.md`](DATA.md)
Expand Down
6 changes: 3 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def infer(url, mp_text, hidx):

inputs = [
gr.inputs.Textbox(
label="A url of an image.",
label="Url of an image.",
lines=5,
),
gr.inputs.Textbox(label="A caption with [MASK] tokens to be filled.", lines=5),
gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5),
gr.inputs.Slider(
minimum=0,
maximum=38,
step=1,
label="A index of token for heatmap visualization (ignored if zero)",
label="Index of token for heatmap visualization (ignored if zero)",
),
]
outputs = [
Expand Down
107 changes: 107 additions & 0 deletions demo_vqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import gradio as gr
import torch
import copy
import time
import requests
import io
import numpy as np
import re
import json
import urllib.request

import ipdb

from PIL import Image

from vilt.config import ex
from vilt.modules import ViLTransformerSS

from vilt.transforms import pixelbert_transform
from vilt.datamodules.datamodule_base import get_pretrained_tokenizer

@ex.automain
def main(_config):
_config = copy.deepcopy(_config)

loss_names = {
"itm": 0,
"mlm": 0,
"mpp": 0,
"vqa": 1,
"imgcls": 0,
"nlvr2": 0,
"irtr": 0,
"arc": 0,
}
tokenizer = get_pretrained_tokenizer(_config["tokenizer"])

with urllib.request.urlopen("https://dl.dropboxusercontent.com/s/otya4i5sagt4f5p/vqa_dict.json") as url:
id2ans = json.loads(url.read().decode())

_config.update(
{
"loss_names": loss_names,
}
)

model = ViLTransformerSS(_config)
model.setup("test")
model.eval()

device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
model.to(device)

def infer(url, text):
try:
res = requests.get(url)
image = Image.open(io.BytesIO(res.content)).convert("RGB")
img = pixelbert_transform(size=384)(image)
img = img.unsqueeze(0).to(device)
except:
return False

batch = {"text": [text], "image": [img]}

with torch.no_grad():
encoded = tokenizer(batch['text'])
batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
infer = model.infer(batch)
vqa_logits = model.vqa_classifier(infer['cls_feats'])

answer = id2ans[str(vqa_logits.argmax().item())]

return [np.array(image), answer]

inputs = [
gr.inputs.Textbox(
label="Url of an image.",
lines=5,
),
gr.inputs.Textbox(label="Question", lines=5),
]
outputs = [
gr.outputs.Image(label="Image"),
gr.outputs.Textbox(label="Answer"),
]

interface = gr.Interface(
fn=infer,
inputs=inputs,
outputs=outputs,
server_name="0.0.0.0",
server_port=8888,
examples=[
[
"https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
"Is the sky cloudy?",
],
[
"https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg",
"Color of flower?",
],
],
)

interface.launch(debug=True)

0 comments on commit 041c08f

Please sign in to comment.