CaRiNG: Learning Temporal Causal Representation under Non-Invertible Generation Proces Guangyi Chen*, Yifan Shen*, Zhenhao Chen*, Xiangchen Song, Yuewen Sun, Weiran Yao, Xiao Liu, Kun Zhang
Official implementation of the paper "CaRiNG: Learning Temporal Causal Representation under Non-Invertible Generation Proces".
Abstract: Identifying the underlying time-delayed latent causal processes in sequential data is vital for grasping temporal dynamics and making downstream reasoning. While some recent methods can robustly identify these latent causal variables, they rely on strict assumptions about the invertible generation process from latent variables to observed data. However, these assumptions are often hard to satisfy in real-world applications containing information loss. For instance, the visual perception process translates a 3D space into 2D images, or the phenomenon of persistence of vision incorporates historical data into current perceptions. To address this challenge, we establish an identifiability theory that allows for the recovery of independent latent components even when they come from a nonlinear and non-invertible mix. Using this theory as a foundation, we propose a principled approach, CaRiNG, to learn the Causal Representation of Non-invertible Generative temporal data with identifiability guarantees. Specifically, we utilize temporal context to recover lost latent information and apply the conditions in our theory to guide the training process. Through experiments conducted on synthetic datasets, we validate that our CaRiNG method reliably identifies the causal process, even when the generation process is non-invertible. Moreover, we demonstrate that our approach considerably improves temporal understanding and reasoning in practical applications.
-
To the best of our understanding, this paper presents the first identifiability theorem that accommodates a non-invertible generation process, which complements the existing body of the nonlinear ICA theory.
-
We present a principled approach, CaRiNG, to learn the latent causal representation from temporal data under non-invertible generation processes with identifiability guarantees, by integrating temporal context information to recover the lost information.
-
Our evaluations across synthetic and real-world datasets demonstrate the CaRiNG's effectiveness for learning the identifiable latent causal representation, leading to enhancements in video reasoning tasks.
Qualitative comparisons between baselines (especially TDRL) and CaRiNG in the setting of Non-invertible Generation. (a) MCC matrix for all 3 latent variables; (b) The scatter plots between the estimated and ground-truth latent variables (only the aligned variables are plot); (c) The validation MCC curves of CaRiNG and other baselines.
Settings\Methods | NG | NG-TDMP |
---|---|---|
CaRiNG | ||
TDRL | ||
LEAP | ||
SlowVAE | ||
PCL | ||
betaVAE | ||
SKD | ||
iVAE | ||
SequentialVAE |
MCC scores (with standard deviations over 4 seeds) of CaRiNG and baselines on NG and NG-TDMP settings.
We provide bash scripts in caring/scripts/run_caring.sh
to generate results for experiment. Please follow the instructions below to run the code.
If you wish to download the code along with the datasets and , please make sure you have git-lfs
installed.
git lfs install
Then download the code with datasets.
git clone [email protected]:sanshuiii/CaRiNG.git
You can also download the code only without datasets in case of network issue.
GIT_LFS_SKIP_SMUDGE=1 git clone [email protected]:sanshuiii/CaRiNG.git
To install it, a version of Python>=3.7
is required. Please kindly note that the current implementation of CaRiNG requires a GPU.
cd CaRiNG
conda create -n caring python=3.7
pip install -e .
Modify root_path
to your repo path (where you are now) in caring/configs/caring_ng.yaml and caring/configs/caring_ng_tdmp.yaml respectively.
Go to the scripts
folder and run the following command to reproduce the results.
bash run_caring.sh
You may also generate your own datasets. Go to the folder data_generator
and run the following command to generate the datasets. The argument
- NG:
python data_generator/NG.py 12
- NG-TDMP:
python data_generator/NG-TDMP.py 12