-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DistRDF fixes #41
base: main
Are you sure you want to change the base?
DistRDF fixes #41
Changes from all commits
ff956be
bd2ba72
479b010
2fc21fc
66e0435
36cfbff
be98ba6
6c6cc54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
import argparse | ||
import os | ||
from pathlib import Path | ||
from time import time | ||
from typing import Optional, Tuple | ||
|
||
|
@@ -313,18 +312,38 @@ def book_histos( | |
return (results, ml_results) | ||
|
||
|
||
def compile_macro_wrapper(library_path: str): | ||
ROOT.gInterpreter.Declare( | ||
''' | ||
#ifndef R__COMPILE_MACRO_WRAPPER | ||
#define R__COMPILE_MACRO_WRAPPER | ||
int CompileMacroWrapper(const std::string &library_path) | ||
{ | ||
R__LOCKGUARD(gInterpreterMutex); | ||
return gSystem->CompileMacro(library_path.c_str(), "kO"); | ||
} | ||
#endif // R__COMPILE_MACRO_WRAPPER | ||
''') | ||
|
||
if ROOT.CompileMacroWrapper(library_path) != 1: | ||
raise RuntimeError("Failure in TSystem::CompileMacro!") | ||
|
||
def load_cpp(): | ||
"""Load C++ helper functions. Works for both local and distributed execution.""" | ||
try: | ||
# when using distributed RDataFrame 'helpers.cpp' is copied to the local_directory | ||
# of every worker (via `distribute_unique_paths`) | ||
localdir = get_worker().local_directory | ||
cpp_source = Path(localdir) / "helpers.h" | ||
this_worker = get_worker() | ||
except ValueError: | ||
# must be local execution | ||
cpp_source = "helpers.h" | ||
|
||
ROOT.gSystem.CompileMacro(str(cpp_source), "kO") | ||
print("Not on a worker") | ||
return | ||
|
||
if not hasattr(this_worker, "is_library_loaded"): | ||
print("Compiling the macro.") | ||
library_source = "helpers.h" | ||
local_dir = get_worker().local_directory | ||
library_path = os.path.join(local_dir, library_source) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we also check that |
||
compile_macro_wrapper(library_path) | ||
this_worker.is_library_loaded = True | ||
else: | ||
print("Didn't try to compile the macro.") | ||
|
||
|
||
def main() -> None: | ||
|
@@ -355,10 +374,12 @@ def main() -> None: | |
# Setup for distributed RDataFrame | ||
client = create_dask_client(args.scheduler, args.ncores, args.hosts) | ||
if args.inference: | ||
ROOT.RDF.Experimental.Distributed.initialize(load_cpp) | ||
if args.inference: | ||
# TODO: make ml.load_cpp working on distributed | ||
ROOT.RDF.Experimental.Distributed.initialize(ml.load_cpp, "./fastforest") | ||
def load_all(fastforest_path): | ||
load_cpp() | ||
ml.load_cpp(fastforest_path) | ||
Comment on lines
+377
to
+379
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably deserves a comment as to why it's needed. |
||
|
||
# TODO: make ml.load_cpp working on distributed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it not working after this patch? |
||
ROOT.RDF.Experimental.Distributed.initialize(load_all, "./fastforest") | ||
else: | ||
ROOT.RDF.Experimental.Distributed.initialize(load_cpp) | ||
run_graphs = ROOT.RDF.Experimental.Distributed.RunGraphs | ||
|
@@ -379,10 +400,10 @@ def main() -> None: | |
ml_results += ml_hist_list | ||
|
||
# Select the right VariationsFor function depending on RDF or DistRDF | ||
if type(df).__module__ == "DistRDF.Proxy": | ||
variationsfor_func = ROOT.RDF.Experimental.Distributed.VariationsFor | ||
else: | ||
if args.scheduler == "mt": | ||
variationsfor_func = ROOT.RDF.Experimental.VariationsFor | ||
else: | ||
variationsfor_func = ROOT.RDF.Experimental.Distributed.VariationsFor | ||
for r in results + ml_results: | ||
if r.should_vary: | ||
r.histo = variationsfor_func(r.histo) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really needed, more for debugging purposes