Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use double buffering in particle storage #266

Merged
merged 3 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions beluga/include/beluga/storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ class StoragePolicy : public Mixin {
*/
template <class Range>
void initialize_particles(Range&& input) {
static_assert(std::is_same_v<particle_type, ranges::range_value_t<Range>>, "Invalid value type");
static_assert(std::is_convertible_v<ranges::range_value_t<Range>, particle_type>, "Invalid value type");
const std::size_t size = this->self().max_samples();
particles_.resize(size);
const auto first = std::begin(views::all(particles_));
particles_[1].resize(size);
hidmic marked this conversation as resolved.
Show resolved Hide resolved
const auto first = std::begin(views::all(particles_[1]));
const auto last = ranges::copy(input | ranges::views::take(size), first).out;
hidmic marked this conversation as resolved.
Show resolved Hide resolved
particles_.resize(static_cast<std::size_t>(std::distance(first, last)));
particles_[1].resize(static_cast<std::size_t>(std::distance(first, last)));
std::swap(particles_[0], particles_[1]);
}

/// \copydoc StorageInterface::initialize_states()
Expand All @@ -191,7 +192,7 @@ class StoragePolicy : public Mixin {
}

/// \copydoc StorageInterface::particle_count()
[[nodiscard]] std::size_t particle_count() const final { return particles_.size(); }
[[nodiscard]] std::size_t particle_count() const final { return particles_[0].size(); }

/// \copydoc StorageInterface::states_view()
[[nodiscard]] output_view_type states_view() const final { return this->states(); }
Expand All @@ -200,22 +201,22 @@ class StoragePolicy : public Mixin {
[[nodiscard]] weights_view_type weights_view() const final { return this->weights(); }

/// Returns a view of the particles container.
[[nodiscard]] auto particles() { return views::all(particles_); }
[[nodiscard]] auto particles() { return views::all(particles_[0]); }
/// Returns a const view of the particles container.
[[nodiscard]] auto particles() const { return views::all(particles_) | ranges::views::const_; }
[[nodiscard]] auto particles() const { return views::all(particles_[0]) | ranges::views::const_; }

/// Returns a view of the particles states.
[[nodiscard]] auto states() { return views::states(particles_); }
[[nodiscard]] auto states() { return views::states(particles_[0]); }
/// Returns a const view of the particles states.
[[nodiscard]] auto states() const { return views::states(particles_) | ranges::views::const_; }
[[nodiscard]] auto states() const { return views::states(particles_[0]) | ranges::views::const_; }

/// Returns a view of the particles weight.
[[nodiscard]] auto weights() { return views::weights(particles_); }
[[nodiscard]] auto weights() { return views::weights(particles_[0]); }
/// Returns a const view of the particles weight.
[[nodiscard]] auto weights() const { return views::weights(particles_) | ranges::views::const_; }
[[nodiscard]] auto weights() const { return views::weights(particles_[0]) | ranges::views::const_; }

private:
Container particles_;
std::array<Container, 2> particles_;
};

/// A storage policy that implements a structure of arrays layout.
Expand Down
13 changes: 13 additions & 0 deletions beluga/test/beluga/test_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include <beluga/storage.hpp>
#include <ciabatta/ciabatta.hpp>

#include <range/v3/algorithm/equal.hpp>
#include <range/v3/view/reverse.hpp>

namespace {

using testing::Return;
Expand Down Expand Up @@ -69,4 +72,14 @@ TYPED_TEST(StoragePolicyTest, InitializeWithMoreParticlesThanExpected) {
ASSERT_EQ(mixin.particle_count(), 2);
}

TYPED_TEST(StoragePolicyTest, ResampleParticles) {
auto states = std::vector<int>{1, 2, 3, 4, 5};
auto mixin = TypeParam{};
EXPECT_CALL(mixin, max_samples()).WillOnce(Return(5)).WillOnce(Return(5));
mixin.initialize_states(states | ranges::views::all);
ASSERT_EQ(mixin.particle_count(), 5);
mixin.initialize_particles(mixin.particles() | ranges::views::reverse);
ASSERT_TRUE(ranges::equal(mixin.states(), states | ranges::views::reverse));
}

} // namespace
Loading