An unofficial PyTorch implementation of the paper "End-to-End Human-Gaze-Target Detection with Transformers".
We provide a pip requirements file to install all the dependencies. We recommend using a conda environment to install the dependencies.
# Clone project and submodules
git clone --recursive https://github.com/francescotonini/human-gaze-target-detection-transformer.git
cd human-gaze-target-detection-transformer
# Create conda environment
conda create -n human-gaze-target-detection-transformer python=3.9
conda activate human-gaze-target-detection-transformer
# Install requirements
pip install -r requirements.txt
(optional) Setup wandb
cp .env.example .env
# Add token to .env
The code expects that the datasets are placed under the data/ folder.
You can change this by modifying the data_dir
parameter in the configuration files.
cat <<EOT >> configs/local/default.yaml
# @package _global_
paths:
data_dir: "{PATH TO DATASETS}"
EOT
The implementation requires faces annotations ("auxiliary faces", i.e. the ones not annotated by GazeFollow or VideoAttentionTarget). Therefore, you need run the following script to extract face annotations.
# GazeFollow
python scripts/gazefollow_get_aux_faces.py --dataset_dir /path/to/gazefollow --subset train
python scripts/gazefollow_get_aux_faces.py --dataset_dir /path/to/gazefollow --subset test
# VideoAttentionTarget
cp data/videoattentiontarget_extended/*.csv /path/to/videoattentiontarget
python scripts/videoattentiontarget_get_aux_faces.py --dataset_dir /path/to/videoattentiontarget --subset train
python scripts/videoattentiontarget_get_aux_faces.py --dataset_dir /path/to/videoattentiontarget --subset test
We provide configuration to train on GazeFollow and VideoAttentionTarget (see configs/experiment/).
# GazeFollow
python src/train.py experiment=hgttr_gazefollow
# VideoAttentionTarget
python src/train.py experiment=hgttr_videoattentiontarget +model.net_pretraining={URL/PATH TO GAZEFOLLOW WEIGHTS}
The configuration files are also useful when evaluating the model.
# GazeFollow
python src/eval.py experiment=hgttr_gazefollow ckpt_path={PATH TO CHECKPOINT}
# VideoAttentionTarget
python src/eval.py experiment=hgttr_videoattentiontarget ckpt_path={PATH TO CHECKPOINT}
We provide model weights for GazeFollow at this URL and VideoAttentionTarget at this URL.
NOTE: when evaluating on the checkpoints above, replace ckpt_path={PATH_TO_CHECKPOINT}
with +model.net_pretraining={PATH_TO_CHECKPOINT}
.
This code is based on PyTorch Lightning, Hydra, and the official DETR implementation.
If you use this code implementation or our trained weights in your research, please cite us:
@inproceedings{tonini2023objectaware,
title={Object-aware Gaze Target Detection},
author={Tonini, Francesco and Dall'Asen, Nicola and Beyan, Cigdem and Ricci, Elisa},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={21860--21869},
year={2023}
}
and the original paper:
@inproceedings{tu2022end,
title={End-to-end human-gaze-target detection with transformers},
author={Tu, Danyang and Min, Xiongkuo and Duan, Huiyu and Guo, Guodong and Zhai, Guangtao and Shen, Wei},
booktitle={2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month={June},
year={2022},
}