From 12597a21796750aecde4dde78e99f7874aeb992b Mon Sep 17 00:00:00 2001 From: Giulio Eulisse <10544+ktf@users.noreply.github.com> Date: Thu, 12 Dec 2024 14:16:11 +0100 Subject: [PATCH] DPL Analysis: add RNTuple arrow::Dataset support As part of the changes, move the actual logic which serialises / deserialised things to plugins so that we do not need to depend or RNTuple in production code. Include an initial converter to go from AO2Ds to RNTuple based files. --- Framework/AnalysisSupport/CMakeLists.txt | 10 + .../AnalysisSupport/src/RNTuplePlugin.cxx | 825 +++++++++++++++++ Framework/AnalysisSupport/src/TTreePlugin.cxx | 862 ++++++++++++++++++ Framework/Core/CMakeLists.txt | 5 + Framework/Core/include/Framework/Plugins.h | 8 + .../include/Framework/RootArrowFilesystem.h | 174 +--- Framework/Core/src/Plugin.cxx | 71 +- Framework/Core/src/RootArrowFilesystem.cxx | 721 +-------------- Framework/Core/test/o2AO2DToAO3D.cxx | 165 ++++ Framework/Core/test/test_Root2ArrowTable.cxx | 105 ++- 10 files changed, 2118 insertions(+), 828 deletions(-) create mode 100644 Framework/AnalysisSupport/src/RNTuplePlugin.cxx create mode 100644 Framework/AnalysisSupport/src/TTreePlugin.cxx create mode 100644 Framework/Core/test/o2AO2DToAO3D.cxx diff --git a/Framework/AnalysisSupport/CMakeLists.txt b/Framework/AnalysisSupport/CMakeLists.txt index 5fb1282469711..dedbf8cb590b2 100644 --- a/Framework/AnalysisSupport/CMakeLists.txt +++ b/Framework/AnalysisSupport/CMakeLists.txt @@ -24,6 +24,16 @@ o2_add_library(FrameworkAnalysisSupport PRIVATE_INCLUDE_DIRECTORIES ${CMAKE_CURRENT_LIST_DIR}/src PUBLIC_LINK_LIBRARIES O2::Framework ${EXTRA_TARGETS} ROOT::TreePlayer) +o2_add_library(FrameworkAnalysisRNTupleSupport + SOURCES src/RNTuplePlugin.cxx + PRIVATE_INCLUDE_DIRECTORIES ${CMAKE_CURRENT_LIST_DIR}/src + PUBLIC_LINK_LIBRARIES O2::Framework ${EXTRA_TARGETS} ROOT::ROOTNTuple ROOT::ROOTNTupleUtil) + +o2_add_library(FrameworkAnalysisTTreeSupport + SOURCES src/TTreePlugin.cxx + PRIVATE_INCLUDE_DIRECTORIES ${CMAKE_CURRENT_LIST_DIR}/src + PUBLIC_LINK_LIBRARIES O2::Framework ${EXTRA_TARGETS} ROOT::TreePlayer) + o2_add_test(DataInputDirector NAME test_Framework_test_DataInputDirector SOURCES test/test_DataInputDirector.cxx COMPONENT_NAME Framework diff --git a/Framework/AnalysisSupport/src/RNTuplePlugin.cxx b/Framework/AnalysisSupport/src/RNTuplePlugin.cxx new file mode 100644 index 0000000000000..9f67785f1a069 --- /dev/null +++ b/Framework/AnalysisSupport/src/RNTuplePlugin.cxx @@ -0,0 +1,825 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#include "Framework/RuntimeError.h" +#include "Framework/RootArrowFilesystem.h" +#include "Framework/Plugins.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template class + std::unique_ptr; + +namespace o2::framework +{ + +class RNTupleFileWriteOptions : public arrow::dataset::FileWriteOptions +{ + public: + RNTupleFileWriteOptions(std::shared_ptr format) + : FileWriteOptions(format) + { + } +}; + +// A filesystem which allows me to get a RNTuple +class RNTupleFileSystem : public VirtualRootFileSystemBase +{ + public: + ~RNTupleFileSystem() override; + + std::shared_ptr GetSubFilesystem(arrow::dataset::FileSource source) override + { + return std::dynamic_pointer_cast(shared_from_this()); + }; + virtual ROOT::Experimental::RNTuple* GetRNTuple(arrow::dataset::FileSource source) = 0; +}; + +class SingleRNTupleFileSystem : public RNTupleFileSystem +{ + public: + SingleRNTupleFileSystem(ROOT::Experimental::RNTuple* tuple) + : RNTupleFileSystem(), + mTuple(tuple) + { + } + + arrow::Result GetFileInfo(std::string const& path) override; + + std::string type_name() const override + { + return "rntuple"; + } + + ROOT::Experimental::RNTuple* GetRNTuple(arrow::dataset::FileSource) override + { + // Simply return the only TTree we have + return mTuple; + } + + private: + ROOT::Experimental::RNTuple* mTuple; +}; + +arrow::Result SingleRNTupleFileSystem::GetFileInfo(std::string const& path) +{ + arrow::dataset::FileSource source(path, shared_from_this()); + arrow::fs::FileInfo result; + result.set_path(path); + result.set_type(arrow::fs::FileType::File); + return result; +} + +class RNTupleFileFragment : public arrow::dataset::FileFragment +{ + public: + RNTupleFileFragment(arrow::dataset::FileSource source, + std::shared_ptr format, + arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) + : FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema)) + { + } +}; + +class RNTupleFileFormat : public arrow::dataset::FileFormat +{ + size_t& mTotCompressedSize; + size_t& mTotUncompressedSize; + + public: + RNTupleFileFormat(size_t& totalCompressedSize, size_t& totalUncompressedSize) + : FileFormat({}), + mTotCompressedSize(totalCompressedSize), + mTotUncompressedSize(totalUncompressedSize) + { + } + + ~RNTupleFileFormat() override = default; + + std::string type_name() const override + { + return "rntuple"; + } + + bool Equals(const FileFormat& other) const override + { + return other.type_name() == this->type_name(); + } + + arrow::Result IsSupported(const arrow::dataset::FileSource& source) const override + { + auto fs = std::dynamic_pointer_cast(source.filesystem()); + auto subFs = fs->GetSubFilesystem(source); + if (std::dynamic_pointer_cast(subFs)) { + return true; + } + return false; + } + + arrow::Result> Inspect(const arrow::dataset::FileSource& source) const override; + + arrow::Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& fragment) const override; + + std::shared_ptr DefaultWriteOptions() override; + + arrow::Result> MakeWriter(std::shared_ptr destination, + std::shared_ptr schema, + std::shared_ptr options, + arrow::fs::FileLocator destination_locator) const override; + arrow::Result> MakeFragment( + arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) override; +}; + +struct RootNTupleVisitor : public ROOT::Experimental::Detail::RFieldVisitor { + void VisitArrayField(const ROOT::Experimental::RArrayField& field) override + { + int size = field.GetLength(); + RootNTupleVisitor valueVisitor{}; + auto valueField = field.GetSubFields()[0]; + valueField->AcceptVisitor(valueVisitor); + auto type = valueVisitor.datatype; + this->datatype = arrow::fixed_size_list(type, size); + } + + void VisitRVecField(const ROOT::Experimental::RRVecField& field) override + { + RootNTupleVisitor valueVisitor{}; + auto valueField = field.GetSubFields()[0]; + valueField->AcceptVisitor(valueVisitor); + auto type = valueVisitor.datatype; + this->datatype = arrow::list(type); + } + + void VisitField(const ROOT::Experimental::RFieldBase& field) override + { + throw o2::framework::runtime_error_f("Unknown field %s with type %s", field.GetFieldName().c_str(), field.GetTypeName().c_str()); + } + + void VisitIntField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::int32(); + } + + void VisitBoolField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::boolean(); + } + + void VisitFloatField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::float32(); + } + + void VisitDoubleField(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::float64(); + } + std::shared_ptr datatype; +}; +} // namespace o2::framework + +auto arrowTypeFromRNTuple(ROOT::Experimental::RFieldBase const& field, int size) +{ + o2::framework::RootNTupleVisitor visitor; + field.AcceptVisitor(visitor); + return visitor.datatype; +} + +namespace o2::framework +{ +std::unique_ptr rootFieldFromArrow(std::shared_ptr field, std::string name) +{ + using namespace ROOT::Experimental; + switch (field->type()->id()) { + case arrow::Type::BOOL: + return std::make_unique>(name); + case arrow::Type::UINT8: + return std::make_unique>(name); + case arrow::Type::UINT16: + return std::make_unique>(name); + case arrow::Type::UINT32: + return std::make_unique>(name); + case arrow::Type::UINT64: + return std::make_unique>(name); + case arrow::Type::INT8: + return std::make_unique>(name); + case arrow::Type::INT16: + return std::make_unique>(name); + case arrow::Type::INT32: + return std::make_unique>(name); + case arrow::Type::INT64: + return std::make_unique>(name); + case arrow::Type::FLOAT: + return std::make_unique>(name); + case arrow::Type::DOUBLE: + return std::make_unique>(name); + default: + throw runtime_error("Unsupported arrow column type"); + } +} + +class RNTupleFileWriter : public arrow::dataset::FileWriter +{ + std::shared_ptr mWriter; + bool firstBatch = true; + std::vector> valueArrays; + std::vector> valueTypes; + std::vector valueCount; + + public: + RNTupleFileWriter(std::shared_ptr schema, std::shared_ptr options, + std::shared_ptr destination, + arrow::fs::FileLocator destination_locator) + : FileWriter(schema, options, destination, destination_locator) + { + using namespace ROOT::Experimental; + + auto model = RNTupleModel::CreateBare(); + // Let's create a model from the physical schema + for (auto i = 0u; i < schema->fields().size(); ++i) { + auto& field = schema->field(i); + + // Construct all the needed branches. + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + auto list = std::static_pointer_cast(field->type()); + auto valueField = field->type()->field(0); + model->AddField(std::make_unique(field->name(), rootFieldFromArrow(valueField, "_0"), list->list_size())); + } break; + case arrow::Type::LIST: { + auto valueField = field->type()->field(0); + model->AddField(std::make_unique(field->name(), rootFieldFromArrow(valueField, "_0"))); + } break; + default: { + model->AddField(rootFieldFromArrow(field, field->name())); + } break; + } + } + auto fileStream = std::dynamic_pointer_cast(destination_); + auto* file = dynamic_cast(fileStream->GetDirectory()); + mWriter = RNTupleWriter::Append(std::move(model), destination_locator_.path, *file, {}); + } + + arrow::Status Write(const std::shared_ptr& batch) override + { + if (firstBatch) { + firstBatch = false; + } + + // Support writing empty tables + if (batch->columns().empty() || batch->num_rows() == 0) { + return arrow::Status::OK(); + } + + for (auto i = 0u; i < batch->columns().size(); ++i) { + auto column = batch->column(i); + auto& field = batch->schema()->field(i); + + valueArrays.push_back(nullptr); + valueTypes.push_back(nullptr); + valueCount.push_back(1); + + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + auto list = std::static_pointer_cast(column); + auto listType = std::static_pointer_cast(field->type()); + if (field->type()->field(0)->type()->id() == arrow::Type::BOOL) { + auto boolArray = std::static_pointer_cast(list->values()); + int64_t length = boolArray->length(); + arrow::UInt8Builder builder; + auto ok = builder.Reserve(length); + + for (int64_t i = 0; i < length; ++i) { + if (boolArray->IsValid(i)) { + // Expand each boolean value (true/false) to uint8 (1/0) + uint8_t value = boolArray->Value(i) ? 1 : 0; + auto ok = builder.Append(value); + } else { + // Append null for invalid entries + auto ok = builder.AppendNull(); + } + } + valueArrays.back() = *builder.Finish(); + valueTypes.back() = valueArrays.back()->type(); + } else { + valueArrays.back() = list->values(); + valueTypes.back() = field->type()->field(0)->type(); + } + valueCount.back() = listType->list_size(); + } break; + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(column); + valueArrays.back() = list; + valueTypes.back() = field->type()->field(0)->type(); + valueCount.back() = -1; + } break; + case arrow::Type::BOOL: { + // We unpack the array + auto boolArray = std::static_pointer_cast(column); + int64_t length = boolArray->length(); + arrow::UInt8Builder builder; + auto ok = builder.Reserve(length); + + for (int64_t i = 0; i < length; ++i) { + if (boolArray->IsValid(i)) { + // Expand each boolean value (true/false) to uint8 (1/0) + uint8_t value = boolArray->Value(i) ? 1 : 0; + auto ok = builder.Append(value); + } else { + // Append null for invalid entries + auto ok = builder.AppendNull(); + } + } + valueArrays.back() = *builder.Finish(); + valueTypes.back() = valueArrays.back()->type(); + } break; + default: + valueArrays.back() = column; + valueTypes.back() = field->type(); + break; + } + } + + int64_t pos = 0; + + auto entry = mWriter->CreateEntry(); + std::vector tokens; + tokens.reserve(batch->num_columns()); + std::vector typeIds; + typeIds.reserve(batch->num_columns()); + + for (size_t ci = 0; ci < batch->num_columns(); ++ci) { + auto& field = batch->schema()->field(ci); + typeIds.push_back(batch->column(ci)->type()->id()); + tokens.push_back(entry->GetToken(field->name())); + } + + while (pos < batch->num_rows()) { + for (size_t ci = 0; ci < batch->num_columns(); ++ci) { + auto typeId = typeIds[ci]; + auto token = tokens[ci]; + + switch (typeId) { + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(valueArrays[ci]); + auto value_slice = list->value_slice(pos); + + valueCount[ci] = value_slice->length(); + auto bindValue = [&vc = valueCount, ci, token](auto array, std::unique_ptr& entry) -> void { + using value_type = std::decay_t::value_type; + auto v = std::make_shared>((value_type*)array->raw_values(), vc[ci]); + entry->BindValue(token, v); + }; + switch (valueTypes[ci]->id()) { + case arrow::Type::FLOAT: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::DOUBLE: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT8: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT16: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT32: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::INT64: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT8: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT16: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT32: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + case arrow::Type::UINT64: { + bindValue(std::static_pointer_cast(value_slice), entry); + } break; + default: { + throw runtime_error("Unsupported kind of VLA"); + } break; + } + } break; + case arrow::Type::FIXED_SIZE_LIST: { + entry->BindRawPtr(token, (void*)(valueArrays[ci]->data()->buffers[1]->data() + pos * valueCount[ci] * valueTypes[ci]->byte_width())); + } break; + case arrow::Type::BOOL: { + // Not sure we actually need this + entry->BindRawPtr(token, (bool*)(valueArrays[ci]->data()->buffers[1]->data() + pos * 1)); + } break; + default: + // By default we consider things scalars. + entry->BindRawPtr(token, (void*)(valueArrays[ci]->data()->buffers[1]->data() + pos * valueTypes[ci]->byte_width())); + break; + } + } + mWriter->Fill(*entry); + ++pos; + } + // mWriter->CommitCluster(); + + return arrow::Status::OK(); + } + + arrow::Future<> + FinishInternal() override + { + return {}; + }; +}; + +arrow::Result> RNTupleFileFormat::Inspect(const arrow::dataset::FileSource& source) const +{ + + auto fs = std::dynamic_pointer_cast(source.filesystem()); + // Actually get the TTree from the ROOT file. + auto ntupleFs = std::dynamic_pointer_cast(fs->GetSubFilesystem(source)); + if (!ntupleFs.get()) { + throw runtime_error_f("Unknown filesystem %s\n", source.filesystem()->type_name().c_str()); + } + ROOT::Experimental::RNTuple* rntuple = ntupleFs->GetRNTuple(source); + + auto inspector = ROOT::Experimental::RNTupleInspector::Create(rntuple); + + auto reader = ROOT::Experimental::RNTupleReader::Open(rntuple); + + auto& tupleField0 = reader->GetModel().GetFieldZero(); + std::vector> fields; + for (auto& tupleField : tupleField0.GetSubFields()) { + auto field = std::make_shared(tupleField->GetFieldName(), arrowTypeFromRNTuple(*tupleField, tupleField->GetValueSize())); + fields.push_back(field); + } + + return std::make_shared(fields); +} + +arrow::Result RNTupleFileFormat::ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& fragment) const +{ + auto dataset_schema = options->dataset_schema; + auto ntupleFragment = std::dynamic_pointer_cast(fragment); + + auto generator = [pool = options->pool, ntupleFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize, + &totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future> { + using namespace ROOT::Experimental; + std::vector> columns; + std::vector> fields = dataset_schema->fields(); + + auto containerFS = std::dynamic_pointer_cast(ntupleFragment->source().filesystem()); + auto fs = std::dynamic_pointer_cast(containerFS->GetSubFilesystem(ntupleFragment->source())); + + int64_t rows = -1; + ROOT::Experimental::RNTuple* rntuple = fs->GetRNTuple(ntupleFragment->source()); + auto reader = ROOT::Experimental::RNTupleReader::Open(rntuple); + auto& model = reader->GetModel(); + for (auto& physicalField : fields) { + auto bulk = model.CreateBulk(physicalField->name()); + + auto listType = std::dynamic_pointer_cast(physicalField->type()); + + auto& descriptor = reader->GetDescriptor(); + auto totalEntries = reader->GetNEntries(); + + if (rows == -1) { + rows = totalEntries; + } + if (rows != totalEntries) { + throw runtime_error_f("Unmatching number of rows for branch %s", physicalField->name().c_str()); + } + arrow::Status status; + int readEntries = 0; + std::shared_ptr array; + if (physicalField->type() == arrow::boolean() || + (listType && physicalField->type()->field(0)->type() == arrow::boolean())) { + if (listType) { + std::unique_ptr builder = nullptr; + auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder); + if (!status.ok()) { + throw runtime_error("Cannot create value builder"); + } + auto listBuilder = std::make_unique(pool, std::move(builder), listType->list_size()); + auto valueBuilder = listBuilder.get()->value_builder(); + // boolean array special case: we need to use builder to create the bitmap + status = valueBuilder->Reserve(totalEntries * listType->list_size()); + status &= listBuilder->Reserve(totalEntries); + if (!status.ok()) { + throw runtime_error("Failed to reserve memory for array builder"); + } + auto clusterIt = descriptor.FindClusterId(0, 0); + // No adoption for now... + // bulk.AdoptBuffer(buffer, totalEntries) + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* ptr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + int readLast = index.GetNEntries(); + readEntries += readLast; + status &= static_cast(valueBuilder)->AppendValues(reinterpret_cast(ptr), readLast * listType->list_size()); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + status &= static_cast(listBuilder.get())->AppendValues(readEntries); + if (!status.ok()) { + throw runtime_error("Failed to append values to array"); + } + status &= listBuilder->Finish(&array); + if (!status.ok()) { + throw runtime_error("Failed to create array"); + } + } else if (listType == nullptr) { + std::unique_ptr builder = nullptr; + auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder); + if (!status.ok()) { + throw runtime_error("Cannot create builder"); + } + auto valueBuilder = static_cast(builder.get()); + // boolean array special case: we need to use builder to create the bitmap + status = valueBuilder->Reserve(totalEntries); + if (!status.ok()) { + throw runtime_error("Failed to reserve memory for array builder"); + } + auto clusterIt = descriptor.FindClusterId(0, 0); + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* ptr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + int readLast = index.GetNEntries(); + readEntries += readLast; + status &= valueBuilder->AppendValues(reinterpret_cast(ptr), readLast); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + if (!status.ok()) { + throw runtime_error("Failed to append values to array"); + } + status &= valueBuilder->Finish(&array); + if (!status.ok()) { + throw runtime_error("Failed to create array"); + } + } + } else { + // other types: use serialized read to build arrays directly. + auto typeSize = physicalField->type()->byte_width(); + // FIXME: for now... + auto bytes = 0; + auto branchSize = bytes ? bytes : 1000000; + auto&& result = arrow::AllocateResizableBuffer(branchSize, pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate values buffer"); + } + std::shared_ptr arrowValuesBuffer = std::move(result).ValueUnsafe(); + auto ptr = arrowValuesBuffer->mutable_data(); + if (ptr == nullptr) { + throw runtime_error("Invalid buffer"); + } + + std::unique_ptr offsetBuffer = nullptr; + + std::shared_ptr arrowOffsetBuffer; + std::span offsets; + int size = 0; + uint32_t totalSize = 0; + int64_t listSize = 1; + if (auto fixedSizeList = std::dynamic_pointer_cast(physicalField->type())) { + listSize = fixedSizeList->list_size(); + typeSize = fixedSizeList->field(0)->type()->byte_width(); + auto clusterIt = descriptor.FindClusterId(0, 0); + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* inPtr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + + int readLast = index.GetNEntries(); + if (listSize == -1) { + size = offsets[readEntries + readLast] - offsets[readEntries]; + } else { + size = readLast * listSize; + } + readEntries += readLast; + memcpy(ptr, inPtr, size * typeSize); + ptr += (ptrdiff_t)(size * typeSize); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + } else if (auto vlaListType = std::dynamic_pointer_cast(physicalField->type())) { + listSize = -1; + typeSize = vlaListType->field(0)->type()->byte_width(); + offsetBuffer = std::make_unique(TBuffer::EMode::kWrite, 4 * 1024 * 1024); + result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate offset buffer"); + } + arrowOffsetBuffer = std::move(result).ValueUnsafe(); + + // Offset bulk + auto offsetBulk = model.CreateBulk(physicalField->name()); + // Actual values are in a different place... + bulk = model.CreateBulk(physicalField->name()); + auto clusterIt = descriptor.FindClusterId(0, 0); + auto* ptrOffset = reinterpret_cast(arrowOffsetBuffer->mutable_data()); + auto* tPtrOffset = reinterpret_cast(ptrOffset); + offsets = std::span{tPtrOffset, tPtrOffset + totalEntries + 1}; + + auto copyOffsets = [&arrowValuesBuffer, &pool, &ptrOffset, &ptr, &totalSize](auto inPtr, size_t total) { + using value_type = typename std::decay_t::value_type; + for (size_t i = 0; i < total; i++) { + *ptrOffset++ = totalSize; + totalSize += inPtr[i].size(); + } + *ptrOffset = totalSize; + auto&& result = arrow::AllocateResizableBuffer(totalSize * sizeof(value_type), pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate values buffer"); + } + arrowValuesBuffer = std::move(result).ValueUnsafe(); + ptr = (uint8_t*)(arrowValuesBuffer->mutable_data()); + // Calculate the size of the buffer here. + for (size_t i = 0; i < total; i++) { + int vlaSizeInBytes = inPtr[i].size() * sizeof(value_type); + if (vlaSizeInBytes == 0) { + continue; + } + memcpy(ptr, inPtr[i].data(), vlaSizeInBytes); + ptr += vlaSizeInBytes; + } + }; + + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + int readLast = index.GetNEntries(); + switch (vlaListType->field(0)->type()->id()) { + case arrow::Type::FLOAT: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::DOUBLE: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT8: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT16: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT32: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::INT64: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT8: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT16: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT32: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + case arrow::Type::UINT64: { + copyOffsets((ROOT::Internal::VecOps::RVec*)offsetBulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()), readLast); + } break; + default: { + throw runtime_error("Unsupported kind of VLA"); + } break; + } + + readEntries += readLast; + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + } else { + auto clusterIt = descriptor.FindClusterId(0, 0); + while (clusterIt != kInvalidDescriptorId) { + auto& index = descriptor.GetClusterDescriptor(clusterIt); + auto mask = std::make_unique(index.GetNEntries()); + std::fill(mask.get(), mask.get() + index.GetNEntries(), true); + void* inPtr = bulk.ReadBulk(RClusterIndex(clusterIt, index.GetFirstEntryIndex()), mask.get(), index.GetNEntries()); + + int readLast = index.GetNEntries(); + if (listSize == -1) { + size = offsets[readEntries + readLast] - offsets[readEntries]; + } else { + size = readLast * listSize; + } + readEntries += readLast; + memcpy(ptr, inPtr, size * typeSize); + ptr += (ptrdiff_t)(size * typeSize); + clusterIt = descriptor.FindNextClusterId(clusterIt); + } + } + switch (listSize) { + case -1: { + auto varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); + array = std::make_shared(physicalField->type(), readEntries, arrowOffsetBuffer, varray); + } break; + case 1: { + totalSize = readEntries * listSize; + array = std::make_shared(physicalField->type(), readEntries, arrowValuesBuffer); + + } break; + default: { + totalSize = readEntries * listSize; + auto varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); + array = std::make_shared(physicalField->type(), readEntries, varray); + } + } + } + columns.push_back(array); + } + + auto batch = arrow::RecordBatch::Make(dataset_schema, rows, columns); + return batch; + }; + + return generator; +} + +arrow::Result> RNTupleFileFormat::MakeWriter(std::shared_ptr destination, + std::shared_ptr schema, + std::shared_ptr options, + arrow::fs::FileLocator destination_locator) const +{ + auto writer = std::make_shared(schema, options, destination, destination_locator); + return std::dynamic_pointer_cast(writer); +} + +arrow::Result> RNTupleFileFormat::MakeFragment( + arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) +{ + std::shared_ptr format = std::make_shared(mTotCompressedSize, mTotUncompressedSize); + + auto fragment = std::make_shared(std::move(source), std::move(format), + std::move(partition_expression), + std::move(physical_schema)); + return std::dynamic_pointer_cast(fragment); +} + +RNTupleFileSystem::~RNTupleFileSystem() = default; + +std::shared_ptr + RNTupleFileFormat::DefaultWriteOptions() +{ + return std::make_shared(shared_from_this()); +} + +struct RNTuplePluginContext { + size_t totalCompressedSize = 0; + size_t totalUncompressedSize = 0; + std::shared_ptr format = nullptr; +}; + +struct RNTupleObjectReadingImplementation : public RootArrowFactoryPlugin { + RootArrowFactory* create() override + { + auto context = new RNTuplePluginContext; + context->format = std::make_shared(context->totalCompressedSize, context->totalUncompressedSize); + return new RootArrowFactory{ + .options = [context]() { return context->format->DefaultWriteOptions(); }, + .format = [context]() { return context->format; }, + .getSubFilesystem = [](void* handle) { + auto rntuple = (ROOT::Experimental::RNTuple*)handle; + return std::shared_ptr(new SingleRNTupleFileSystem(rntuple)); }, + }; + } +}; + +DEFINE_DPL_PLUGINS_BEGIN +DEFINE_DPL_PLUGIN_INSTANCE(RNTupleObjectReadingImplementation, RootObjectReadingImplementation); +DEFINE_DPL_PLUGINS_END +} // namespace o2::framework diff --git a/Framework/AnalysisSupport/src/TTreePlugin.cxx b/Framework/AnalysisSupport/src/TTreePlugin.cxx new file mode 100644 index 0000000000000..e376ed8b96268 --- /dev/null +++ b/Framework/AnalysisSupport/src/TTreePlugin.cxx @@ -0,0 +1,862 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +#include "Framework/RootArrowFilesystem.h" +#include "Framework/Plugins.h" +#include "Framework/Signpost.h" +#include "Framework/Endian.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +O2_DECLARE_DYNAMIC_LOG(root_arrow_fs); + +namespace o2::framework +{ + +class TTreeFileWriteOptions : public arrow::dataset::FileWriteOptions +{ + public: + TTreeFileWriteOptions(std::shared_ptr format) + : FileWriteOptions(format) + { + } +}; + +// A filesystem which allows me to get a TTree +class TTreeFileSystem : public VirtualRootFileSystemBase +{ + public: + ~TTreeFileSystem() override; + + std::shared_ptr GetSubFilesystem(arrow::dataset::FileSource source) override + { + return std::dynamic_pointer_cast(shared_from_this()); + }; + + arrow::Result> OpenOutputStream( + const std::string& path, + const std::shared_ptr& metadata) override; + + virtual TTree* GetTree(arrow::dataset::FileSource source) = 0; +}; + +class SingleTreeFileSystem : public TTreeFileSystem +{ + public: + SingleTreeFileSystem(TTree* tree) + : TTreeFileSystem(), + mTree(tree) + { + } + + arrow::Result GetFileInfo(std::string const& path) override; + + std::string type_name() const override + { + return "ttree"; + } + + TTree* GetTree(arrow::dataset::FileSource) override + { + // Simply return the only TTree we have + return mTree; + } + + private: + TTree* mTree; +}; + +arrow::Result SingleTreeFileSystem::GetFileInfo(std::string const& path) +{ + arrow::dataset::FileSource source(path, shared_from_this()); + arrow::fs::FileInfo result; + result.set_path(path); + result.set_type(arrow::fs::FileType::File); + return result; +} + +class TTreeFileFragment : public arrow::dataset::FileFragment +{ + public: + TTreeFileFragment(arrow::dataset::FileSource source, + std::shared_ptr format, + arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) + : FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema)) + { + } +}; + +class TTreeFileFormat : public arrow::dataset::FileFormat +{ + size_t& mTotCompressedSize; + size_t& mTotUncompressedSize; + + public: + TTreeFileFormat(size_t& totalCompressedSize, size_t& totalUncompressedSize) + : FileFormat({}), + mTotCompressedSize(totalCompressedSize), + mTotUncompressedSize(totalUncompressedSize) + { + } + + ~TTreeFileFormat() override = default; + + std::string type_name() const override + { + return "ttree"; + } + + bool Equals(const FileFormat& other) const override + { + return other.type_name() == this->type_name(); + } + + arrow::Result IsSupported(const arrow::dataset::FileSource& source) const override + { + auto fs = std::dynamic_pointer_cast(source.filesystem()); + auto subFs = fs->GetSubFilesystem(source); + if (std::dynamic_pointer_cast(subFs)) { + return true; + } + return false; + } + + arrow::Result> Inspect(const arrow::dataset::FileSource& source) const override; + /// \brief Create a FileFragment for a FileSource. + arrow::Result> MakeFragment( + arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) override; + + arrow::Result> MakeWriter(std::shared_ptr destination, std::shared_ptr schema, std::shared_ptr options, arrow::fs::FileLocator destination_locator) const override; + + std::shared_ptr DefaultWriteOptions() override; + + arrow::Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& fragment) const override; +}; + +// An arrow outputstream which allows to write to a TTree. Eventually +// with a prefix for the branches. +class TTreeOutputStream : public arrow::io::OutputStream +{ + public: + TTreeOutputStream(TTree*, std::string branchPrefix); + + arrow::Status Close() override; + + arrow::Result Tell() const override; + + arrow::Status Write(const void* data, int64_t nbytes) override; + + bool closed() const override; + + TBranch* CreateBranch(char const* branchName, char const* sizeBranch); + + TTree* GetTree() + { + return mTree; + } + + private: + TTree* mTree; + std::string mBranchPrefix; +}; + +// An arrow outputstream which allows to write to a ttree +// @a branch prefix is to be used to identify a set of branches which all belong to +// the same table. +TTreeOutputStream::TTreeOutputStream(TTree* f, std::string branchPrefix) + : mTree(f), + mBranchPrefix(std::move(branchPrefix)) +{ +} + +arrow::Status TTreeOutputStream::Close() +{ + if (mTree->GetCurrentFile() == nullptr) { + return arrow::Status::Invalid("Cannot close a tree not attached to a file"); + } + mTree->GetCurrentFile()->Close(); + return arrow::Status::OK(); +} + +arrow::Result TTreeOutputStream::Tell() const +{ + return arrow::Result(arrow::Status::NotImplemented("Cannot move")); +} + +arrow::Status TTreeOutputStream::Write(const void* data, int64_t nbytes) +{ + return arrow::Status::NotImplemented("Cannot write raw bytes to a TTree"); +} + +bool TTreeOutputStream::closed() const +{ + // A standalone tree is never closed. + if (mTree->GetCurrentFile() == nullptr) { + return false; + } + return mTree->GetCurrentFile()->IsOpen() == false; +} + +TBranch* TTreeOutputStream::CreateBranch(char const* branchName, char const* sizeBranch) +{ + return mTree->Branch((mBranchPrefix + "/" + branchName).c_str(), (char*)nullptr, (mBranchPrefix + sizeBranch).c_str()); +} + +struct TTreePluginContext { + size_t totalCompressedSize = 0; + size_t totalUncompressedSize = 0; + std::shared_ptr format = nullptr; +}; + +struct TTreeObjectReadingImplementation : public RootArrowFactoryPlugin { + RootArrowFactory* create() override + { + auto context = new TTreePluginContext; + context->format = std::make_shared(context->totalCompressedSize, context->totalUncompressedSize); + return new RootArrowFactory{ + .options = [context]() { return context->format->DefaultWriteOptions(); }, + .format = [context]() { return context->format; }, + .getSubFilesystem = [](void* handle) { + auto tree = (TTree*)handle; + return std::shared_ptr(new SingleTreeFileSystem(tree)); }, + }; + } +}; + +arrow::Result TTreeFileFormat::ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& fragment) const +{ + // Get the fragment as a TTreeFragment. This might be PART of a TTree. + auto treeFragment = std::dynamic_pointer_cast(fragment); + // This is the schema we want to read + auto dataset_schema = options->dataset_schema; + + auto generator = [pool = options->pool, treeFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize, + &totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future> { + auto schema = treeFragment->format()->Inspect(treeFragment->source()); + + std::vector> columns; + std::vector> fields = dataset_schema->fields(); + auto physical_schema = *treeFragment->ReadPhysicalSchema(); + + static TBufferFile buffer{TBuffer::EMode::kWrite, 4 * 1024 * 1024}; + auto containerFS = std::dynamic_pointer_cast(treeFragment->source().filesystem()); + auto fs = std::dynamic_pointer_cast(containerFS->GetSubFilesystem(treeFragment->source())); + + int64_t rows = -1; + TTree* tree = fs->GetTree(treeFragment->source()); + for (auto& field : fields) { + // The field actually on disk + auto physicalField = physical_schema->GetFieldByName(field->name()); + TBranch* branch = tree->GetBranch(physicalField->name().c_str()); + assert(branch); + buffer.Reset(); + auto totalEntries = branch->GetEntries(); + if (rows == -1) { + rows = totalEntries; + } + if (rows != totalEntries) { + throw runtime_error_f("Unmatching number of rows for branch %s", branch->GetName()); + } + arrow::Status status; + int readEntries = 0; + std::shared_ptr array; + auto listType = std::dynamic_pointer_cast(physicalField->type()); + if (physicalField->type() == arrow::boolean() || + (listType && physicalField->type()->field(0)->type() == arrow::boolean())) { + if (listType) { + std::unique_ptr builder = nullptr; + auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder); + if (!status.ok()) { + throw runtime_error("Cannot create value builder"); + } + auto listBuilder = std::make_unique(pool, std::move(builder), listType->list_size()); + auto valueBuilder = listBuilder.get()->value_builder(); + // boolean array special case: we need to use builder to create the bitmap + status = valueBuilder->Reserve(totalEntries * listType->list_size()); + status &= listBuilder->Reserve(totalEntries); + if (!status.ok()) { + throw runtime_error("Failed to reserve memory for array builder"); + } + while (readEntries < totalEntries) { + auto readLast = branch->GetBulkRead().GetBulkEntries(readEntries, buffer); + readEntries += readLast; + status &= static_cast(valueBuilder)->AppendValues(reinterpret_cast(buffer.GetCurrent()), readLast * listType->list_size()); + } + status &= static_cast(listBuilder.get())->AppendValues(readEntries); + if (!status.ok()) { + throw runtime_error("Failed to append values to array"); + } + status &= listBuilder->Finish(&array); + if (!status.ok()) { + throw runtime_error("Failed to create array"); + } + } else if (listType == nullptr) { + std::unique_ptr builder = nullptr; + auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder); + if (!status.ok()) { + throw runtime_error("Cannot create builder"); + } + auto valueBuilder = static_cast(builder.get()); + // boolean array special case: we need to use builder to create the bitmap + status = valueBuilder->Reserve(totalEntries); + if (!status.ok()) { + throw runtime_error("Failed to reserve memory for array builder"); + } + while (readEntries < totalEntries) { + auto readLast = branch->GetBulkRead().GetBulkEntries(readEntries, buffer); + readEntries += readLast; + status &= valueBuilder->AppendValues(reinterpret_cast(buffer.GetCurrent()), readLast); + } + if (!status.ok()) { + throw runtime_error("Failed to append values to array"); + } + status &= valueBuilder->Finish(&array); + if (!status.ok()) { + throw runtime_error("Failed to create array"); + } + } + } else { + // other types: use serialized read to build arrays directly. + auto typeSize = physicalField->type()->byte_width(); + // This is needed for branches which have not been persisted. + auto bytes = branch->GetTotBytes(); + auto branchSize = bytes ? bytes : 1000000; + auto&& result = arrow::AllocateResizableBuffer(branchSize, pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate values buffer"); + } + std::shared_ptr arrowValuesBuffer = std::move(result).ValueUnsafe(); + auto ptr = arrowValuesBuffer->mutable_data(); + if (ptr == nullptr) { + throw runtime_error("Invalid buffer"); + } + + std::unique_ptr offsetBuffer = nullptr; + + uint32_t offset = 0; + int count = 0; + std::shared_ptr arrowOffsetBuffer; + std::span offsets; + int size = 0; + uint32_t totalSize = 0; + TBranch* mSizeBranch = nullptr; + int64_t listSize = 1; + if (auto fixedSizeList = std::dynamic_pointer_cast(physicalField->type())) { + listSize = fixedSizeList->list_size(); + typeSize = fixedSizeList->field(0)->type()->byte_width(); + } else if (auto vlaListType = std::dynamic_pointer_cast(physicalField->type())) { + listSize = -1; + typeSize = vlaListType->field(0)->type()->byte_width(); + } + if (listSize == -1) { + mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str()); + offsetBuffer = std::make_unique(TBuffer::EMode::kWrite, 4 * 1024 * 1024); + result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool); + if (!result.ok()) { + throw runtime_error("Cannot allocate offset buffer"); + } + arrowOffsetBuffer = std::move(result).ValueUnsafe(); + unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data(); + auto* tPtrOffset = reinterpret_cast(ptrOffset); + offsets = std::span{tPtrOffset, tPtrOffset + totalEntries + 1}; + + // read sizes first + while (readEntries < totalEntries) { + auto readLast = mSizeBranch->GetBulkRead().GetEntriesSerialized(readEntries, *offsetBuffer); + readEntries += readLast; + for (auto i = 0; i < readLast; ++i) { + offsets[count++] = (int)offset; + offset += swap32_(reinterpret_cast(offsetBuffer->GetCurrent())[i]); + } + } + offsets[count] = (int)offset; + totalSize = offset; + readEntries = 0; + } + + while (readEntries < totalEntries) { + auto readLast = branch->GetBulkRead().GetEntriesSerialized(readEntries, buffer); + if (listSize == -1) { + size = offsets[readEntries + readLast] - offsets[readEntries]; + } else { + size = readLast * listSize; + } + readEntries += readLast; + swapCopy(ptr, buffer.GetCurrent(), size, typeSize); + ptr += (ptrdiff_t)(size * typeSize); + } + if (listSize >= 1) { + totalSize = readEntries * listSize; + } + std::shared_ptr varray; + switch (listSize) { + case -1: + varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); + array = std::make_shared(physicalField->type(), readEntries, arrowOffsetBuffer, varray); + break; + case 1: + array = std::make_shared(physicalField->type(), readEntries, arrowValuesBuffer); + break; + default: + varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); + array = std::make_shared(physicalField->type(), readEntries, varray); + } + } + + branch->SetStatus(false); + branch->DropBaskets("all"); + branch->Reset(); + branch->GetTransientBuffer(0)->Expand(0); + + columns.push_back(array); + } + auto batch = arrow::RecordBatch::Make(dataset_schema, rows, columns); + totalCompressedSize += tree->GetZipBytes(); + totalUncompressedSize += tree->GetTotBytes(); + return batch; + }; + return generator; +} + +char const* rootSuffixFromArrow(arrow::Type::type id) +{ + switch (id) { + case arrow::Type::BOOL: + return "/O"; + case arrow::Type::UINT8: + return "/b"; + case arrow::Type::UINT16: + return "/s"; + case arrow::Type::UINT32: + return "/i"; + case arrow::Type::UINT64: + return "/l"; + case arrow::Type::INT8: + return "/B"; + case arrow::Type::INT16: + return "/S"; + case arrow::Type::INT32: + return "/I"; + case arrow::Type::INT64: + return "/L"; + case arrow::Type::FLOAT: + return "/F"; + case arrow::Type::DOUBLE: + return "/D"; + default: + throw runtime_error("Unsupported arrow column type"); + } +} + +arrow::Result> TTreeFileSystem::OpenOutputStream( + const std::string& path, + const std::shared_ptr& metadata) +{ + arrow::dataset::FileSource source{path, shared_from_this()}; + auto prefix = metadata->Get("branch_prefix"); + if (prefix.ok()) { + return std::make_shared(GetTree(source), *prefix); + } + return std::make_shared(GetTree(source), ""); +} + +namespace +{ +struct BranchInfo { + std::string name; + TBranch* ptr; + bool mVLA; +}; +} // namespace + +auto arrowTypeFromROOT(EDataType type, int size) +{ + auto typeGenerator = [](std::shared_ptr const& type, int size) -> std::shared_ptr { + switch (size) { + case -1: + return arrow::list(type); + case 1: + return std::move(type); + default: + return arrow::fixed_size_list(type, size); + } + }; + + switch (type) { + case EDataType::kBool_t: + return typeGenerator(arrow::boolean(), size); + case EDataType::kUChar_t: + return typeGenerator(arrow::uint8(), size); + case EDataType::kUShort_t: + return typeGenerator(arrow::uint16(), size); + case EDataType::kUInt_t: + return typeGenerator(arrow::uint32(), size); + case EDataType::kULong64_t: + return typeGenerator(arrow::uint64(), size); + case EDataType::kChar_t: + return typeGenerator(arrow::int8(), size); + case EDataType::kShort_t: + return typeGenerator(arrow::int16(), size); + case EDataType::kInt_t: + return typeGenerator(arrow::int32(), size); + case EDataType::kLong64_t: + return typeGenerator(arrow::int64(), size); + case EDataType::kFloat_t: + return typeGenerator(arrow::float32(), size); + case EDataType::kDouble_t: + return typeGenerator(arrow::float64(), size); + default: + throw o2::framework::runtime_error_f("Unsupported branch type: %d", static_cast(type)); + } +} + +arrow::Result> TTreeFileFormat::Inspect(const arrow::dataset::FileSource& source) const +{ + arrow::Schema schema{{}}; + auto fs = std::dynamic_pointer_cast(source.filesystem()); + // Actually get the TTree from the ROOT file. + auto treeFs = std::dynamic_pointer_cast(fs->GetSubFilesystem(source)); + if (!treeFs.get()) { + throw runtime_error_f("Unknown filesystem %s\n", source.filesystem()->type_name().c_str()); + } + TTree* tree = treeFs->GetTree(source); + + auto branches = tree->GetListOfBranches(); + auto n = branches->GetEntries(); + + std::vector branchInfos; + for (auto i = 0; i < n; ++i) { + auto branch = static_cast(branches->At(i)); + auto name = std::string{branch->GetName()}; + auto pos = name.find("_size"); + if (pos != std::string::npos) { + name.erase(pos); + branchInfos.emplace_back(BranchInfo{name, (TBranch*)nullptr, true}); + } else { + auto lookup = std::find_if(branchInfos.begin(), branchInfos.end(), [&](BranchInfo const& bi) { + return bi.name == name; + }); + if (lookup == branchInfos.end()) { + branchInfos.emplace_back(BranchInfo{name, branch, false}); + } else { + lookup->ptr = branch; + } + } + } + + std::vector> fields; + tree->SetCacheSize(25000000); + for (auto& bi : branchInfos) { + static TClass* cls; + EDataType type; + bi.ptr->GetExpectedType(cls, type); + auto listSize = -1; + if (!bi.mVLA) { + listSize = static_cast(bi.ptr->GetListOfLeaves()->At(0))->GetLenStatic(); + } + auto field = std::make_shared(bi.ptr->GetName(), arrowTypeFromROOT(type, listSize)); + fields.push_back(field); + + tree->AddBranchToCache(bi.ptr); + if (strncmp(bi.ptr->GetName(), "fIndexArray", strlen("fIndexArray")) == 0) { + std::string sizeBranchName = bi.ptr->GetName(); + sizeBranchName += "_size"; + auto* sizeBranch = (TBranch*)tree->GetBranch(sizeBranchName.c_str()); + if (sizeBranch) { + tree->AddBranchToCache(sizeBranch); + } + } + } + tree->StopCacheLearningPhase(); + + return std::make_shared(fields); +} + +/// \brief Create a FileFragment for a FileSource. +arrow::Result> TTreeFileFormat::MakeFragment( + arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, + std::shared_ptr physical_schema) +{ + std::shared_ptr format = std::make_shared(mTotCompressedSize, mTotUncompressedSize); + + auto fragment = std::make_shared(std::move(source), std::move(format), + std::move(partition_expression), + std::move(physical_schema)); + return std::dynamic_pointer_cast(fragment); +} + +class TTreeFileWriter : public arrow::dataset::FileWriter +{ + std::vector branches; + std::vector sizesBranches; + std::vector> valueArrays; + std::vector> sizeArrays; + std::vector> valueTypes; + + std::vector valuesIdealBasketSize; + std::vector sizeIdealBasketSize; + + std::vector typeSizes; + std::vector listSizes; + bool firstBasket = true; + + // This is to create a batsket size according to the first batch. + void finaliseBasketSize(std::shared_ptr firstBatch) + { + O2_SIGNPOST_ID_FROM_POINTER(sid, root_arrow_fs, this); + O2_SIGNPOST_START(root_arrow_fs, sid, "finaliseBasketSize", "First batch with %lli rows received and %zu columns", + firstBatch->num_rows(), firstBatch->columns().size()); + for (size_t i = 0; i < branches.size(); i++) { + auto* branch = branches[i]; + auto* sizeBranch = sizesBranches[i]; + + int valueSize = valueTypes[i]->byte_width(); + if (listSizes[i] == 1) { + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry for %lli entries.", + branch->GetName(), valueSize, firstBatch->num_rows()); + assert(sizeBranch == nullptr); + branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize); + } else if (listSizes[i] == -1) { + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry.", + branch->GetName(), valueSize); + // This should probably lookup the + auto column = firstBatch->GetColumnByName(schema_->field(i)->name()); + auto list = std::static_pointer_cast(column); + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. Associated size branch %s and there are %lli entries of size %d in that list.", + branch->GetName(), sizeBranch->GetName(), list->length(), valueSize); + branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize * list->length()); + sizeBranch->SetBasketSize(1024 + firstBatch->num_rows() * 4); + } else { + O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. There are %lli entries per array of size %d in that list.", + branch->GetName(), listSizes[i], valueSize); + assert(sizeBranch == nullptr); + branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize * listSizes[i]); + } + + auto field = firstBatch->schema()->field(i); + if (field->name().starts_with("fIndexArray")) { + // One int per array to keep track of the size + int idealBasketSize = 4 * firstBatch->num_rows() + 1024 + field->type()->byte_width() * firstBatch->num_rows(); // minimal additional size needed, otherwise we get 2 baskets + int basketSize = std::max(32000, idealBasketSize); // keep a minimum value + sizeBranch->SetBasketSize(basketSize); + branch->SetBasketSize(basketSize); + } + } + O2_SIGNPOST_END(root_arrow_fs, sid, "finaliseBasketSize", "Done"); + } + + public: + // Create the TTree based on the physical_schema, not the one in the batch. + // The write method will have to reconcile the two schemas. + TTreeFileWriter(std::shared_ptr schema, std::shared_ptr options, + std::shared_ptr destination, + arrow::fs::FileLocator destination_locator) + : FileWriter(schema, options, destination, destination_locator) + { + // Batches have the same number of entries for each column. + auto directoryStream = std::dynamic_pointer_cast(destination_); + auto treeStream = std::dynamic_pointer_cast(destination_); + + if (directoryStream.get()) { + TDirectoryFile* dir = directoryStream->GetDirectory(); + dir->cd(); + auto* tree = new TTree(destination_locator_.path.c_str(), ""); + treeStream = std::make_shared(tree, ""); + } else if (treeStream.get()) { + // We already have a tree stream, let's derive a new one + // with the destination_locator_.path as prefix for the branches + // This way we can multiplex multiple tables in the same tree. + auto tree = treeStream->GetTree(); + treeStream = std::make_shared(tree, destination_locator_.path); + } else { + // I could simply set a prefix here to merge to an already existing tree. + throw std::runtime_error("Unsupported backend."); + } + + for (auto i = 0u; i < schema->fields().size(); ++i) { + auto& field = schema->field(i); + listSizes.push_back(1); + + int valuesIdealBasketSize = 0; + // Construct all the needed branches. + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + listSizes.back() = std::static_pointer_cast(field->type())->list_size(); + valuesIdealBasketSize = 1024 + valueTypes.back()->byte_width() * listSizes.back(); + valueTypes.push_back(field->type()->field(0)->type()); + sizesBranches.push_back(nullptr); + std::string leafList = fmt::format("{}[{}]{}", field->name(), listSizes.back(), rootSuffixFromArrow(valueTypes.back()->id())); + branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); + } break; + case arrow::Type::LIST: { + valueTypes.push_back(field->type()->field(0)->type()); + std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id())); + listSizes.back() = -1; // VLA, we need to calculate it on the fly; + std::string sizeLeafList = field->name() + "_size/I"; + sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str())); + branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); + // Notice that this could be replaced by a better guess of the + // average size of the list elements, but this is not trivial. + } break; + default: { + valueTypes.push_back(field->type()); + std::string leafList = field->name() + rootSuffixFromArrow(valueTypes.back()->id()); + sizesBranches.push_back(nullptr); + branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); + } break; + } + } + // We create the branches from the schema + } + + arrow::Status Write(const std::shared_ptr& batch) override + { + if (firstBasket) { + firstBasket = false; + finaliseBasketSize(batch); + } + + // Support writing empty tables + if (batch->columns().empty() || batch->num_rows() == 0) { + return arrow::Status::OK(); + } + + // Batches have the same number of entries for each column. + auto directoryStream = std::dynamic_pointer_cast(destination_); + TTree* tree = nullptr; + if (directoryStream.get()) { + TDirectoryFile* dir = directoryStream->GetDirectory(); + tree = (TTree*)dir->Get(destination_locator_.path.c_str()); + } + auto treeStream = std::dynamic_pointer_cast(destination_); + + if (!tree) { + // I could simply set a prefix here to merge to an already existing tree. + throw std::runtime_error("Unsupported backend."); + } + + for (auto i = 0u; i < batch->columns().size(); ++i) { + auto column = batch->column(i); + auto& field = batch->schema()->field(i); + + valueArrays.push_back(nullptr); + + switch (field->type()->id()) { + case arrow::Type::FIXED_SIZE_LIST: { + auto list = std::static_pointer_cast(column); + valueArrays.back() = list->values(); + } break; + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(column); + valueArrays.back() = list; + } break; + case arrow::Type::BOOL: { + // In case of arrays of booleans, we need to go back to their + // char based representation for ROOT to save them. + auto boolArray = std::static_pointer_cast(column); + + int64_t length = boolArray->length(); + arrow::UInt8Builder builder; + auto ok = builder.Reserve(length); + + for (int64_t i = 0; i < length; ++i) { + if (boolArray->IsValid(i)) { + // Expand each boolean value (true/false) to uint8 (1/0) + uint8_t value = boolArray->Value(i) ? 1 : 0; + auto ok = builder.Append(value); + } else { + // Append null for invalid entries + auto ok = builder.AppendNull(); + } + } + valueArrays.back() = *builder.Finish(); + } break; + default: + valueArrays.back() = column; + } + } + + int64_t pos = 0; + while (pos < batch->num_rows()) { + for (size_t bi = 0; bi < branches.size(); ++bi) { + auto* branch = branches[bi]; + auto* sizeBranch = sizesBranches[bi]; + auto array = batch->column(bi); + auto& field = batch->schema()->field(bi); + auto& listSize = listSizes[bi]; + auto valueType = valueTypes[bi]; + auto valueArray = valueArrays[bi]; + + switch (field->type()->id()) { + case arrow::Type::LIST: { + auto list = std::static_pointer_cast(array); + listSize = list->value_length(pos); + uint8_t const* buffer = std::static_pointer_cast(valueArray)->values()->data() + array->offset() + list->value_offset(pos) * valueType->byte_width(); + branch->SetAddress((void*)buffer); + sizeBranch->SetAddress(&listSize); + }; + break; + case arrow::Type::FIXED_SIZE_LIST: + default: { + uint8_t const* buffer = std::static_pointer_cast(valueArray)->values()->data() + array->offset() + pos * listSize * valueType->byte_width(); + branch->SetAddress((void*)buffer); + }; + } + } + tree->Fill(); + ++pos; + } + return arrow::Status::OK(); + } + + arrow::Future<> FinishInternal() override + { + auto treeStream = std::dynamic_pointer_cast(destination_); + TTree* tree = treeStream->GetTree(); + tree->Write("", TObject::kOverwrite); + tree->SetDirectory(nullptr); + + return {}; + }; +}; +arrow::Result> TTreeFileFormat::MakeWriter(std::shared_ptr destination, std::shared_ptr schema, std::shared_ptr options, arrow::fs::FileLocator destination_locator) const +{ + auto writer = std::make_shared(schema, options, destination, destination_locator); + return std::dynamic_pointer_cast(writer); +} + +std::shared_ptr TTreeFileFormat::DefaultWriteOptions() +{ + std::shared_ptr options( + new TTreeFileWriteOptions(shared_from_this())); + return options; +} + +TTreeFileSystem::~TTreeFileSystem() = default; + +DEFINE_DPL_PLUGINS_BEGIN +DEFINE_DPL_PLUGIN_INSTANCE(TTreeObjectReadingImplementation, RootObjectReadingImplementation); +DEFINE_DPL_PLUGINS_END +} // namespace o2::framework diff --git a/Framework/Core/CMakeLists.txt b/Framework/Core/CMakeLists.txt index 5cdd1241ecfb0..c1214a8f56beb 100644 --- a/Framework/Core/CMakeLists.txt +++ b/Framework/Core/CMakeLists.txt @@ -270,6 +270,10 @@ o2_add_test(Timers NAME test_Framework_test_Timers LABELS framework PUBLIC_LINK_LIBRARIES O2::Framework) +o2_add_executable(framework-ao2d-to-ao3d + SOURCES test/o2AO2DToAO3D.cxx + PUBLIC_LINK_LIBRARIES O2::Framework) + # FIXME: make this a proper test, when it actually does not hang. o2_add_executable(test-framework-ConsumeWhenAllOrdered SOURCES test/test_ConsumeWhenAllOrdered.cxx @@ -299,6 +303,7 @@ add_executable(o2-test-framework-root target_link_libraries(o2-test-framework-root PRIVATE O2::Framework) target_link_libraries(o2-test-framework-root PRIVATE O2::Catch2) target_link_libraries(o2-test-framework-root PRIVATE ROOT::ROOTDataFrame) +target_link_libraries(o2-test-framework-root PRIVATE ROOT::ROOTNTuple) set_property(TARGET o2-test-framework-root PROPERTY RUNTIME_OUTPUT_DIRECTORY ${outdir}) add_test(NAME framework:root COMMAND o2-test-framework-root --skip-benchmarks) add_test(NAME framework:crash COMMAND sh -e -c "PATH=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}:$PATH ${CMAKE_CURRENT_LIST_DIR}/test/test_AllCrashTypes.sh") diff --git a/Framework/Core/include/Framework/Plugins.h b/Framework/Core/include/Framework/Plugins.h index 23d55a512e1fa..925943c6bffc3 100644 --- a/Framework/Core/include/Framework/Plugins.h +++ b/Framework/Core/include/Framework/Plugins.h @@ -36,6 +36,14 @@ enum struct DplPluginKind : int { // set, you might want to load metadata from it and attach it to the // configuration. Capability, + // A RootObjectReadingCapability is used to discover if there is away + // to read and understand an object serialised with ROOT. + RootObjectReadingCapability, + + // A RootObjectReadingImplementation is actually used to read said object + // using the arrow dataset API + RootObjectReadingImplementation, + // A plugin which was not initialised properly. Unknown }; diff --git a/Framework/Core/include/Framework/RootArrowFilesystem.h b/Framework/Core/include/Framework/RootArrowFilesystem.h index 48d817bc9ddf2..8744656e7d55d 100644 --- a/Framework/Core/include/Framework/RootArrowFilesystem.h +++ b/Framework/Core/include/Framework/RootArrowFilesystem.h @@ -11,6 +11,7 @@ #ifndef O2_FRAMEWORK_ROOT_ARROW_FILESYSTEM_H_ #define O2_FRAMEWORK_ROOT_ARROW_FILESYSTEM_H_ +#include #include #include #include @@ -18,23 +19,12 @@ #include class TFile; -class TBranch; -class TTree; class TBufferFile; class TDirectoryFile; namespace o2::framework { -class TTreeFileWriteOptions : public arrow::dataset::FileWriteOptions -{ - public: - TTreeFileWriteOptions(std::shared_ptr format) - : FileWriteOptions(format) - { - } -}; - // This is to avoid having to implement a bunch of unimplemented methods // for all the possible virtual filesystem we can invent on top of ROOT // data structures. @@ -79,46 +69,43 @@ class VirtualRootFileSystemBase : public arrow::fs::FileSystem const std::shared_ptr& metadata) override; }; -// A filesystem which allows me to get a TTree -class TTreeFileSystem : public VirtualRootFileSystemBase -{ - public: - ~TTreeFileSystem() override; - - std::shared_ptr GetSubFilesystem(arrow::dataset::FileSource source) override - { - return std::dynamic_pointer_cast(shared_from_this()); - }; - - arrow::Result> OpenOutputStream( - const std::string& path, - const std::shared_ptr& metadata) override; - - virtual TTree* GetTree(arrow::dataset::FileSource source) = 0; +struct RootArrowFactory final { + std::function()> options = nullptr; + std::function()> format = nullptr; + std::function(void*)> getSubFilesystem = nullptr; }; -class SingleTreeFileSystem : public TTreeFileSystem -{ - public: - SingleTreeFileSystem(TTree* tree) - : TTreeFileSystem(), - mTree(tree) - { - } +struct RootArrowFactoryPlugin { + virtual RootArrowFactory* create() = 0; +}; - std::string type_name() const override - { - return "ttree"; - } +// A registry for all the possible ways of encoding a table in a TFile +struct RootObjectReadingCapability { + // The unique name of this capability + std::string name = "unknown"; + // Given a TFile, return the object which this capability support + // Use a void * in order not to expose the kind of object to the + // generic reading code. This is also where we load the plugin + // which will be used for the actual creation. + std::function getHandle; + // Same as the above, but uses a TBufferFile as storage + std::function getBufferHandle; + // This must be implemented to load the actual RootArrowFactory plugin which + // implements this capability. This way the detection of the file format + // (via get handle) does not need to know about the actual code which performs + // the serialization (and might depend on e.g. RNTuple). + std::function factory; +}; - TTree* GetTree(arrow::dataset::FileSource) override - { - // Simply return the only TTree we have - return mTree; - } +struct RootObjectReadingCapabilityPlugin { + virtual RootObjectReadingCapability* create() = 0; +}; - private: - TTree* mTree; +// This acts as registry of all the capabilities (i.e. the ability to +// associate a given object in a root file to the serialization plugin) and +// the factory (i.e. the serialization plugin) +struct RootObjectReadingFactory { + std::vector capabilities; }; class TFileFileSystem : public VirtualRootFileSystemBase @@ -126,7 +113,7 @@ class TFileFileSystem : public VirtualRootFileSystemBase public: arrow::Result GetFileInfo(const std::string& path) override; - TFileFileSystem(TDirectoryFile* f, size_t readahead); + TFileFileSystem(TDirectoryFile* f, size_t readahead, RootObjectReadingFactory&); std::string type_name() const override { @@ -147,12 +134,13 @@ class TFileFileSystem : public VirtualRootFileSystemBase private: TDirectoryFile* mFile; + RootObjectReadingFactory& mObjectFactory; }; class TBufferFileFS : public VirtualRootFileSystemBase { public: - TBufferFileFS(TBufferFile* f); + TBufferFileFS(TBufferFile* f, RootObjectReadingFactory&); arrow::Result GetFileInfo(const std::string& path) override; std::string type_name() const override @@ -165,68 +153,7 @@ class TBufferFileFS : public VirtualRootFileSystemBase private: TBufferFile* mBuffer; std::shared_ptr mFilesystem; -}; - -class TTreeFileFragment : public arrow::dataset::FileFragment -{ - public: - TTreeFileFragment(arrow::dataset::FileSource source, - std::shared_ptr format, - arrow::compute::Expression partition_expression, - std::shared_ptr physical_schema) - : FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema)) - { - } -}; - -class TTreeFileFormat : public arrow::dataset::FileFormat -{ - size_t& mTotCompressedSize; - size_t& mTotUncompressedSize; - - public: - TTreeFileFormat(size_t& totalCompressedSize, size_t& totalUncompressedSize) - : FileFormat({}), - mTotCompressedSize(totalCompressedSize), - mTotUncompressedSize(totalUncompressedSize) - { - } - - ~TTreeFileFormat() override = default; - - std::string type_name() const override - { - return "ttree"; - } - - bool Equals(const FileFormat& other) const override - { - return other.type_name() == this->type_name(); - } - - arrow::Result IsSupported(const arrow::dataset::FileSource& source) const override - { - auto fs = std::dynamic_pointer_cast(source.filesystem()); - auto subFs = fs->GetSubFilesystem(source); - if (std::dynamic_pointer_cast(subFs)) { - return true; - } - return false; - } - - arrow::Result> Inspect(const arrow::dataset::FileSource& source) const override; - /// \brief Create a FileFragment for a FileSource. - arrow::Result> MakeFragment( - arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, - std::shared_ptr physical_schema) override; - - arrow::Result> MakeWriter(std::shared_ptr destination, std::shared_ptr schema, std::shared_ptr options, arrow::fs::FileLocator destination_locator) const override; - - std::shared_ptr DefaultWriteOptions() override; - - arrow::Result ScanBatchesAsync( - const std::shared_ptr& options, - const std::shared_ptr& fragment) const override; + RootObjectReadingFactory& mObjectFactory; }; // An arrow outputstream which allows to write to a TDirectoryFile. @@ -255,33 +182,6 @@ class TDirectoryFileOutputStream : public arrow::io::OutputStream TDirectoryFile* mDirectory; }; -// An arrow outputstream which allows to write to a TTree. Eventually -// with a prefix for the branches. -class TTreeOutputStream : public arrow::io::OutputStream -{ - public: - TTreeOutputStream(TTree*, std::string branchPrefix); - - arrow::Status Close() override; - - arrow::Result Tell() const override; - - arrow::Status Write(const void* data, int64_t nbytes) override; - - bool closed() const override; - - TBranch* CreateBranch(char const* branchName, char const* sizeBranch); - - TTree* GetTree() - { - return mTree; - } - - private: - TTree* mTree; - std::string mBranchPrefix; -}; - } // namespace o2::framework #endif // O2_FRAMEWORK_ROOT_ARROW_FILESYSTEM_H_ diff --git a/Framework/Core/src/Plugin.cxx b/Framework/Core/src/Plugin.cxx index 0d225b81c0581..af71db4af3445 100644 --- a/Framework/Core/src/Plugin.cxx +++ b/Framework/Core/src/Plugin.cxx @@ -1,4 +1,4 @@ -// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// Copyright 2019-2024 CERN and copyright holders of ALICE O2. // See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. // All rights not expressly granted are reserved. // @@ -11,10 +11,15 @@ #include "Framework/Plugins.h" #include "Framework/ConfigParamDiscovery.h" #include "Framework/ConfigParamRegistry.h" +#include "Framework/RootArrowFilesystem.h" #include "Framework/Logger.h" #include "Framework/Capability.h" #include "Framework/Signpost.h" #include "Framework/VariantJSONHelpers.h" +#include "Framework/PluginManager.h" +#include +#include +#include #include #include @@ -168,11 +173,75 @@ struct DiscoverAODOptionsInCommandLine : o2::framework::ConfigDiscoveryPlugin { } }; +struct ImplementationContext { + std::vector implementations; +}; + +std::function getHandleByClass(char const* classname) +{ + return [classname](TDirectoryFile* file, std::string const& path) { return file->GetObjectChecked(path.c_str(), TClass::GetClass(classname)); }; +} + +std::function getBufferHandleByClass(char const* classname) +{ + return [classname](TBufferFile* buffer, std::string const& path) { buffer->Reset(); return buffer->ReadObjectAny(TClass::GetClass(classname)); }; +} + +void lazyLoadFactory(std::vector& implementations, char const* specs) +{ + // Lazy loading of the plugin so that we do not bring in RNTuple / TTree if not needed + if (implementations.empty()) { + std::vector plugins; + auto morePlugins = PluginManager::parsePluginSpecString(specs); + for (auto& extra : morePlugins) { + plugins.push_back(extra); + } + PluginManager::loadFromPlugin(plugins, implementations); + if (implementations.empty()) { + return; + } + } +} + +struct RNTupleObjectReadingCapability : o2::framework::RootObjectReadingCapabilityPlugin { + RootObjectReadingCapability* create() override + { + auto context = new ImplementationContext; + + return new RootObjectReadingCapability{ + .name = "rntuple", + .getHandle = getHandleByClass("ROOT::Experimental::RNTuple"), + .getBufferHandle = getBufferHandleByClass("ROOT::Experimental::RNTuple"), + .factory = [context]() -> RootArrowFactory& { + lazyLoadFactory(context->implementations, "O2FrameworkAnalysisRNTupleSupport:RNTupleObjectReadingImplementation"); + return context->implementations.back(); + }}; + } +}; + +struct TTreeObjectReadingCapability : o2::framework::RootObjectReadingCapabilityPlugin { + RootObjectReadingCapability* create() override + { + auto context = new ImplementationContext; + + return new RootObjectReadingCapability{ + .name = "ttree", + .getHandle = getHandleByClass("TTree"), + .getBufferHandle = getBufferHandleByClass("TTree"), + .factory = [context]() -> RootArrowFactory& { + lazyLoadFactory(context->implementations, "O2FrameworkAnalysisTTreeSupport:TTreeObjectReadingImplementation"); + return context->implementations.back(); + }}; + } +}; + DEFINE_DPL_PLUGINS_BEGIN DEFINE_DPL_PLUGIN_INSTANCE(DiscoverMetadataInAODCapability, Capability); DEFINE_DPL_PLUGIN_INSTANCE(DiscoverMetadataInCommandLineCapability, Capability); DEFINE_DPL_PLUGIN_INSTANCE(DiscoverAODOptionsInCommandLineCapability, Capability); DEFINE_DPL_PLUGIN_INSTANCE(DiscoverMetadataInCommandLine, ConfigDiscovery); DEFINE_DPL_PLUGIN_INSTANCE(DiscoverAODOptionsInCommandLine, ConfigDiscovery); +DEFINE_DPL_PLUGIN_INSTANCE(RNTupleObjectReadingCapability, RootObjectReadingCapability); +DEFINE_DPL_PLUGIN_INSTANCE(TTreeObjectReadingCapability, RootObjectReadingCapability); DEFINE_DPL_PLUGINS_END } // namespace o2::framework diff --git a/Framework/Core/src/RootArrowFilesystem.cxx b/Framework/Core/src/RootArrowFilesystem.cxx index 5f2d21d942d37..545ba6f0afb71 100644 --- a/Framework/Core/src/RootArrowFilesystem.cxx +++ b/Framework/Core/src/RootArrowFilesystem.cxx @@ -9,9 +9,7 @@ // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. #include "Framework/RootArrowFilesystem.h" -#include "Framework/Endian.h" #include "Framework/RuntimeError.h" -#include "Framework/Signpost.h" #include #include #include @@ -19,93 +17,48 @@ #include #include #include -#include #include -#include #include #include #include #include #include #include -#include #include -#include -#include +template class + std::shared_ptr; -O2_DECLARE_DYNAMIC_LOG(root_arrow_fs); - -namespace -{ -struct BranchInfo { - std::string name; - TBranch* ptr; - bool mVLA; -}; -} // namespace - -auto arrowTypeFromROOT(EDataType type, int size) -{ - auto typeGenerator = [](std::shared_ptr const& type, int size) -> std::shared_ptr { - switch (size) { - case -1: - return arrow::list(type); - case 1: - return std::move(type); - default: - return arrow::fixed_size_list(type, size); - } - }; - - switch (type) { - case EDataType::kBool_t: - return typeGenerator(arrow::boolean(), size); - case EDataType::kUChar_t: - return typeGenerator(arrow::uint8(), size); - case EDataType::kUShort_t: - return typeGenerator(arrow::uint16(), size); - case EDataType::kUInt_t: - return typeGenerator(arrow::uint32(), size); - case EDataType::kULong64_t: - return typeGenerator(arrow::uint64(), size); - case EDataType::kChar_t: - return typeGenerator(arrow::int8(), size); - case EDataType::kShort_t: - return typeGenerator(arrow::int16(), size); - case EDataType::kInt_t: - return typeGenerator(arrow::int32(), size); - case EDataType::kLong64_t: - return typeGenerator(arrow::int64(), size); - case EDataType::kFloat_t: - return typeGenerator(arrow::float32(), size); - case EDataType::kDouble_t: - return typeGenerator(arrow::float64(), size); - default: - throw o2::framework::runtime_error_f("Unsupported branch type: %d", static_cast(type)); - } -} namespace o2::framework { using arrow::Status; -TFileFileSystem::TFileFileSystem(TDirectoryFile* f, size_t readahead) +TFileFileSystem::TFileFileSystem(TDirectoryFile* f, size_t readahead, RootObjectReadingFactory& factory) : VirtualRootFileSystemBase(), - mFile(f) + mFile(f), + mObjectFactory(factory) { ((TFile*)mFile)->SetReadaheadSize(50 * 1024 * 1024); } std::shared_ptr TFileFileSystem::GetSubFilesystem(arrow::dataset::FileSource source) { - auto tree = (TTree*)mFile->GetObjectChecked(source.path().c_str(), TClass::GetClass()); - if (tree) { - return std::shared_ptr(new SingleTreeFileSystem(tree)); + // We use a plugin to create the actual objects inside the + // file, so that we can support TTree and RNTuple at the same time + // without having to depend on both. + for (auto& capability : mObjectFactory.capabilities) { + void* handle = capability.getHandle(mFile, source.path()); + if (!handle) { + continue; + } + if (handle) { + return capability.factory().getSubFilesystem(handle); + } } auto directory = (TDirectoryFile*)mFile->GetObjectChecked(source.path().c_str(), TClass::GetClass()); if (directory) { - return std::shared_ptr(new TFileFileSystem(directory, 50 * 1024 * 1024)); + return std::shared_ptr(new TFileFileSystem(directory, 50 * 1024 * 1024, mObjectFactory)); } throw runtime_error_f("Unsupported file layout"); } @@ -120,10 +73,14 @@ arrow::Result TFileFileSystem::GetFileInfo(const std::strin auto fs = GetSubFilesystem(source); // For now we only support single trees. - if (std::dynamic_pointer_cast(fs)) { - result.set_type(arrow::fs::FileType::File); + if (std::dynamic_pointer_cast(fs)) { + result.set_type(arrow::fs::FileType::Directory); return result; } + // Everything else is a file, if it was created. + if (fs.get()) { + result.set_type(arrow::fs::FileType::File); + } return result; } @@ -137,7 +94,7 @@ arrow::Result> TFileFileSystem::OpenOut auto* dir = dynamic_cast(this->GetFile()->Get(path.c_str())); if (!dir) { - throw runtime_error_f("Unable to open directory %s in file %s", path.c_str(), GetFile()->GetName()); + return arrow::Status::Invalid(fmt::format("Unable to open directory {} in file {} ", path.c_str(), GetFile()->GetName())); } auto stream = std::make_shared(dir); return stream; @@ -219,81 +176,6 @@ arrow::Result> VirtualRootFileSystemBas return arrow::Status::NotImplemented("No random access file system"); } -arrow::Result> TTreeFileFormat::Inspect(const arrow::dataset::FileSource& source) const -{ - arrow::Schema schema{{}}; - auto fs = std::dynamic_pointer_cast(source.filesystem()); - // Actually get the TTree from the ROOT file. - auto treeFs = std::dynamic_pointer_cast(fs->GetSubFilesystem(source)); - if (!treeFs.get()) { - throw runtime_error_f("Unknown filesystem %s\n", source.filesystem()->type_name().c_str()); - } - TTree* tree = treeFs->GetTree(source); - - auto branches = tree->GetListOfBranches(); - auto n = branches->GetEntries(); - - std::vector branchInfos; - for (auto i = 0; i < n; ++i) { - auto branch = static_cast(branches->At(i)); - auto name = std::string{branch->GetName()}; - auto pos = name.find("_size"); - if (pos != std::string::npos) { - name.erase(pos); - branchInfos.emplace_back(BranchInfo{name, (TBranch*)nullptr, true}); - } else { - auto lookup = std::find_if(branchInfos.begin(), branchInfos.end(), [&](BranchInfo const& bi) { - return bi.name == name; - }); - if (lookup == branchInfos.end()) { - branchInfos.emplace_back(BranchInfo{name, branch, false}); - } else { - lookup->ptr = branch; - } - } - } - - std::vector> fields; - tree->SetCacheSize(25000000); - for (auto& bi : branchInfos) { - static TClass* cls; - EDataType type; - bi.ptr->GetExpectedType(cls, type); - auto listSize = -1; - if (!bi.mVLA) { - listSize = static_cast(bi.ptr->GetListOfLeaves()->At(0))->GetLenStatic(); - } - auto field = std::make_shared(bi.ptr->GetName(), arrowTypeFromROOT(type, listSize)); - fields.push_back(field); - - tree->AddBranchToCache(bi.ptr); - if (strncmp(bi.ptr->GetName(), "fIndexArray", strlen("fIndexArray")) == 0) { - std::string sizeBranchName = bi.ptr->GetName(); - sizeBranchName += "_size"; - auto* sizeBranch = (TBranch*)tree->GetBranch(sizeBranchName.c_str()); - if (sizeBranch) { - tree->AddBranchToCache(sizeBranch); - } - } - } - tree->StopCacheLearningPhase(); - - return std::make_shared(fields); -} - -/// \brief Create a FileFragment for a FileSource. -arrow::Result> TTreeFileFormat::MakeFragment( - arrow::dataset::FileSource source, arrow::compute::Expression partition_expression, - std::shared_ptr physical_schema) -{ - std::shared_ptr format = std::make_shared(mTotCompressedSize, mTotUncompressedSize); - - auto fragment = std::make_shared(std::move(source), std::move(format), - std::move(partition_expression), - std::move(physical_schema)); - return std::dynamic_pointer_cast(fragment); -} - // An arrow outputstream which allows to write to a ttree TDirectoryFileOutputStream::TDirectoryFileOutputStream(TDirectoryFile* f) : mDirectory(f) @@ -321,544 +203,14 @@ bool TDirectoryFileOutputStream::closed() const return mDirectory->GetFile()->IsOpen() == false; } -// An arrow outputstream which allows to write to a ttree -// @a branch prefix is to be used to identify a set of branches which all belong to -// the same table. -TTreeOutputStream::TTreeOutputStream(TTree* f, std::string branchPrefix) - : mTree(f), - mBranchPrefix(std::move(branchPrefix)) -{ -} - -arrow::Status TTreeOutputStream::Close() -{ - if (mTree->GetCurrentFile() == nullptr) { - return arrow::Status::Invalid("Cannot close a tree not attached to a file"); - } - mTree->GetCurrentFile()->Close(); - return arrow::Status::OK(); -} - -arrow::Result TTreeOutputStream::Tell() const -{ - return arrow::Result(arrow::Status::NotImplemented("Cannot move")); -} - -arrow::Status TTreeOutputStream::Write(const void* data, int64_t nbytes) -{ - return arrow::Status::NotImplemented("Cannot write raw bytes to a TTree"); -} - -bool TTreeOutputStream::closed() const -{ - // A standalone tree is never closed. - if (mTree->GetCurrentFile() == nullptr) { - return false; - } - return mTree->GetCurrentFile()->IsOpen() == false; -} - -TBranch* TTreeOutputStream::CreateBranch(char const* branchName, char const* sizeBranch) -{ - return mTree->Branch((mBranchPrefix + "/" + branchName).c_str(), (char*)nullptr, (mBranchPrefix + sizeBranch).c_str()); -} - -char const* rootSuffixFromArrow(arrow::Type::type id) -{ - switch (id) { - case arrow::Type::BOOL: - return "/O"; - case arrow::Type::UINT8: - return "/b"; - case arrow::Type::UINT16: - return "/s"; - case arrow::Type::UINT32: - return "/i"; - case arrow::Type::UINT64: - return "/l"; - case arrow::Type::INT8: - return "/B"; - case arrow::Type::INT16: - return "/S"; - case arrow::Type::INT32: - return "/I"; - case arrow::Type::INT64: - return "/L"; - case arrow::Type::FLOAT: - return "/F"; - case arrow::Type::DOUBLE: - return "/D"; - default: - throw runtime_error("Unsupported arrow column type"); - } -} - -class TTreeFileWriter : public arrow::dataset::FileWriter -{ - std::vector branches; - std::vector sizesBranches; - std::vector> valueArrays; - std::vector> sizeArrays; - std::vector> valueTypes; - - std::vector valuesIdealBasketSize; - std::vector sizeIdealBasketSize; - - std::vector typeSizes; - std::vector listSizes; - bool firstBasket = true; - - // This is to create a batsket size according to the first batch. - void finaliseBasketSize(std::shared_ptr firstBatch) - { - O2_SIGNPOST_ID_FROM_POINTER(sid, root_arrow_fs, this); - O2_SIGNPOST_START(root_arrow_fs, sid, "finaliseBasketSize", "First batch with %lli rows received and %zu columns", - firstBatch->num_rows(), firstBatch->columns().size()); - for (size_t i = 0; i < branches.size(); i++) { - auto* branch = branches[i]; - auto* sizeBranch = sizesBranches[i]; - - int valueSize = valueTypes[i]->byte_width(); - if (listSizes[i] == 1) { - O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry for %lli entries.", - branch->GetName(), valueSize, firstBatch->num_rows()); - assert(sizeBranch == nullptr); - branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize); - } else if (listSizes[i] == -1) { - O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s exists and uses %d bytes per entry.", - branch->GetName(), valueSize); - // This should probably lookup the - auto column = firstBatch->GetColumnByName(schema_->field(i)->name()); - auto list = std::static_pointer_cast(column); - O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. Associated size branch %s and there are %lli entries of size %d in that list.", - branch->GetName(), sizeBranch->GetName(), list->length(), valueSize); - branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize * list->length()); - sizeBranch->SetBasketSize(1024 + firstBatch->num_rows() * 4); - } else { - O2_SIGNPOST_EVENT_EMIT(root_arrow_fs, sid, "finaliseBasketSize", "Branch %s needed. There are %lli entries per array of size %d in that list.", - branch->GetName(), listSizes[i], valueSize); - assert(sizeBranch == nullptr); - branch->SetBasketSize(1024 + firstBatch->num_rows() * valueSize * listSizes[i]); - } - - auto field = firstBatch->schema()->field(i); - if (field->name().starts_with("fIndexArray")) { - // One int per array to keep track of the size - int idealBasketSize = 4 * firstBatch->num_rows() + 1024 + field->type()->byte_width() * firstBatch->num_rows(); // minimal additional size needed, otherwise we get 2 baskets - int basketSize = std::max(32000, idealBasketSize); // keep a minimum value - sizeBranch->SetBasketSize(basketSize); - branch->SetBasketSize(basketSize); - } - } - O2_SIGNPOST_END(root_arrow_fs, sid, "finaliseBasketSize", "Done"); - } - - public: - // Create the TTree based on the physical_schema, not the one in the batch. - // The write method will have to reconcile the two schemas. - TTreeFileWriter(std::shared_ptr schema, std::shared_ptr options, - std::shared_ptr destination, - arrow::fs::FileLocator destination_locator) - : FileWriter(schema, options, destination, destination_locator) - { - // Batches have the same number of entries for each column. - auto directoryStream = std::dynamic_pointer_cast(destination_); - auto treeStream = std::dynamic_pointer_cast(destination_); - - if (directoryStream.get()) { - TDirectoryFile* dir = directoryStream->GetDirectory(); - dir->cd(); - auto* tree = new TTree(destination_locator_.path.c_str(), ""); - treeStream = std::make_shared(tree, ""); - } else if (treeStream.get()) { - // We already have a tree stream, let's derive a new one - // with the destination_locator_.path as prefix for the branches - // This way we can multiplex multiple tables in the same tree. - auto tree = treeStream->GetTree(); - treeStream = std::make_shared(tree, destination_locator_.path); - } else { - // I could simply set a prefix here to merge to an already existing tree. - throw std::runtime_error("Unsupported backend."); - } - - for (auto i = 0u; i < schema->fields().size(); ++i) { - auto& field = schema->field(i); - listSizes.push_back(1); - - int valuesIdealBasketSize = 0; - // Construct all the needed branches. - switch (field->type()->id()) { - case arrow::Type::FIXED_SIZE_LIST: { - listSizes.back() = std::static_pointer_cast(field->type())->list_size(); - valuesIdealBasketSize = 1024 + valueTypes.back()->byte_width() * listSizes.back(); - valueTypes.push_back(field->type()->field(0)->type()); - sizesBranches.push_back(nullptr); - std::string leafList = fmt::format("{}[{}]{}", field->name(), listSizes.back(), rootSuffixFromArrow(valueTypes.back()->id())); - branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); - } break; - case arrow::Type::LIST: { - valueTypes.push_back(field->type()->field(0)->type()); - std::string leafList = fmt::format("{}[{}_size]{}", field->name(), field->name(), rootSuffixFromArrow(valueTypes.back()->id())); - listSizes.back() = -1; // VLA, we need to calculate it on the fly; - std::string sizeLeafList = field->name() + "_size/I"; - sizesBranches.push_back(treeStream->CreateBranch((field->name() + "_size").c_str(), sizeLeafList.c_str())); - branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); - // Notice that this could be replaced by a better guess of the - // average size of the list elements, but this is not trivial. - } break; - default: { - valueTypes.push_back(field->type()); - std::string leafList = field->name() + rootSuffixFromArrow(valueTypes.back()->id()); - sizesBranches.push_back(nullptr); - branches.push_back(treeStream->CreateBranch(field->name().c_str(), leafList.c_str())); - } break; - } - } - // We create the branches from the schema - } - - arrow::Status Write(const std::shared_ptr& batch) override - { - if (firstBasket) { - firstBasket = false; - finaliseBasketSize(batch); - } - - // Support writing empty tables - if (batch->columns().empty() || batch->num_rows() == 0) { - return arrow::Status::OK(); - } - - // Batches have the same number of entries for each column. - auto directoryStream = std::dynamic_pointer_cast(destination_); - TTree* tree = nullptr; - if (directoryStream.get()) { - TDirectoryFile* dir = directoryStream->GetDirectory(); - tree = (TTree*)dir->Get(destination_locator_.path.c_str()); - } - auto treeStream = std::dynamic_pointer_cast(destination_); - - if (!tree) { - // I could simply set a prefix here to merge to an already existing tree. - throw std::runtime_error("Unsupported backend."); - } - - for (auto i = 0u; i < batch->columns().size(); ++i) { - auto column = batch->column(i); - auto& field = batch->schema()->field(i); - - valueArrays.push_back(nullptr); - - switch (field->type()->id()) { - case arrow::Type::FIXED_SIZE_LIST: { - auto list = std::static_pointer_cast(column); - valueArrays.back() = list->values(); - } break; - case arrow::Type::LIST: { - auto list = std::static_pointer_cast(column); - valueArrays.back() = list; - } break; - case arrow::Type::BOOL: { - // In case of arrays of booleans, we need to go back to their - // char based representation for ROOT to save them. - auto boolArray = std::static_pointer_cast(column); - - int64_t length = boolArray->length(); - arrow::UInt8Builder builder; - auto ok = builder.Reserve(length); - - for (int64_t i = 0; i < length; ++i) { - if (boolArray->IsValid(i)) { - // Expand each boolean value (true/false) to uint8 (1/0) - uint8_t value = boolArray->Value(i) ? 1 : 0; - auto ok = builder.Append(value); - } else { - // Append null for invalid entries - auto ok = builder.AppendNull(); - } - } - valueArrays.back() = *builder.Finish(); - } break; - default: - valueArrays.back() = column; - } - } - - int64_t pos = 0; - while (pos < batch->num_rows()) { - for (size_t bi = 0; bi < branches.size(); ++bi) { - auto* branch = branches[bi]; - auto* sizeBranch = sizesBranches[bi]; - auto array = batch->column(bi); - auto& field = batch->schema()->field(bi); - auto& listSize = listSizes[bi]; - auto valueType = valueTypes[bi]; - auto valueArray = valueArrays[bi]; - - switch (field->type()->id()) { - case arrow::Type::LIST: { - auto list = std::static_pointer_cast(array); - listSize = list->value_length(pos); - uint8_t const* buffer = std::static_pointer_cast(valueArray)->values()->data() + array->offset() + list->value_offset(pos) * valueType->byte_width(); - branch->SetAddress((void*)buffer); - sizeBranch->SetAddress(&listSize); - }; - break; - case arrow::Type::FIXED_SIZE_LIST: - default: { - uint8_t const* buffer = std::static_pointer_cast(valueArray)->values()->data() + array->offset() + pos * listSize * valueType->byte_width(); - branch->SetAddress((void*)buffer); - }; - } - } - tree->Fill(); - ++pos; - } - return arrow::Status::OK(); - } - - arrow::Future<> FinishInternal() override - { - auto treeStream = std::dynamic_pointer_cast(destination_); - TTree* tree = treeStream->GetTree(); - tree->Write("", TObject::kOverwrite); - tree->SetDirectory(nullptr); - - return {}; - }; -}; - -arrow::Result> TTreeFileFormat::MakeWriter(std::shared_ptr destination, std::shared_ptr schema, std::shared_ptr options, arrow::fs::FileLocator destination_locator) const -{ - auto writer = std::make_shared(schema, options, destination, destination_locator); - return std::dynamic_pointer_cast(writer); -} - -std::shared_ptr TTreeFileFormat::DefaultWriteOptions() -{ - std::shared_ptr options( - new TTreeFileWriteOptions(shared_from_this())); - return options; -} - -arrow::Result TTreeFileFormat::ScanBatchesAsync( - const std::shared_ptr& options, - const std::shared_ptr& fragment) const -{ - // Get the fragment as a TTreeFragment. This might be PART of a TTree. - auto treeFragment = std::dynamic_pointer_cast(fragment); - // This is the schema we want to read - auto dataset_schema = options->dataset_schema; - - auto generator = [pool = options->pool, treeFragment, dataset_schema, &totalCompressedSize = mTotCompressedSize, - &totalUncompressedSize = mTotUncompressedSize]() -> arrow::Future> { - auto schema = treeFragment->format()->Inspect(treeFragment->source()); - - std::vector> columns; - std::vector> fields = dataset_schema->fields(); - auto physical_schema = *treeFragment->ReadPhysicalSchema(); - - static TBufferFile buffer{TBuffer::EMode::kWrite, 4 * 1024 * 1024}; - auto containerFS = std::dynamic_pointer_cast(treeFragment->source().filesystem()); - auto fs = std::dynamic_pointer_cast(containerFS->GetSubFilesystem(treeFragment->source())); - - int64_t rows = -1; - TTree* tree = fs->GetTree(treeFragment->source()); - for (auto& field : fields) { - // The field actually on disk - auto physicalField = physical_schema->GetFieldByName(field->name()); - TBranch* branch = tree->GetBranch(physicalField->name().c_str()); - assert(branch); - buffer.Reset(); - auto totalEntries = branch->GetEntries(); - if (rows == -1) { - rows = totalEntries; - } - if (rows != totalEntries) { - throw runtime_error_f("Unmatching number of rows for branch %s", branch->GetName()); - } - arrow::Status status; - int readEntries = 0; - std::shared_ptr array; - auto listType = std::dynamic_pointer_cast(physicalField->type()); - if (physicalField->type() == arrow::boolean() || - (listType && physicalField->type()->field(0)->type() == arrow::boolean())) { - if (listType) { - std::unique_ptr builder = nullptr; - auto status = arrow::MakeBuilder(pool, physicalField->type()->field(0)->type(), &builder); - if (!status.ok()) { - throw runtime_error("Cannot create value builder"); - } - auto listBuilder = std::make_unique(pool, std::move(builder), listType->list_size()); - auto valueBuilder = listBuilder.get()->value_builder(); - // boolean array special case: we need to use builder to create the bitmap - status = valueBuilder->Reserve(totalEntries * listType->list_size()); - status &= listBuilder->Reserve(totalEntries); - if (!status.ok()) { - throw runtime_error("Failed to reserve memory for array builder"); - } - while (readEntries < totalEntries) { - auto readLast = branch->GetBulkRead().GetBulkEntries(readEntries, buffer); - readEntries += readLast; - status &= static_cast(valueBuilder)->AppendValues(reinterpret_cast(buffer.GetCurrent()), readLast * listType->list_size()); - } - status &= static_cast(listBuilder.get())->AppendValues(readEntries); - if (!status.ok()) { - throw runtime_error("Failed to append values to array"); - } - status &= listBuilder->Finish(&array); - if (!status.ok()) { - throw runtime_error("Failed to create array"); - } - } else if (listType == nullptr) { - std::unique_ptr builder = nullptr; - auto status = arrow::MakeBuilder(pool, physicalField->type(), &builder); - if (!status.ok()) { - throw runtime_error("Cannot create builder"); - } - auto valueBuilder = static_cast(builder.get()); - // boolean array special case: we need to use builder to create the bitmap - status = valueBuilder->Reserve(totalEntries); - if (!status.ok()) { - throw runtime_error("Failed to reserve memory for array builder"); - } - while (readEntries < totalEntries) { - auto readLast = branch->GetBulkRead().GetBulkEntries(readEntries, buffer); - readEntries += readLast; - status &= valueBuilder->AppendValues(reinterpret_cast(buffer.GetCurrent()), readLast); - } - if (!status.ok()) { - throw runtime_error("Failed to append values to array"); - } - status &= valueBuilder->Finish(&array); - if (!status.ok()) { - throw runtime_error("Failed to create array"); - } - } - } else { - // other types: use serialized read to build arrays directly. - auto typeSize = physicalField->type()->byte_width(); - // This is needed for branches which have not been persisted. - auto bytes = branch->GetTotBytes(); - auto branchSize = bytes ? bytes : 1000000; - auto&& result = arrow::AllocateResizableBuffer(branchSize, pool); - if (!result.ok()) { - throw runtime_error("Cannot allocate values buffer"); - } - std::shared_ptr arrowValuesBuffer = std::move(result).ValueUnsafe(); - auto ptr = arrowValuesBuffer->mutable_data(); - if (ptr == nullptr) { - throw runtime_error("Invalid buffer"); - } - - std::unique_ptr offsetBuffer = nullptr; - - uint32_t offset = 0; - int count = 0; - std::shared_ptr arrowOffsetBuffer; - std::span offsets; - int size = 0; - uint32_t totalSize = 0; - TBranch* mSizeBranch = nullptr; - int64_t listSize = 1; - if (auto fixedSizeList = std::dynamic_pointer_cast(physicalField->type())) { - listSize = fixedSizeList->list_size(); - typeSize = fixedSizeList->field(0)->type()->byte_width(); - } else if (auto vlaListType = std::dynamic_pointer_cast(physicalField->type())) { - listSize = -1; - typeSize = vlaListType->field(0)->type()->byte_width(); - } - if (listSize == -1) { - mSizeBranch = branch->GetTree()->GetBranch((std::string{branch->GetName()} + "_size").c_str()); - offsetBuffer = std::make_unique(TBuffer::EMode::kWrite, 4 * 1024 * 1024); - result = arrow::AllocateResizableBuffer((totalEntries + 1) * (int64_t)sizeof(int), pool); - if (!result.ok()) { - throw runtime_error("Cannot allocate offset buffer"); - } - arrowOffsetBuffer = std::move(result).ValueUnsafe(); - unsigned char* ptrOffset = arrowOffsetBuffer->mutable_data(); - auto* tPtrOffset = reinterpret_cast(ptrOffset); - offsets = std::span{tPtrOffset, tPtrOffset + totalEntries + 1}; - - // read sizes first - while (readEntries < totalEntries) { - auto readLast = mSizeBranch->GetBulkRead().GetEntriesSerialized(readEntries, *offsetBuffer); - readEntries += readLast; - for (auto i = 0; i < readLast; ++i) { - offsets[count++] = (int)offset; - offset += swap32_(reinterpret_cast(offsetBuffer->GetCurrent())[i]); - } - } - offsets[count] = (int)offset; - totalSize = offset; - readEntries = 0; - } - - while (readEntries < totalEntries) { - auto readLast = branch->GetBulkRead().GetEntriesSerialized(readEntries, buffer); - if (listSize == -1) { - size = offsets[readEntries + readLast] - offsets[readEntries]; - } else { - size = readLast * listSize; - } - readEntries += readLast; - swapCopy(ptr, buffer.GetCurrent(), size, typeSize); - ptr += (ptrdiff_t)(size * typeSize); - } - if (listSize >= 1) { - totalSize = readEntries * listSize; - } - std::shared_ptr varray; - switch (listSize) { - case -1: - varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); - array = std::make_shared(physicalField->type(), readEntries, arrowOffsetBuffer, varray); - break; - case 1: - array = std::make_shared(physicalField->type(), readEntries, arrowValuesBuffer); - break; - default: - varray = std::make_shared(physicalField->type()->field(0)->type(), totalSize, arrowValuesBuffer); - array = std::make_shared(physicalField->type(), readEntries, varray); - } - } - - branch->SetStatus(false); - branch->DropBaskets("all"); - branch->Reset(); - branch->GetTransientBuffer(0)->Expand(0); - - columns.push_back(array); - } - auto batch = arrow::RecordBatch::Make(dataset_schema, rows, columns); - totalCompressedSize += tree->GetZipBytes(); - totalUncompressedSize += tree->GetTotBytes(); - return batch; - }; - return generator; -} - -arrow::Result> TTreeFileSystem::OpenOutputStream( - const std::string& path, - const std::shared_ptr& metadata) -{ - arrow::dataset::FileSource source{path, shared_from_this()}; - auto prefix = metadata->Get("branch_prefix"); - if (prefix.ok()) { - return std::make_shared(GetTree(source), *prefix); - } - return std::make_shared(GetTree(source), ""); -} - -TBufferFileFS::TBufferFileFS(TBufferFile* f) +TBufferFileFS::TBufferFileFS(TBufferFile* f, RootObjectReadingFactory& factory) : VirtualRootFileSystemBase(), mBuffer(f), - mFilesystem(nullptr) + mFilesystem(nullptr), + mObjectFactory(factory) { } -TTreeFileSystem::~TTreeFileSystem() = default; - arrow::Result TBufferFileFS::GetFileInfo(const std::string& path) { arrow::fs::FileInfo result; @@ -871,19 +223,26 @@ arrow::Result TBufferFileFS::GetFileInfo(const std::string& return result; } - // For now we only support single trees. - if (std::dynamic_pointer_cast(mFilesystem)) { - result.set_type(arrow::fs::FileType::File); + auto info = mFilesystem->GetFileInfo(path); + if (!info.ok()) { return result; } + + result.set_type(info->type()); return result; } std::shared_ptr TBufferFileFS::GetSubFilesystem(arrow::dataset::FileSource source) { - if (!mFilesystem.get()) { - auto tree = ((TTree*)mBuffer->ReadObject(TTree::Class())); - mFilesystem = std::make_shared(tree); + // We use a plugin to create the actual objects inside the + // file, so that we can support TTree and RNTuple at the same time + // without having to depend on both. + for (auto& capability : mObjectFactory.capabilities) { + void* handle = capability.getBufferHandle(mBuffer, source.path()); + if (handle) { + mFilesystem = capability.factory().getSubFilesystem(handle); + break; + } } return mFilesystem; } diff --git a/Framework/Core/test/o2AO2DToAO3D.cxx b/Framework/Core/test/o2AO2DToAO3D.cxx new file mode 100644 index 0000000000000..25fa292d66ed9 --- /dev/null +++ b/Framework/Core/test/o2AO2DToAO3D.cxx @@ -0,0 +1,165 @@ +// Copyright 2019-2024 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. +#include "Framework/RootArrowFilesystem.h" +#include "Framework/PluginManager.h" +#include +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char** argv) +{ + + char* input_file = nullptr; + char* output_file = nullptr; + + // Define long options + static struct option long_options[] = { + {"input", required_argument, nullptr, 'i'}, + {"output", required_argument, nullptr, 'o'}, + {nullptr, 0, nullptr, 0} // End of options + }; + + int option_index = 0; + int c; + + // Parse options + while ((c = getopt_long(argc, argv, "i:o:", long_options, &option_index)) != -1) { + switch (c) { + case 'i': + input_file = optarg; + break; + case 'o': + output_file = optarg; + break; + case '?': + // Unknown option + printf("Unknown option. Use --input and --output \n"); + return 1; + default: + break; + } + } + + // Check if input and output files are provided + if (input_file && output_file) { + printf("Input file: %s\n", input_file); + printf("Output file: %s\n", output_file); + } else { + fprintf(stderr, "Usage: %s --input --output \n", argv[0]); + return 1; + } + + // Plugins which understand + std::vector capabilitiesSpecs = { + "O2Framework:RNTupleObjectReadingCapability", + "O2Framework:TTreeObjectReadingCapability", + }; + + o2::framework::RootObjectReadingFactory factory; + + std::vector plugins; + for (auto spec : capabilitiesSpecs) { + auto morePlugins = o2::framework::PluginManager::parsePluginSpecString(spec); + for (auto& extra : morePlugins) { + plugins.push_back(extra); + } + } + + auto in = TFile::Open(input_file, "READ"); + auto out = TFile::Open(output_file, "RECREATE"); + + auto fs = std::make_shared(in, 50 * 1024 * 1024, factory); + auto outFs = std::make_shared(out, 0, factory); + + o2::framework::PluginManager::loadFromPlugin(plugins, factory.capabilities); + + // Plugins are hardcoded for now... + auto rNtupleFormat = factory.capabilities[0].factory().format(); + auto format = factory.capabilities[1].factory().format(); + + for (TObject* dk : *in->GetListOfKeys()) { + if (dk->GetName() == std::string("metaData")) { + TMap* m = dynamic_cast(in->Get(dk->GetName())); + m->Print(); + auto* copy = m->Clone("metaData"); + out->WriteTObject(copy); + continue; + } + auto* d = (TDirectory*)in->Get(dk->GetName()); + std::cout << "Processing: " << dk->GetName() << std::endl; + // For the moment RNTuple does not support TDirectory, so + // we write everything at toplevel. + auto destination = outFs->OpenOutputStream("/", {}); + if (!destination.ok()) { + std::cerr << "Could not open destination folder " << output_file << std::endl; + exit(1); + } + + for (TObject* tk : *d->GetListOfKeys()) { + auto sourceUrl = fmt::format("{}/{}", dk->GetName(), tk->GetName()); + // FIXME: there is no support for TDirectory yet. Let's write everything + // at the same level. + auto destUrl = fmt::format("/{}-{}", dk->GetName(), tk->GetName()); + arrow::dataset::FileSource source(sourceUrl, fs); + if (!format->IsSupported(source).ok()) { + std::cout << "Source " << source.path() << " is not supported" << std::endl; + continue; + } + std::cout << " Processing tree: " << tk->GetName() << std::endl; + auto schemaOpt = format->Inspect(source); + if (!schemaOpt.ok()) { + std::cout << "Could not inspect source " << source.path() << std::endl; + } + auto schema = *schemaOpt; + auto fragment = format->MakeFragment(source, {}, schema); + if (!fragment.ok()) { + std::cout << "Could not make fragment from " << source.path() << "with schema:" << schema->ToString() << std::endl; + continue; + } + auto options = std::make_shared(); + options->dataset_schema = schema; + auto scanner = format->ScanBatchesAsync(options, *fragment); + if (!scanner.ok()) { + std::cout << "Scanner not ok" << std::endl; + continue; + } + auto batches = (*scanner)(); + auto result = batches.result(); + if (!result.ok()) { + std::cout << "Could not get batches." << std::endl; + continue; + } + std::cout << " Found a table with " << (*result)->columns().size() << " columns " << (*result)->num_rows() << " rows." << std::endl; + + if ((*result)->num_rows() == 0) { + std::cout << "Empty table, skipping for now" << std::endl; + continue; + } + arrow::fs::FileLocator locator{outFs, destUrl}; + std::cout << schema->ToString() << std::endl; + auto writer = rNtupleFormat->MakeWriter(*destination, schema, {}, locator); + auto success = writer->get()->Write(*result); + if (!success.ok()) { + std::cout << "Error while writing" << std::endl; + continue; + } + } + out->ls(); + auto rootDestination = std::dynamic_pointer_cast(*destination); + } + in->Close(); + out->Close(); +} diff --git a/Framework/Core/test/test_Root2ArrowTable.cxx b/Framework/Core/test/test_Root2ArrowTable.cxx index 8440e942903a5..8eb3a9825f0f7 100644 --- a/Framework/Core/test/test_Root2ArrowTable.cxx +++ b/Framework/Core/test/test_Root2ArrowTable.cxx @@ -14,6 +14,7 @@ #include "Framework/TableBuilder.h" #include "Framework/RootTableBuilderHelpers.h" #include "Framework/ASoA.h" +#include "Framework/PluginManager.h" #include "../src/ArrowDebugHelpers.h" #include @@ -26,6 +27,13 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -232,10 +240,31 @@ TEST_CASE("RootTree2Fragment") file->WriteObjectAny(&t1, t1.Class()); auto* fileRead = new TBufferFile(TBuffer::kRead, file->BufferSize(), file->Buffer(), false, nullptr); - size_t totalSizeCompressed = 0; - size_t totalSizeUncompressed = 0; - auto format = std::make_shared(totalSizeCompressed, totalSizeUncompressed); - auto fs = std::make_shared(fileRead); + std::vector capabilitiesSpecs = { + "O2Framework:RNTupleObjectReadingCapability", + "O2Framework:TTreeObjectReadingCapability", + }; + + std::vector plugins; + for (auto spec : capabilitiesSpecs) { + auto morePlugins = PluginManager::parsePluginSpecString(spec); + for (auto& extra : morePlugins) { + plugins.push_back(extra); + } + } + REQUIRE(plugins.size() == 2); + + RootObjectReadingFactory factory; + std::vector configDiscoverySpec = {}; + PluginManager::loadFromPlugin(plugins, factory.capabilities); + REQUIRE(factory.capabilities.size() == 2); + REQUIRE(factory.capabilities[0].name == "rntuple"); + REQUIRE(factory.capabilities[1].name == "ttree"); + + // Plugins are hardcoded for now... + auto format = factory.capabilities[1].factory().format(); + + auto fs = std::make_shared(fileRead, factory); arrow::dataset::FileSource source("p", fs); REQUIRE(format->IsSupported(source) == true); @@ -439,10 +468,34 @@ TEST_CASE("RootTree2Dataset") } f->Write(); - size_t totalSizeCompressed = 0; - size_t totalSizeUncompressed = 0; - auto format = std::make_shared(totalSizeCompressed, totalSizeUncompressed); - auto fs = std::make_shared(f, 50 * 1024 * 1024); + std::vector capabilitiesSpecs = { + "O2Framework:RNTupleObjectReadingCapability", + "O2Framework:TTreeObjectReadingCapability", + }; + + RootObjectReadingFactory factory; + + std::vector plugins; + for (auto spec : capabilitiesSpecs) { + auto morePlugins = PluginManager::parsePluginSpecString(spec); + for (auto& extra : morePlugins) { + plugins.push_back(extra); + } + } + REQUIRE(plugins.size() == 2); + + PluginManager::loadFromPlugin(plugins, factory.capabilities); + + REQUIRE(factory.capabilities.size() == 2); + REQUIRE(factory.capabilities[0].name == "rntuple"); + REQUIRE(factory.capabilities[1].name == "ttree"); + + // Plugins are hardcoded for now... + auto rNtupleFormat = factory.capabilities[0].factory().format(); + auto format = factory.capabilities[1].factory().format(); + + auto fs = std::make_shared(f, 50 * 1024 * 1024, factory); + arrow::dataset::FileSource source("DF_2/tracks", fs); REQUIRE(format->IsSupported(source) == true); auto schemaOpt = format->Inspect(source); @@ -464,7 +517,7 @@ TEST_CASE("RootTree2Dataset") validateContents(*result); auto* output = new TMemFile("foo", "RECREATE"); - auto outFs = std::make_shared(output, 0); + auto outFs = std::make_shared(output, 0, factory); // Open a stream at toplevel auto destination = outFs->OpenOutputStream("/", {}); @@ -503,4 +556,38 @@ TEST_CASE("RootTree2Dataset") REQUIRE((*resultWritten)->num_rows() == 100); validateContents(*resultWritten); } + arrow::fs::FileLocator rnTupleLocator{outFs, "/rntuple"}; + // We write an RNTuple in the same TMemFile, using /rntuple as a location + auto rntupleDestination = std::dynamic_pointer_cast(*destination); + + { + auto rNtupleWriter = rNtupleFormat->MakeWriter(*destination, schema, {}, rnTupleLocator); + auto rNtupleSuccess = rNtupleWriter->get()->Write(*result); + REQUIRE(rNtupleSuccess.ok()); + } + + // And now we can read back the RNTuple into a RecordBatch + arrow::dataset::FileSource writtenRntupleSource("/rntuple", outFs); + auto newRNTupleFS = outFs->GetSubFilesystem(writtenRntupleSource); + + REQUIRE(rNtupleFormat->IsSupported(writtenRntupleSource) == true); + + auto rntupleSchemaOpt = rNtupleFormat->Inspect(writtenRntupleSource); + REQUIRE(rntupleSchemaOpt.ok()); + auto rntupleSchemaWritten = *rntupleSchemaOpt; + REQUIRE(validateSchema(rntupleSchemaWritten)); + + auto rntupleFragmentWritten = rNtupleFormat->MakeFragment(writtenRntupleSource, {}, rntupleSchemaWritten); + REQUIRE(rntupleFragmentWritten.ok()); + auto rntupleOptionsWritten = std::make_shared(); + rntupleOptionsWritten->dataset_schema = rntupleSchemaWritten; + auto rntupleScannerWritten = rNtupleFormat->ScanBatchesAsync(rntupleOptionsWritten, *rntupleFragmentWritten); + REQUIRE(rntupleScannerWritten.ok()); + auto rntupleBatchesWritten = (*rntupleScannerWritten)(); + auto rntupleResultWritten = rntupleBatchesWritten.result(); + REQUIRE(rntupleResultWritten.ok()); + REQUIRE((*rntupleResultWritten)->columns().size() == 10); + REQUIRE(validateSchema((*rntupleResultWritten)->schema())); + REQUIRE((*rntupleResultWritten)->num_rows() == 100); + REQUIRE(validateContents(*rntupleResultWritten)); }