Skip to content

Commit

Permalink
Complete user-initiated SDF functionality (#52)
Browse files Browse the repository at this point in the history
* correctly set is_drain parameter

* enable passing runner tests

* Support deferred applications in drain mode

* implement bundle finalization
  • Loading branch information
iasoon authored Nov 10, 2022
1 parent 3339959 commit 5dd1eb4
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 25 deletions.
74 changes: 59 additions & 15 deletions ray_beam_runner/portability/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ def ray_execute_bundle(
output_buffers[expected_outputs[output.transform_id]].append(output.data)

result: beam_fn_api_pb2.InstructionResponse = result_future.get()

if result.process_bundle.requires_finalization:
finalize_request = beam_fn_api_pb2.InstructionRequest(
finalize_bundle=beam_fn_api_pb2.FinalizeBundleRequest(
instruction_id=process_bundle_id
)
)
finalize_response = worker_handler.control_conn.push(finalize_request).get()
if finalize_response.error:
raise RuntimeError(finalize_response.error)

returns = [result.SerializeToString()]

returns.append(len(output_buffers))
Expand Down Expand Up @@ -155,6 +166,46 @@ def ray_execute_bundle(
yield ret


def _get_source_transform_name(
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
transform_id: str,
input_id: str,
) -> str:
"""Find the name of the source PTransform that feeds into the given
(transform_id, input_id)."""
input_pcoll = process_bundle_descriptor.transforms[transform_id].inputs[input_id]
for ptransform_id, ptransform in process_bundle_descriptor.transforms.items():
# The GrpcRead is directly followed by the SDF/Process.
if (
ptransform.spec.urn == bundle_processor.DATA_INPUT_URN
and input_pcoll in ptransform.outputs.values()
):
return ptransform_id

# The GrpcRead is followed by SDF/Truncate -> SDF/Process.
# We need to traverse the TRUNCATE_SIZED_RESTRICTION node in order
# to find the original source PTransform.
if (
ptransform.spec.urn
== common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn
and input_pcoll in ptransform.outputs.values()
):
input_pcoll_ = translations.only_element(
process_bundle_descriptor.transforms[ptransform_id].inputs.values()
)
for (
ptransform_id_2,
ptransform_2,
) in process_bundle_descriptor.transforms.items():
if (
ptransform_2.spec.urn == bundle_processor.DATA_INPUT_URN
and input_pcoll_ in ptransform_2.outputs.values()
):
return ptransform_id_2

raise RuntimeError("No IO transform feeds %s" % transform_id)


def _retrieve_delayed_applications(
bundle_result: beam_fn_api_pb2.InstructionResponse,
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
Expand All @@ -170,22 +221,15 @@ def _retrieve_delayed_applications(
for delayed_application in bundle_result.process_bundle.residual_roots:
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
# time_delay = delayed_application.requested_time_delay
transform = process_bundle_descriptor.transforms[
delayed_application.application.transform_id
]
pcoll_name = transform.inputs[delayed_application.application.input_id]

consumer_transform = translations.only_element(
[
read_id
for read_id, proto in process_bundle_descriptor.transforms.items()
if proto.spec.urn == bundle_processor.DATA_INPUT_URN
and pcoll_name in proto.outputs.values()
]
source_transform = _get_source_transform_name(
process_bundle_descriptor,
delayed_application.application.transform_id,
delayed_application.application.input_id,
)
if consumer_transform not in delayed_bundles:
delayed_bundles[consumer_transform] = []
delayed_bundles[consumer_transform].append(

if source_transform not in delayed_bundles:
delayed_bundles[source_transform] = []
delayed_bundles[source_transform].append(
delayed_application.application.element
)

Expand Down
5 changes: 4 additions & 1 deletion ray_beam_runner/portability/ray_fn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,21 @@ def _pipeline_checks(
class RayFnApiRunner(runner.PipelineRunner):
def __init__(
self,
is_drain=False,
) -> None:

"""Creates a new Ray Runner instance.
Args:
progress_request_frequency: The frequency (in seconds) that the runner
waits before requesting progress from the SDK.
is_drain: identify whether expand the sdf graph in the drain mode.
"""
super().__init__()
# TODO: figure out if this is necessary (probably, later)
self._progress_frequency = None
self._cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator()
self._is_drain = is_drain

@staticmethod
def supported_requirements():
Expand Down Expand Up @@ -183,7 +186,7 @@ def run_pipeline(
]
),
use_state_iterables=False,
is_drain=False,
is_drain=self._is_drain,
)
return self.execute_pipeline(stage_context, stages)

Expand Down
13 changes: 4 additions & 9 deletions ray_beam_runner/portability/ray_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def tearDown(self) -> None:

def create_pipeline(self, is_drain=False):
return beam.Pipeline(
runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner()
runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner(
is_drain=is_drain
)
)

def test_assert_that(self):
Expand Down Expand Up @@ -684,7 +686,6 @@ def process(
actual = p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())
assert_that(actual, equal_to(list("".join(data))))

@unittest.skip("SDF not yet supported")
def test_sdf_with_dofn_as_watermark_estimator(self):
class ExpandingStringsDoFn(beam.DoFn, beam.WatermarkEstimatorProvider):
def initial_estimator_state(self, element, restriction):
Expand Down Expand Up @@ -758,11 +759,9 @@ def process(
def test_sdf_with_sdf_initiated_checkpointing(self):
self.run_sdf_initiated_checkpointing(is_drain=False)

@unittest.skip("SDF not yet supported")
def test_draining_sdf_with_sdf_initiated_checkpointing(self):
self.run_sdf_initiated_checkpointing(is_drain=True)

@unittest.skip("SDF not yet supported")
def test_sdf_default_truncate_when_bounded(self):
class SimleSDF(beam.DoFn):
def process(
Expand All @@ -782,7 +781,6 @@ def process(
actual = p | beam.Create([10]) | beam.ParDo(SimleSDF())
assert_that(actual, equal_to(range(10)))

@unittest.skip("SDF not yet supported")
def test_sdf_default_truncate_when_unbounded(self):
class SimleSDF(beam.DoFn):
def process(
Expand All @@ -802,7 +800,6 @@ def process(
actual = p | beam.Create([10]) | beam.ParDo(SimleSDF())
assert_that(actual, equal_to([]))

@unittest.skip("SDF not yet supported")
def test_sdf_with_truncate(self):
class SimleSDF(beam.DoFn):
def process(
Expand Down Expand Up @@ -1042,7 +1039,6 @@ def process(self, element, bundle_finalizer=beam.DoFn.BundleFinalizerParam):
)
assert_that(res, equal_to(["1", "2"]))

@unittest.skip("SDF not yet supported")
def test_register_finalizations(self):
event_recorder = EventRecorder(tempfile.gettempdir())

Expand Down Expand Up @@ -1086,7 +1082,6 @@ def process(

event_recorder.cleanup()

@unittest.skip("Combiners not yet supported")
def test_sdf_synthetic_source(self):
common_attrs = {
"key_size": 1,
Expand Down Expand Up @@ -1188,7 +1183,7 @@ def expand(self, pcoll):
any(re.match(packed_step_name_regex, s) for s in step_names)
)

@unittest.skip("Combiners not yet supported")
@unittest.skip("Metrics not yet supported")
def test_pack_combiners(self):
self._test_pack_combiners(assert_using_counter_names=True)

Expand Down

0 comments on commit 5dd1eb4

Please sign in to comment.