Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feature/set-attribute-for-writab…
Browse files Browse the repository at this point in the history
…le-dataset' into feature/columnar-data-c-api

Signed-off-by: Martijn Govers <[email protected]>
  • Loading branch information
mgovers committed Sep 10, 2024
2 parents 71ae93a + 7b03cc3 commit 874c770
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,37 +302,16 @@ template <dataset_type_tag dataset_type_> class Dataset {
}
}

void add_attribute_buffer(std::string_view component, std::string_view attribute, Data* data) {
Idx const idx = find_component(component, true);
Buffer& buffer = buffers_[idx];
if (!is_columnar(buffer)) {
throw DatasetError{"Cannot add attribute buffers to row-based dataset!\n"};
}
if (std::ranges::find_if(buffer.attributes, [&attribute](auto const& buffer_attribute) {
return buffer_attribute.meta_attribute->name == attribute;
}) != buffer.attributes.end()) {
throw DatasetError{"Cannot have duplicated attribute buffers!\n"};
}
AttributeBuffer<Data> const attribute_buffer{
.data = data, .meta_attribute = &dataset_info_.component_info[idx].component->get_attribute(attribute)};
buffer.attributes.emplace_back(attribute_buffer);
void add_attribute_buffer(std::string_view component, std::string_view attribute, Data* data)
requires(!is_indptr_mutable_v<dataset_type>)
{
add_attribute_buffer_impl(component, attribute, data);
}

void set_attribute_buffer(std::string_view component, std::string_view attribute, Data* data)
requires is_data_mutable_v<dataset_type>
requires is_indptr_mutable_v<dataset_type>
{
Idx const idx = find_component(component, true);
Buffer& buffer = buffers_[idx];
if (!is_columnar(buffer)) {
throw DatasetError{"Cannot set attribute buffers for row-based dataset!\n"};
}
auto it = std::ranges::find_if(buffer.attributes, [&attribute](auto const& buffer_attribute) {
return buffer_attribute.meta_attribute->name == attribute;
});
if (it == buffer.attributes.end()) {
throw DatasetError{"Attribute buffer not found!\n"};
}
it->data = data;
add_attribute_buffer_impl(component, attribute, data);
}

// get buffer by component type
Expand Down Expand Up @@ -459,6 +438,22 @@ template <dataset_type_tag dataset_type_> class Dataset {
buffers_.push_back(Buffer{});
}

void add_attribute_buffer_impl(std::string_view component, std::string_view attribute, Data* data) {
Idx const idx = find_component(component, true);
Buffer& buffer = buffers_[idx];
if (!is_columnar(buffer)) {
throw DatasetError{"Cannot add attribute buffers to row-based dataset!\n"};
}
if (std::ranges::find_if(buffer.attributes, [&attribute](auto const& buffer_attribute) {
return buffer_attribute.meta_attribute->name == attribute;
}) != buffer.attributes.end()) {
throw DatasetError{"Cannot have duplicated attribute buffers!\n"};
}
AttributeBuffer<Data> attribute_buffer{
.data = data, .meta_attribute = &dataset_info_.component_info[idx].component->get_attribute(attribute)};
buffer.attributes.emplace_back(attribute_buffer);
}

template <class RangeType>
RangeType get_span_impl(RangeType const& total_range, Idx scenario, Buffer const& buffer,
ComponentInfo const& info) const {
Expand Down
8 changes: 7 additions & 1 deletion tests/cpp_unit_tests/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,13 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa
}
};
auto const add_attribute_buffer = [](DatasetType& dataset, std::string_view name, std::string_view attribute,
auto* data) { dataset.add_attribute_buffer(name, attribute, data); };
auto* data) {
if constexpr (std::same_as<DatasetType, WritableDataset>) {
dataset.set_attribute_buffer(name, attribute, data);
} else {
dataset.add_attribute_buffer(name, attribute, data);
}
};
auto const add_homogeneous_buffer = [&add_buffer](DatasetType& dataset, std::string_view name,
Idx elements_per_scenario, void* data) {
add_buffer(dataset, name, elements_per_scenario, elements_per_scenario * dataset.batch_size(), nullptr, data);
Expand Down
44 changes: 22 additions & 22 deletions tests/cpp_unit_tests/test_deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,23 +358,23 @@ TEST_CASE("Deserializer") {

auto& info = deserializer.get_dataset_info();
info.set_buffer("node", nullptr, nullptr);
info.add_attribute_buffer("node", "id", node_id.data());
info.add_attribute_buffer("node", "u_rated", node_u_rated.data());
info.set_attribute_buffer("node", "id", node_id.data());
info.set_attribute_buffer("node", "u_rated", node_u_rated.data());
info.set_buffer("line", nullptr, nullptr);
info.add_attribute_buffer("line", "id", line_id.data());
info.add_attribute_buffer("line", "r1", line_r1.data());
info.add_attribute_buffer("line", "r0", line_r0.data());
info.add_attribute_buffer("line", "x1", line_x1.data());
info.add_attribute_buffer("line", "x0", line_x0.data());
info.set_attribute_buffer("line", "id", line_id.data());
info.set_attribute_buffer("line", "r1", line_r1.data());
info.set_attribute_buffer("line", "r0", line_r0.data());
info.set_attribute_buffer("line", "x1", line_x1.data());
info.set_attribute_buffer("line", "x0", line_x0.data());
info.set_buffer("source", nullptr, nullptr);
info.add_attribute_buffer("source", "id", source_id.data());
info.add_attribute_buffer("source", "u_ref", source_u_ref.data());
info.add_attribute_buffer("source", "sk", source_sk.data());
info.add_attribute_buffer("source", "rx_ratio", source_rx_ratio.data());
info.set_attribute_buffer("source", "id", source_id.data());
info.set_attribute_buffer("source", "u_ref", source_u_ref.data());
info.set_attribute_buffer("source", "sk", source_sk.data());
info.set_attribute_buffer("source", "rx_ratio", source_rx_ratio.data());
info.set_buffer("sym_load", nullptr, nullptr);
info.add_attribute_buffer("sym_load", "id", sym_load_id.data());
info.add_attribute_buffer("sym_load", "p_specified", sym_load_p_specified.data());
info.add_attribute_buffer("sym_load", "q_specified", sym_load_q_specified.data());
info.set_attribute_buffer("sym_load", "id", sym_load_id.data());
info.set_attribute_buffer("sym_load", "p_specified", sym_load_p_specified.data());
info.set_attribute_buffer("sym_load", "q_specified", sym_load_q_specified.data());

deserializer.parse();
// check node
Expand Down Expand Up @@ -498,15 +498,15 @@ TEST_CASE("Deserializer") {

auto& info = deserializer.get_dataset_info();
info.set_buffer("sym_load", sym_load_indptr.data(), nullptr);
info.add_attribute_buffer("sym_load", "id", sym_load_id.data());
info.add_attribute_buffer("sym_load", "status", sym_load_status.data());
info.add_attribute_buffer("sym_load", "p_specified", sym_load_p_specified.data());
info.add_attribute_buffer("sym_load", "q_specified", sym_load_q_specified.data());
info.set_attribute_buffer("sym_load", "id", sym_load_id.data());
info.set_attribute_buffer("sym_load", "status", sym_load_status.data());
info.set_attribute_buffer("sym_load", "p_specified", sym_load_p_specified.data());
info.set_attribute_buffer("sym_load", "q_specified", sym_load_q_specified.data());
info.set_buffer("asym_load", nullptr, nullptr);
info.add_attribute_buffer("asym_load", "id", asym_load_id.data());
info.add_attribute_buffer("asym_load", "status", asym_load_status.data());
info.add_attribute_buffer("asym_load", "p_specified", asym_load_p_specified.data());
info.add_attribute_buffer("asym_load", "q_specified", asym_load_q_specified.data());
info.set_attribute_buffer("asym_load", "id", asym_load_id.data());
info.set_attribute_buffer("asym_load", "status", asym_load_status.data());
info.set_attribute_buffer("asym_load", "p_specified", asym_load_p_specified.data());
info.set_attribute_buffer("asym_load", "q_specified", asym_load_q_specified.data());

deserializer.parse();

Expand Down

0 comments on commit 874c770

Please sign in to comment.