diff --git a/README.md b/README.md index 997443f..97759de 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,15 @@ ex) python demo.py with num_gpus=0 load_path="weights/vilt_200k_mlm_itm.ckpt" ``` +## Out-of-the-box VQA Demo +```bash +pip install gradio==1.6.4 +python demo_vqa.py with num_gpus=<0 if you have no gpus else 1> load_path="/vilt_vqa.ckpt" test_only=True + +ex) +python demo_vqa.py with num_gpus=0 load_path="weights/vilt_vqa.ckpt" test_only=True +``` + ## Dataset Preparation See [`DATA.md`](DATA.md) diff --git a/demo.py b/demo.py index f8dcf8c..aa79897 100644 --- a/demo.py +++ b/demo.py @@ -153,15 +153,15 @@ def infer(url, mp_text, hidx): inputs = [ gr.inputs.Textbox( - label="A url of an image.", + label="Url of an image.", lines=5, ), - gr.inputs.Textbox(label="A caption with [MASK] tokens to be filled.", lines=5), + gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5), gr.inputs.Slider( minimum=0, maximum=38, step=1, - label="A index of token for heatmap visualization (ignored if zero)", + label="Index of token for heatmap visualization (ignored if zero)", ), ] outputs = [ diff --git a/demo_vqa.py b/demo_vqa.py new file mode 100644 index 0000000..5f8f75a --- /dev/null +++ b/demo_vqa.py @@ -0,0 +1,107 @@ +import gradio as gr +import torch +import copy +import time +import requests +import io +import numpy as np +import re +import json +import urllib.request + +import ipdb + +from PIL import Image + +from vilt.config import ex +from vilt.modules import ViLTransformerSS + +from vilt.transforms import pixelbert_transform +from vilt.datamodules.datamodule_base import get_pretrained_tokenizer + +@ex.automain +def main(_config): + _config = copy.deepcopy(_config) + + loss_names = { + "itm": 0, + "mlm": 0, + "mpp": 0, + "vqa": 1, + "imgcls": 0, + "nlvr2": 0, + "irtr": 0, + "arc": 0, + } + tokenizer = get_pretrained_tokenizer(_config["tokenizer"]) + + with urllib.request.urlopen("https://dl.dropboxusercontent.com/s/otya4i5sagt4f5p/vqa_dict.json") as url: + id2ans = json.loads(url.read().decode()) + + _config.update( + { + "loss_names": loss_names, + } + ) + + model = ViLTransformerSS(_config) + model.setup("test") + model.eval() + + device = "cuda:0" if _config["num_gpus"] > 0 else "cpu" + model.to(device) + + def infer(url, text): + try: + res = requests.get(url) + image = Image.open(io.BytesIO(res.content)).convert("RGB") + img = pixelbert_transform(size=384)(image) + img = img.unsqueeze(0).to(device) + except: + return False + + batch = {"text": [text], "image": [img]} + + with torch.no_grad(): + encoded = tokenizer(batch['text']) + batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device) + batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device) + batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device) + infer = model.infer(batch) + vqa_logits = model.vqa_classifier(infer['cls_feats']) + + answer = id2ans[str(vqa_logits.argmax().item())] + + return [np.array(image), answer] + + inputs = [ + gr.inputs.Textbox( + label="Url of an image.", + lines=5, + ), + gr.inputs.Textbox(label="Question", lines=5), + ] + outputs = [ + gr.outputs.Image(label="Image"), + gr.outputs.Textbox(label="Answer"), + ] + + interface = gr.Interface( + fn=infer, + inputs=inputs, + outputs=outputs, + server_name="0.0.0.0", + server_port=8888, + examples=[ + [ + "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", + "Is the sky cloudy?", + ], + [ + "https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg", + "Color of flower?", + ], + ], + ) + + interface.launch(debug=True) \ No newline at end of file