Skip to content

Commit

Permalink
Add wasserstein-2 metric calculation (#488)
Browse files Browse the repository at this point in the history
The Wasserstein metric measures the distance between two distributions,
accounting for differences in both shape and location. This can be used
to compare two multivariate Gaussians.

This change is in support of PIC.
  • Loading branch information
peddie authored Jun 25, 2024
1 parent 92f3a99 commit d34e284
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ swift_cc_test(
linkopts = ["-lz"],
local_defines = ["CSV_IO_NO_THREAD"],
type = UNIT,
size = "large",
deps = [
":albatross",
":serialize-testsuite",
Expand Down
24 changes: 24 additions & 0 deletions include/albatross/src/evaluation/prediction_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,30 @@ struct ChiSquaredCdf : public PredictionMetric<JointDistribution> {
ChiSquaredCdf() : PredictionMetric<JointDistribution>(chi_squared_cdf){};
};

namespace distance {

namespace detail {

inline Eigen::MatrixXd principal_sqrt(const Eigen::MatrixXd &input) {
const Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> eigs(input);
return eigs.eigenvectors() *
eigs.eigenvalues().array().sqrt().matrix().asDiagonal() *
eigs.eigenvectors().transpose();
}

} // namespace detail

inline double wasserstein_2(const JointDistribution &a,
const JointDistribution &b) {
auto b_sqrt{detail::principal_sqrt(b.covariance)};
return (a.mean - b.mean).squaredNorm() +
(a.covariance + b.covariance -
2 * detail::principal_sqrt(b_sqrt * a.covariance * b_sqrt))
.trace();
}

} // namespace distance

} // namespace albatross

#endif /* ALBATROSS_EVALUATION_PREDICTION_METRICS_H_ */
98 changes: 98 additions & 0 deletions tests/test_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <albatross/Common>
#include <albatross/Distribution>
#include <albatross/Evaluation>
#include <albatross/Stats>
#include <albatross/utils/RandomUtils>

Expand Down Expand Up @@ -172,4 +173,101 @@ TEST(test_stats, test_chi_squared_cdf_monotonic_1d) {
}
}

template <typename RandomNumberGenerator>
JointDistribution random_distribution(Eigen::Index dimension,
RandomNumberGenerator &gen) {
const auto covariance = random_covariance_matrix(dimension, gen);
Eigen::VectorXd mean(dimension);
gaussian_fill(mean, gen);

return {mean, covariance};
}

static constexpr Eigen::Index cDistributionDimension = 30;
static constexpr std::size_t cNumIterations = 10000;

// The Wasserstein distance between a distribution and itself should
// be zero to within numerical precision.
TEST(test_stats, test_wasserstein_zero) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
const auto dist = random_distribution(dimension, gen);

EXPECT_LT(distance::wasserstein_2(dist, dist),
1.e-12 * dist.covariance.trace() +
1.e-12 * dist.mean.squaredNorm());
}
}

// The Wasserstein distance between two distributions should aways be
// nonnegative.
TEST(test_stats, test_wasserstein_nonnegative) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
const auto dist_a = random_distribution(dimension, gen);
const auto dist_b = random_distribution(dimension, gen);

EXPECT_GE(distance::wasserstein_2(dist_a, dist_b), 0);
}
}

// If two distributions differ only in their mean, then the
// Wasserstein 2-distance should differ according to the square of the
// distance between means (i.e. the Wasserstein distance has the same
// units as the mean).
TEST(test_stats, test_wasserstein_shift) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
auto dist_a = random_distribution(dimension, gen);
auto dist_b = dist_a;
gaussian_fill(dist_b.mean, gen);

const double distance = distance::wasserstein_2(dist_a, dist_b);

const double mean_distance = (dist_a.mean - dist_b.mean).squaredNorm();

EXPECT_LT(distance - mean_distance, 1.e-10);
}
}

// If we inflate the covariance of the distribution, the Wasserstein
// distance to the original distribution should increase.
TEST(test_stats, test_wasserstein_grows_with_covariance) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
auto dist_a = random_distribution(dimension, gen);
const Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> cov_eigs(
dist_a.covariance);

auto dist_b = dist_a;
dist_b.covariance =
cov_eigs.eigenvectors() *
(cov_eigs.eigenvalues().array() * 2).matrix().asDiagonal() *
cov_eigs.eigenvectors().transpose();

auto dist_c = dist_a;
dist_c.covariance =
cov_eigs.eigenvectors() *
(cov_eigs.eigenvalues().array() * 4).matrix().asDiagonal() *
cov_eigs.eigenvectors().transpose();

const double distance_ab = distance::wasserstein_2(dist_a, dist_b);
const double distance_ac = distance::wasserstein_2(dist_a, dist_c);

EXPECT_GT(distance_ac, distance_ab);
}
}

} // namespace albatross

0 comments on commit d34e284

Please sign in to comment.