Skip to content

Latest commit

 

History

History
93 lines (54 loc) · 2.7 KB

README.md

File metadata and controls

93 lines (54 loc) · 2.7 KB

Noise Contrastive Estimation (NCE)

Introduction

This is an implementation of Noise Contrastive Estimation (NCE) in PyTorch on 2D dataset.

NCE is a method to estimate energy based models (EBM)

$$p_\theta(x) = \frac{\exp[-f_\theta(x)]}{Z(\theta)}$$

where

$$Z(\theta) = \int\exp[-f_\theta(x)]dx$$

is the normalizing constant that is hard to compute. In NCE, the normalizing constant is treated as a trainable parameter $c=\log Z$. We cannot directly do maximum likelihood estimation (MLE) with $\displaystyle\max_\theta p_\theta(x)$ because $p_\theta(x)$ can simply blow up to infinity by letting $Z\to0$ (or $c\to -\infty$). Instead, in Noise Contrastive Estimation, we train the energy based model by doing (nonlinear) logistic regression/classification between the data distribution $p_{\mathrm{data}}$ and some noise distribution $q$.

There are three requirements for the noise distribution $q$:

  (1) log density can be evaluated on any input;

  (2) samples can be obtained from the distribution;

  (3) $q(x)\neq0$ for all $x$ such that $p_{\mathrm{data}}(x)\neq0$.

Here we use Multivariate Gaussian as the noise distribution.

The objective is to maximize the posterior log-likelihood of the classification

$$V(\theta) = \mathbb{E}_{x\sim p_{\text{data}}}\log\frac{p_\theta(x)}{p_\theta(x)+q(x)} + \mathbb{E}_{\tilde{x}\sim q}\log\frac{q(\tilde{x})}{p_\theta(\tilde{x}) + q(\tilde{x})}.$$

This objective is implemented in the file util.py as the value function (we minimize $-V(\theta)$). We use Adam as the optimizer.

Installation

Clone the repository to your local machine with

git clone https://github.com/lifeitech/nce.git

In your python environment, cd to the repository, and

pip install -r requirements.txt

Training

To train the model, do

python trian.py

For MacOS users, since currently PyTorch only has limited support for mps, make sure to run the script with PYTORCH_ENABLE_MPS_FALLBACK=1. You can add

export PYTORCH_ENABLE_MPS_FALLBACK=1 

to your .zshrc file.

Available datasets:

  • 8gaussians (default)
  • 2spirals
  • checkerboard
  • rings
  • pinwheel

A density plot is saved in the folder images after every epoch. After training, you can obtain gif images like below by executing the python script in the folder:

cd images
python create_gif.py

Examples

Some visualizations of the learned energy densities are listed below.

  • 8gaussians dataset

8gaussians

  • pinwheel dataset

pinwheel

  • 2spirals dataset

2spirals