Skip to content

Commit

Permalink
added vector support to extractObjectFrom function
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal committed Jan 18, 2024
1 parent 9a79617 commit 981fb00
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 23 deletions.
1 change: 1 addition & 0 deletions Utilities/Mergers/include/Mergers/LinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
#pragma link C++ class o2::mergers::MergeInterface + ;
#pragma link C++ class o2::mergers::CustomMergeableObject + ;
#pragma link C++ class o2::mergers::CustomMergeableTObject + ;
#pragma link C++ class std::vector < TObject*> + ;

#endif
2 changes: 2 additions & 0 deletions Utilities/Mergers/include/Mergers/MergerAlgorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "Mergers/MergeInterface.h"

class TObject;
class VectorOfTObject;

namespace o2::mergers::algorithm
{
Expand All @@ -33,6 +34,7 @@ void merge(TObject* const target, TObject* const other);
/// of targets vector.
void merge(std::vector<TObject*>& targets, const std::vector<TObject*>& others);
void deleteTCollections(TObject* obj);
void deleteVectorTObject(VectorOfTObject* vec);

} // namespace o2::mergers::algorithm

Expand Down
7 changes: 5 additions & 2 deletions Utilities/Mergers/include/Mergers/ObjectStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <variant>
#include <memory>
#include <vector>
#include "Framework/DataRef.h"

class TObject;
Expand All @@ -29,8 +30,10 @@ namespace o2::mergers
class MergeInterface;

using TObjectPtr = std::shared_ptr<TObject>;
using VectorOfTObject = std::vector<TObject*>;
using VectorOfTObjectPtr = std::shared_ptr<VectorOfTObject>;
using MergeInterfacePtr = std::shared_ptr<MergeInterface>;
using ObjectStore = std::variant<std::monostate, TObjectPtr, MergeInterfacePtr>;
using ObjectStore = std::variant<std::monostate, TObjectPtr, VectorOfTObjectPtr, MergeInterfacePtr>;

namespace object_store_helpers
{
Expand All @@ -42,4 +45,4 @@ ObjectStore extractObjectFrom(const framework::DataRef& ref);

} // namespace o2::mergers

#endif //O2_OBJECTSTORE_H
#endif // O2_OBJECTSTORE_H
7 changes: 7 additions & 0 deletions Utilities/Mergers/src/FullHistoryMerger.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ void FullHistoryMerger::mergeCache()
target->merge(other.get());
mObjectsMerged++;
}
} else if (std::holds_alternative<VectorOfTObjectPtr>(mMergedObject)) {
auto target = std::get<VectorOfTObjectPtr>(mMergedObject);
for (auto& [_, entry] : mCache) {
auto other = std::get<VectorOfTObjectPtr>(entry);
algorithm::merge(*target.get(), *other.get());
mObjectsMerged += target->size();
}
}
}

Expand Down
12 changes: 11 additions & 1 deletion Utilities/Mergers/src/MergerAlgorithm.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "Mergers/MergeInterface.h"
#include "Framework/Logger.h"
#include "Mergers/ObjectStore.h"

#include <TH1.h>
#include <TH2.h>
Expand Down Expand Up @@ -132,7 +133,7 @@ void merge(std::vector<TObject*>& targets, const std::vector<TObject*>& others)
{
for (const auto& other : others) {
if (const auto target_same_name = std::find_if(targets.begin(), targets.end(),
[&other](const auto& target) { return other->GetName() == target.GetName(); });
[&other](const auto& target) { return other->GetName() == target->GetName(); });
target_same_name != targets.end()) {
merge(*target_same_name, other);
} else {
Expand All @@ -156,4 +157,13 @@ void deleteTCollections(TObject* obj)
}
}

void deleteVectorTObject(VectorOfTObject* vec)
{
for (auto& tObject : *vec) {
if (tObject != nullptr) {
deleteTCollections(tObject);
}
}
}

} // namespace o2::mergers::algorithm
67 changes: 47 additions & 20 deletions Utilities/Mergers/src/ObjectStore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,47 @@
#include "Mergers/MergeInterface.h"
#include "Mergers/MergerAlgorithm.h"
#include <TObject.h>
#include <string_view>

namespace o2::mergers
{

namespace object_store_helpers
{

constexpr static std::string_view errorPrefix = "Could not extract object to be merged: ";

template <typename... Args>
static std::string concat(Args&&... arguments)
{
std::ostringstream ss;
(ss << ... << arguments);
return std::move(ss.str());
}

template <typename TypeToRead>
void* readObject(TypeToRead&& type, o2::framework::FairTMessage& ftm)
{
using namespace std::string_view_literals;
auto* object = ftm.ReadObjectAny(type);

if (object == nullptr) {
throw std::runtime_error(concat(errorPrefix, "Failed to read object with name '"sv, type->GetName(), "' from message using ROOT serialization."sv));
}
return object;
}

MergeInterface* castToMergeInterface(bool inheritsFromTObject, void* object, TClass* storedClass)
{
using namespace std::string_view_literals;
MergeInterface* objectAsMergeInterface = inheritsFromTObject ? dynamic_cast<MergeInterface*>(static_cast<TObject*>(object)) : static_cast<MergeInterface*>(object);
if (objectAsMergeInterface == nullptr) {
throw std::runtime_error(concat(errorPrefix, "Could not cast '"sv, storedClass->GetName(), "' to MergeInterface"sv));
}

return objectAsMergeInterface;
}

ObjectStore extractObjectFrom(const framework::DataRef& ref)
{
// We do extraction on the low level to efficiently determine if the message
Expand All @@ -34,42 +68,35 @@ ObjectStore extractObjectFrom(const framework::DataRef& ref)
// framework::DataRefUtils::as<MergeInterface>(ref)
// it could cause a memory leak if `ref` contained a non-owning TCollection.
// This way we also avoid doing most of the checks twice.
const static std::string errorPrefix = "Could not extract object to be merged: ";

using namespace std::string_view_literals;
using DataHeader = o2::header::DataHeader;
auto header = framework::DataRefUtils::getHeader<const DataHeader*>(ref);
if (header->payloadSerializationMethod != o2::header::gSerializationMethodROOT) {
throw std::runtime_error(errorPrefix + "It is not ROOT-serialized");
if (framework::DataRefUtils::getHeader<const DataHeader*>(ref)->payloadSerializationMethod != o2::header::gSerializationMethodROOT) {
throw std::runtime_error(concat(errorPrefix, "It is not ROOT-serialized"sv));
}

o2::framework::FairTMessage ftm(const_cast<char*>(ref.payload), o2::framework::DataRefUtils::getPayloadSize(ref));
auto* storedClass = ftm.GetClass();
if (storedClass == nullptr) {
throw std::runtime_error(errorPrefix + "Unknown stored class");
throw std::runtime_error(concat(errorPrefix, "Unknown stored class"sv));
}

auto* mergeInterfaceClass = TClass::GetClass(typeid(MergeInterface));
auto* tObjectClass = TClass::GetClass(typeid(TObject));
if (storedClass->InheritsFrom(TClass::GetClass(typeid(VectorOfTObject)))) {
auto* object = readObject(storedClass, ftm);
return VectorOfTObjectPtr(static_cast<VectorOfTObject*>(object), algorithm::deleteVectorTObject);
}

bool inheritsFromMergeInterface = storedClass->InheritsFrom(mergeInterfaceClass);
bool inheritsFromTObject = storedClass->InheritsFrom(tObjectClass);
const bool inheritsFromMergeInterface = storedClass->InheritsFrom(TClass::GetClass(typeid(MergeInterface)));
const bool inheritsFromTObject = storedClass->InheritsFrom(TClass::GetClass(typeid(TObject)));

if (!inheritsFromMergeInterface && !inheritsFromTObject) {
throw std::runtime_error(
errorPrefix + "Class '" + storedClass->GetName() + "'does not inherit from MergeInterface nor TObject");
throw std::runtime_error(concat(errorPrefix, "Class '"sv, storedClass->GetName(), "'does not inherit from MergeInterface nor TObject"sv));
}

auto* object = ftm.ReadObjectAny(storedClass);
if (object == nullptr) {
throw std::runtime_error(
errorPrefix + "Failed to read object with name '" + storedClass->GetName() + "' from message using ROOT serialization.");
}
auto* object = readObject(storedClass, ftm);

if (inheritsFromMergeInterface) {
MergeInterface* objectAsMergeInterface = inheritsFromTObject ? dynamic_cast<MergeInterface*>(static_cast<TObject*>(object)) : static_cast<MergeInterface*>(object);
if (objectAsMergeInterface == nullptr) {
throw std::runtime_error(errorPrefix + "Could not cast '" + storedClass->GetName() + "' to MergeInterface");
}
auto* objectAsMergeInterface = castToMergeInterface(inheritsFromTObject, object, storedClass);
objectAsMergeInterface->postDeserialization();
return MergeInterfacePtr(objectAsMergeInterface);
} else {
Expand Down

0 comments on commit 981fb00

Please sign in to comment.