From 3f31a4d776ebd7a3cb0bfb94f8bce5eb3d696b85 Mon Sep 17 00:00:00 2001 From: Matt Peddie Date: Fri, 25 Aug 2023 11:21:59 +1000 Subject: [PATCH] Generalize covariance RNG type --- include/albatross/src/utils/random_utils.hpp | 23 ++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/include/albatross/src/utils/random_utils.hpp b/include/albatross/src/utils/random_utils.hpp index c19addd5..4fdcc70e 100644 --- a/include/albatross/src/utils/random_utils.hpp +++ b/include/albatross/src/utils/random_utils.hpp @@ -100,10 +100,10 @@ void gaussian_fill(Eigen::Matrix<_Scalar, _Rows, _Cols> &matrix) { gaussian_fill(matrix, 0., 1., rng); } -template +template inline Eigen::MatrixXd random_covariance_matrix(Eigen::Index k, Distribution &eigen_value_distribution, - std::default_random_engine &gen) { + RandomNumberGenerator &gen) { Eigen::MatrixXd Q(k, k); gaussian_fill(Q, gen); @@ -116,16 +116,17 @@ random_covariance_matrix(Eigen::Index k, Distribution &eigen_value_distribution, return Q * diag.asDiagonal() * Q.transpose(); } -inline Eigen::MatrixXd -random_covariance_matrix(Eigen::Index k, std::default_random_engine &gen) { +template +inline Eigen::MatrixXd random_covariance_matrix(Eigen::Index k, + RandomNumberGenerator &gen) { std::gamma_distribution distribution(1.0, 1.0); return random_covariance_matrix(k, distribution, gen); } -inline Eigen::VectorXd -random_multivariate_normal(const Eigen::VectorXd &mean, - const Eigen::MatrixXd &cov, - std::default_random_engine &gen) { +template +inline Eigen::VectorXd random_multivariate_normal(const Eigen::VectorXd &mean, + const Eigen::MatrixXd &cov, + RandomNumberGenerator &gen) { std::normal_distribution dist; ALBATROSS_ASSERT(mean.size() == cov.rows()); ALBATROSS_ASSERT(cov.rows() == cov.cols()); @@ -140,9 +141,9 @@ random_multivariate_normal(const Eigen::VectorXd &mean, return sample; } -inline Eigen::VectorXd -random_multivariate_normal(const Eigen::MatrixXd &cov, - std::default_random_engine &gen) { +template +inline Eigen::VectorXd random_multivariate_normal(const Eigen::MatrixXd &cov, + RandomNumberGenerator &gen) { return random_multivariate_normal(Eigen::VectorXd::Zero(cov.rows()), cov, gen); }