Skip to content

Commit

Permalink
extend dataset with forgotten functionality
Browse files Browse the repository at this point in the history
Signed-off-by: Martijn Govers <[email protected]>
  • Loading branch information
mgovers committed Aug 30, 2024
1 parent f8d9331 commit c7e47b2
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ template <dataset_type_tag dataset_type_> class Dataset {
Dataset get_individual_scenario(Idx scenario)
requires(!is_indptr_mutable_v<dataset_type>)
{
using AdvanceablePtr = std::conditional_t<is_data_mutable_v<dataset_type>, char*, char const*>;

assert(0 <= scenario && scenario < batch_size());

Dataset result{false, 1, dataset().name, meta_data()};
Expand All @@ -366,10 +368,17 @@ template <dataset_type_tag dataset_type_> class Dataset {
Idx size = component_info.elements_per_scenario >= 0
? component_info.elements_per_scenario
: buffer.indptr[scenario + 1] - buffer.indptr[scenario];
Data* data = component_info.elements_per_scenario >= 0
? component_info.component->advance_ptr(buffer.data, size * scenario)
: component_info.component->advance_ptr(buffer.data, buffer.indptr[scenario]);
result.add_buffer(component_info.component->name, size, size, nullptr, data);
Idx offset = component_info.elements_per_scenario >= 0 ? size * scenario : buffer.indptr[scenario];
if (is_columnar(buffer)) {
result.add_buffer(component_info.component->name, size, size, nullptr, nullptr);
for (auto const& attribute_buffer : buffer.attributes) {
result.add_attribute_buffer(component_info.component->name, attribute_buffer.meta_attribute->name,
static_cast<Data*>(static_cast<AdvanceablePtr>(attribute_buffer.data)));
}
} else {
Data* data = component_info.component->advance_ptr(buffer.data, offset);
result.add_buffer(component_info.component->name, size, size, nullptr, data);
}
}
return result;
}
Expand Down
117 changes: 82 additions & 35 deletions tests/cpp_unit_tests/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,41 +967,88 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa

auto dataset = create_dataset(true, batch_size, dataset_type);

auto a_buffer = std::vector<A::InputType>(a_elements_per_scenario * batch_size);
auto b_buffer = std::vector<A::InputType>(3);
auto b_indptr = std::vector<Idx>{0, 0, narrow_cast<Idx>(b_buffer.size())};

add_homogeneous_buffer(dataset, A::name, a_elements_per_scenario, static_cast<void*>(a_buffer.data()));
add_inhomogeneous_buffer(dataset, B::name, b_buffer.size(), b_indptr.data(),
static_cast<void*>(b_buffer.data()));

for (auto scenario = 0; scenario < batch_size; ++scenario) {
auto const scenario_dataset = dataset.get_individual_scenario(scenario);

CHECK(&scenario_dataset.meta_data() == &dataset.meta_data());
CHECK(!scenario_dataset.empty());
CHECK(scenario_dataset.is_batch() == false);
CHECK(scenario_dataset.batch_size() == 1);
CHECK(scenario_dataset.n_components() == dataset.n_components());

CHECK(scenario_dataset.get_component_info(A::name).component == &dataset_type.get_component(A::name));
CHECK(scenario_dataset.get_component_info(A::name).elements_per_scenario == a_elements_per_scenario);
CHECK(scenario_dataset.get_component_info(A::name).total_elements == a_elements_per_scenario);

CHECK(scenario_dataset.get_component_info(B::name).component == &dataset_type.get_component(B::name));
CHECK(scenario_dataset.get_component_info(B::name).elements_per_scenario ==
dataset.template get_buffer_span<input_getter_s, B>(scenario).size());
CHECK(scenario_dataset.get_component_info(B::name).total_elements ==
scenario_dataset.get_component_info(B::name).elements_per_scenario);

auto const scenario_span_a = scenario_dataset.template get_buffer_span<input_getter_s, A>();
auto const scenario_span_b = scenario_dataset.template get_buffer_span<input_getter_s, B>();
auto const dataset_span_a = dataset.template get_buffer_span<input_getter_s, A>(scenario);
auto const dataset_span_b = dataset.template get_buffer_span<input_getter_s, B>(scenario);
CHECK(scenario_span_a.data() == dataset_span_a.data());
CHECK(scenario_span_a.size() == dataset_span_a.size());
CHECK(scenario_span_b.data() == dataset_span_b.data());
CHECK(scenario_span_b.size() == dataset_span_b.size());
auto const check_get_individual_scenario = [&] {
for (auto scenario = 0; scenario < batch_size; ++scenario) {
CAPTURE(scenario);
auto const scenario_dataset = dataset.get_individual_scenario(scenario);

CHECK(&scenario_dataset.meta_data() == &dataset.meta_data());
CHECK(!scenario_dataset.empty());
CHECK(scenario_dataset.is_batch() == false);
CHECK(scenario_dataset.batch_size() == 1);
CHECK(scenario_dataset.n_components() == dataset.n_components());

CHECK(scenario_dataset.get_component_info(A::name).component ==
&dataset_type.get_component(A::name));
CHECK(scenario_dataset.get_component_info(A::name).elements_per_scenario ==
a_elements_per_scenario);
CHECK(scenario_dataset.get_component_info(A::name).total_elements == a_elements_per_scenario);

CHECK(scenario_dataset.get_component_info(B::name).component ==
&dataset_type.get_component(B::name));
auto const expected_size =
dataset.is_columnar(dataset.get_buffer(B::name))
? dataset.template get_columnar_buffer_span<input_getter_s, B>(scenario).size()
: dataset.template get_buffer_span<input_getter_s, B>(scenario).size();
CHECK(scenario_dataset.get_component_info(B::name).elements_per_scenario == expected_size);
CHECK(scenario_dataset.get_component_info(B::name).total_elements ==
scenario_dataset.get_component_info(B::name).elements_per_scenario);

if (dataset.is_columnar(dataset.get_buffer(A::name))) {
auto const scenario_span_a =
scenario_dataset.template get_columnar_buffer_span<input_getter_s, A>();
auto const dataset_span_a =
dataset.template get_columnar_buffer_span<input_getter_s, A>(scenario);
REQUIRE(scenario_span_a.size() == dataset_span_a.size());
for (Idx idx = 0; idx < scenario_span_a.size(); ++idx) {
auto const scenario_element = scenario_span_a[idx].get();
auto const& dataset_element = dataset_span_a[idx].get();
CHECK(scenario_element.id == dataset_element.id);
CHECK(scenario_element.a1 == dataset_element.a1);
}
} else {
auto const scenario_span_a = scenario_dataset.template get_buffer_span<input_getter_s, A>();
auto const dataset_span_a = dataset.template get_buffer_span<input_getter_s, A>(scenario);
CHECK(scenario_span_a.data() == dataset_span_a.data());
CHECK(scenario_span_a.size() == dataset_span_a.size());
}
if (dataset.is_columnar(dataset.get_buffer(B::name))) {
auto const scenario_span_b =
scenario_dataset.template get_columnar_buffer_span<input_getter_s, B>();
auto const dataset_span_b =
dataset.template get_columnar_buffer_span<input_getter_s, B>(scenario);
CHECK(scenario_span_b.begin() == dataset_span_b.begin());
CHECK(scenario_span_b.size() == dataset_span_b.size());
} else {
auto const scenario_span_b = scenario_dataset.template get_buffer_span<input_getter_s, B>();
auto const dataset_span_b = dataset.template get_buffer_span<input_getter_s, B>(scenario);
CHECK(scenario_span_b.data() == dataset_span_b.data());
CHECK(scenario_span_b.size() == dataset_span_b.size());
}
}
};

SUBCASE("row-based") {
auto a_buffer = std::vector<A::InputType>(a_elements_per_scenario * batch_size);
auto b_buffer = std::vector<A::InputType>(3);
auto b_indptr = std::vector<Idx>{0, 0, narrow_cast<Idx>(b_buffer.size())};
add_homogeneous_buffer(dataset, A::name, a_elements_per_scenario, static_cast<void*>(a_buffer.data()));
add_inhomogeneous_buffer(dataset, B::name, b_buffer.size(), b_indptr.data(),
static_cast<void*>(b_buffer.data()));

check_get_individual_scenario();
}
SUBCASE("columnar") {
auto a_id_buffer = std::vector<ID>(a_elements_per_scenario * batch_size);
auto a_a1_buffer = std::vector<double>(a_elements_per_scenario * batch_size);
auto b_indptr = std::vector<Idx>{0, 0, 3};

add_homogeneous_buffer(dataset, A::name, a_elements_per_scenario, nullptr);
add_attribute_buffer(dataset, A::name, "id", static_cast<void*>(a_id_buffer.data()));
add_attribute_buffer(dataset, A::name, "a1", static_cast<void*>(a_a1_buffer.data()));
add_inhomogeneous_buffer(dataset, B::name, b_indptr.back(), b_indptr.data(), nullptr);

check_get_individual_scenario();
}
}
}
Expand Down

0 comments on commit c7e47b2

Please sign in to comment.