From f12fd6bfc9a1bd2853c38741652fef647d4ee0df Mon Sep 17 00:00:00 2001 From: Ilion Beyst Date: Fri, 18 Nov 2022 12:05:55 +0100 Subject: [PATCH] implement runner-initiated split path --- ray_beam_runner/portability/execution.py | 245 +++++++++++++++--- ray_beam_runner/portability/ray_fn_runner.py | 37 ++- .../portability/ray_runner_test.py | 112 +++++--- 3 files changed, 314 insertions(+), 80 deletions(-) diff --git a/ray_beam_runner/portability/execution.py b/ray_beam_runner/portability/execution.py index 4db8940..6dd26fd 100644 --- a/ray_beam_runner/portability/execution.py +++ b/ray_beam_runner/portability/execution.py @@ -24,8 +24,10 @@ import itertools import logging import random +import threading +import time import typing -from typing import List +from typing import List, MutableMapping from typing import Mapping from typing import Optional from typing import Generator @@ -57,6 +59,7 @@ def ray_execute_bundle( transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]], expected_outputs: translations.DataOutput, stage_timers: Mapping[translations.TimerFamilyId, bytes], + split_manager, instruction_request_repr: Mapping[str, typing.Any], dry_run=False, ) -> Generator: @@ -83,8 +86,6 @@ def ray_execute_bundle( runner_context, instruction_request_repr["process_descriptor_id"] ) - _send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id) - input_data = { k: _fetch_decode_data( runner_context, @@ -95,19 +96,35 @@ def ray_execute_bundle( for k, objrefs in input_bundle.input_data.items() } - for transform_id, elements in input_data.items(): - data_out = worker_handler.data_conn.output_stream( - process_bundle_id, transform_id - ) - for byte_stream in elements: - data_out.write(byte_stream) - data_out.close() - expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list( expected_outputs.keys() ) expect_reads.extend(list(stage_timers.keys())) + split_results = [] + split_manager_thread = None + if split_manager: + # TODO(iasoon): synchronization can probably be handled + # more cleanly. + split_manager_started_event = threading.Event() + split_manager_thread = threading.Thread( + target=_run_split_manager, + args=( + runner_context, + worker_handler, + split_manager, + input_data, + transform_buffer_coder, + instruction_request, + split_results, + split_manager_started_event, + ), + ) + split_manager_thread.start() + split_manager_started_event.wait() + + _send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id) + _send_input_data(worker_handler, input_data, process_bundle_id) result_future = worker_handler.control_conn.push(instruction_request) for output in worker_handler.data_conn.input_elements( @@ -125,6 +142,8 @@ def ray_execute_bundle( output_buffers[expected_outputs[output.transform_id]].append(output.data) result: beam_fn_api_pb2.InstructionResponse = result_future.get() + if split_manager_thread: + split_manager_thread.join() if result.process_bundle.requires_finalization: finalize_request = beam_fn_api_pb2.InstructionRequest( @@ -151,14 +170,27 @@ def ray_execute_bundle( process_bundle_descriptor = runner_context.worker_manager.process_bundle_descriptor( instruction_request_repr["process_descriptor_id"] ) - delayed_applications = _retrieve_delayed_applications( + + deferred_inputs = {} + + _add_delayed_applications_to_deferred_inputs( result, process_bundle_descriptor, runner_context, + deferred_inputs, ) - returns.append(len(delayed_applications)) - for pcoll, buffer in delayed_applications.items(): + _add_residuals_and_channel_splits_to_deferred_inputs( + runner_context, + input_bundle.input_data, + transform_buffer_coder, + process_bundle_descriptor, + split_results, + deferred_inputs, + ) + + returns.append(len(deferred_inputs)) + for pcoll, buffer in deferred_inputs.items(): returns.append(pcoll) returns.append(buffer) @@ -206,10 +238,29 @@ def _get_source_transform_name( raise RuntimeError("No IO transform feeds %s" % transform_id) -def _retrieve_delayed_applications( +def _add_delayed_application_to_deferred_inputs( + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + delayed_application: beam_fn_api_pb2.DelayedBundleApplication, + deferred_inputs: MutableMapping[str, List[bytes]], +): + # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it. + # time_delay = delayed_application.requested_time_delay + source_transform = _get_source_transform_name( + process_bundle_descriptor, + delayed_application.application.transform_id, + delayed_application.application.input_id, + ) + + if source_transform not in deferred_inputs: + deferred_inputs[source_transform] = [] + deferred_inputs[source_transform].append(delayed_application.application.element) + + +def _add_delayed_applications_to_deferred_inputs( bundle_result: beam_fn_api_pb2.InstructionResponse, process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, runner_context: "RayRunnerExecutionContext", + deferred_inputs: MutableMapping[str, List[bytes]], ): """Extract delayed applications from a bundle run. @@ -217,26 +268,71 @@ def _retrieve_delayed_applications( delays the consumption of a data element to checkpoint the previous elements in a bundle. """ - delayed_bundles = {} for delayed_application in bundle_result.process_bundle.residual_roots: - # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it. - # time_delay = delayed_application.requested_time_delay - source_transform = _get_source_transform_name( + _add_delayed_application_to_deferred_inputs( process_bundle_descriptor, - delayed_application.application.transform_id, - delayed_application.application.input_id, + delayed_application, + deferred_inputs, ) - if source_transform not in delayed_bundles: - delayed_bundles[source_transform] = [] - delayed_bundles[source_transform].append( - delayed_application.application.element - ) - for consumer, data in delayed_bundles.items(): - delayed_bundles[consumer] = [data] +def _add_residuals_and_channel_splits_to_deferred_inputs( + runner_context: "RayRunnerExecutionContext", + raw_inputs: Mapping[str, List[ray.ObjectRef]], + transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]], + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, + splits: List[beam_fn_api_pb2.ProcessBundleSplitResponse], + deferred_inputs: MutableMapping[str, List[bytes]], +): + prev_split_point = {} # transform id -> first residual offset + for split in splits: + for delayed_application in split.residual_roots: + _add_delayed_application_to_deferred_inputs( + process_bundle_descriptor, + delayed_application, + deferred_inputs, + ) + for channel_split in split.channel_splits: + # Decode all input elements + byte_stream = b"".join( + ( + element + for block in ray.get(raw_inputs[channel_split.transform_id]) + for element in block + ) + ) + input_coder_id = transform_buffer_coder[channel_split.transform_id][1] + input_coder = runner_context.pipeline_context.coders[input_coder_id] + + buffer_id = transform_buffer_coder[channel_split.transform_id][0] + if buffer_id.startswith(b"group:"): + coder_impl = coders.WindowedValueCoder( + coders.TupleCoder( + ( + input_coder.wrapped_value_coder._coders[0], + input_coder.wrapped_value_coder._coders[1]._elem_coder, + ) + ), + input_coder.window_coder, + ).get_impl() + else: + coder_impl = input_coder.get_impl() + + all_elements = list(coder_impl.decode_all(byte_stream)) + + # split at first_residual_element index + end = prev_split_point.get(channel_split.transform_id, len(all_elements)) + residual_elements = all_elements[channel_split.first_residual_element : end] + prev_split_point[ + channel_split.transform_id + ] = channel_split.first_residual_element - return delayed_bundles + if residual_elements: + encoded_residual = coder_impl.encode_all(residual_elements) + + if channel_split.transform_id not in deferred_inputs: + deferred_inputs[channel_split.transform_id] = [] + deferred_inputs[channel_split.transform_id].append(encoded_residual) def _get_input_id(buffer_id, transform_name): @@ -316,6 +412,97 @@ def _send_timers( timer_out.close() +def _send_input_data( + worker_handler: worker_handlers.WorkerHandler, + input_data: Mapping[str, fn_execution.PartitionableBuffer], + process_bundle_id, +): + for transform_id, elements in input_data.items(): + data_out = worker_handler.data_conn.output_stream( + process_bundle_id, transform_id + ) + for byte_stream in elements: + data_out.write(byte_stream) + data_out.close() + + +def _run_split_manager( + runner_context: "RayRunnerExecutionContext", + worker_handler: worker_handlers.WorkerHandler, + split_manager, + inputs: Mapping[str, fn_execution.PartitionableBuffer], + transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]], + instruction_request, + split_results_buf: List[beam_fn_api_pb2.ProcessBundleSplitResponse], + split_manager_started_event: threading.Event, +): + read_transform_id, buffer_data = translations.only_element(inputs.items()) + byte_stream = b"".join(buffer_data or []) + coder_id = transform_buffer_coder[read_transform_id][1] + coder_impl = runner_context.pipeline_context.coders[coder_id].get_impl() + num_elements = len(list(coder_impl.decode_all(byte_stream))) + + # Start the split manager in case it wants to set any breakpoints. + split_manager_generator = split_manager(num_elements) + try: + split_fraction = next(split_manager_generator) + done = False + except StopIteration: + split_fraction = None + done = True + + split_manager_started_event.set() + + assert worker_handler is not None + + # Execute the requested splits. + while not done: + if split_fraction is None: + split_result = None + else: + DesiredSplit = beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit + split_request = beam_fn_api_pb2.InstructionRequest( + process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest( + instruction_id=instruction_request.instruction_id, + desired_splits={ + read_transform_id: DesiredSplit( + fraction_of_remainder=split_fraction, + estimated_input_elements=num_elements, + ) + }, + ) + ) + split_response = worker_handler.control_conn.push( + split_request + ).get() # type: beam_fn_api_pb2.InstructionResponse + for t in (0.05, 0.1, 0.2): + if ( + "Unknown process bundle" in split_response.error + or split_response.process_bundle_split + == beam_fn_api_pb2.ProcessBundleSplitResponse() + ): + time.sleep(t) + split_response = worker_handler.control_conn.push( + split_request + ).get() + if ( + "Unknown process bundle" in split_response.error + or split_response.process_bundle_split + == beam_fn_api_pb2.ProcessBundleSplitResponse() + ): + # It may have finished too fast. + split_result = None + elif split_response.error: + raise RuntimeError(split_response.error) + else: + split_result = split_response.process_bundle_split + split_results_buf.append(split_result) + try: + split_fraction = split_manager_generator.send(split_result) + except StopIteration: + break + + @ray.remote class _RayRunnerStats: def __init__(self): diff --git a/ray_beam_runner/portability/ray_fn_runner.py b/ray_beam_runner/portability/ray_fn_runner.py index d5694d1..aca4710 100644 --- a/ray_beam_runner/portability/ray_fn_runner.py +++ b/ray_beam_runner/portability/ray_fn_runner.py @@ -118,6 +118,23 @@ def _pipeline_checks( return pipeline_proto +def _select_split_manager( + process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, +): + """Return the split manager to use for a certain ProcessBundleDescriptor""" + unique_names = { + t.unique_name for t in process_bundle_descriptor.transforms.values() + } + for stage_name, candidate in reversed(fn_runner._split_managers): + if stage_name in unique_names or (stage_name + "/Process") in unique_names: + split_manager = candidate + break + else: + split_manager = None + + return split_manager + + class RayFnApiRunner(runner.PipelineRunner): def __init__( self, @@ -252,7 +269,7 @@ def _run_stage( ( last_result, fired_timers, - delayed_applications, + deferred_inputs, bundle_outputs, ) = self._run_bundle( runner_execution_context, @@ -261,7 +278,7 @@ def _run_stage( ) final_result = merge_stage_results(final_result, last_result) - if not delayed_applications and not fired_timers: + if not deferred_inputs and not fired_timers: break else: # TODO: Enable following assertion after watermarking is implemented @@ -275,7 +292,9 @@ def _run_stage( # runner_execution_context.watermark_manager # .get_stage_node.remote( # bundle_context_manager.stage.name))) - input_data = delayed_applications + + # + input_data = {k: [v] for k, v in deferred_inputs.items()} input_timers = fired_timers # Store the required downstream side inputs into state so it is accessible @@ -310,6 +329,7 @@ def _run_bundle( ) process_bundle_descriptor = bundle_context_manager.process_bundle_descriptor + split_manager = _select_split_manager(process_bundle_descriptor) # TODO(pabloem): Are there two different IDs? the Bundle ID and PBD ID? process_bundle_id = "bundle_%s" % process_bundle_descriptor.id @@ -321,6 +341,7 @@ def _run_bundle( transform_to_buffer_coder, data_output, stage_timers, + split_manager, instruction_request_repr={ "instruction_id": process_bundle_id, "process_descriptor_id": pbd_id, @@ -340,12 +361,12 @@ def _run_bundle( output.append(pcoll) runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) - delayed_applications = {} - num_delayed_applications = ray.get(next(result_generator)) - for _ in range(num_delayed_applications): + deferred_inputs = {} + num_deferred_inputs = ray.get(next(result_generator)) + for _ in range(num_deferred_inputs): pcoll = ray.get(next(result_generator)) data_ref = next(result_generator) - delayed_applications[pcoll] = data_ref + deferred_inputs[pcoll] = data_ref runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) ( @@ -365,7 +386,7 @@ def _run_bundle( # coder_impl=bundle_context_manager.get_input_coder_impl( # other_input)) - return result, newly_set_timers, delayed_applications, output + return result, newly_set_timers, deferred_inputs, output @staticmethod def _collect_written_timers( diff --git a/ray_beam_runner/portability/ray_runner_test.py b/ray_beam_runner/portability/ray_runner_test.py index 2ca1bcb..ece8625 100644 --- a/ray_beam_runner/portability/ray_runner_test.py +++ b/ray_beam_runner/portability/ray_runner_test.py @@ -1617,18 +1617,18 @@ def has_mi_for_ptransform(mon_infos, ptransform): raise -@unittest.skip("Runner-initiated splitting not yet supported") class RayRunnerSplitTest(unittest.TestCase): def setUp(self) -> None: - if not ray.is_initialized(): - ray.init(num_cpus=1, include_dashboard=False) + ray.init(num_cpus=1, include_dashboard=False, ignore_reinit_error=True) def tearDown(self) -> None: ray.shutdown() def create_pipeline(self, is_drain=False): return beam.Pipeline( - runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner() + runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner( + is_drain=is_drain + ) ) def test_checkpoint(self): @@ -1646,19 +1646,32 @@ def split_manager(num_elements): # Split as close to current as possible. split_result = yield 0.0 # Verify we split at exactly the first element. - self.verify_channel_split(split_result, 0, 1) + verify_channel_split(split_result, 0, 1) # Continue processing. breakpoint.clear() self.run_split_pipeline(split_manager, list("abc"), element_counter) def test_split_half(self): + @ray.remote(num_cpus=0) + class ListActor: + def __init__(self): + self._list = [] + + def append(self, elem): + self._list.append(elem) + + def get(self): + return self._list + + seen_bundle_sizes_actor = ListActor.remote() + total_num_elements = 25 seen_bundle_sizes = [] element_counter = ElementCounter() def split_manager(num_elements): - seen_bundle_sizes.append(num_elements) + seen_bundle_sizes_actor.append.remote(num_elements) if num_elements == total_num_elements: element_counter.reset() breakpoint = element_counter.set_breakpoint(5) @@ -1666,14 +1679,15 @@ def split_manager(num_elements): breakpoint.wait() # Split the remainder (20, then 10, elements) in half. split1 = yield 0.5 - self.verify_channel_split(split1, 14, 15) # remainder is 15 to end + verify_channel_split(split1, 14, 15) # remainder is 15 to end split2 = yield 0.5 - self.verify_channel_split(split2, 9, 10) # remainder is 10 to end + verify_channel_split(split2, 9, 10) # remainder is 10 to end breakpoint.clear() self.run_split_pipeline( split_manager, range(total_num_elements), element_counter ) + seen_bundle_sizes = ray.get(seen_bundle_sizes_actor.get.remote()) self.assertEqual([25, 15], seen_bundle_sizes) def run_split_pipeline(self, split_manager, elements, element_counter=None): @@ -1710,21 +1724,33 @@ def split_manager(num_elements): def run_sdf_split_half(self, is_drain=False): element_counter = ElementCounter() - is_first_bundle = True + + @ray.remote(num_cpus=0) + class IsFirstBundleActor: + def __init__(self): + self._seen_first = False + + # return whether the caller is the first to call this method + def check_first(self): + if not self._seen_first: + self._seen_first = True + return True + return False + + is_first_bundle_actor = IsFirstBundleActor.remote() def split_manager(num_elements): - nonlocal is_first_bundle + is_first_bundle = ray.get(is_first_bundle_actor.check_first.remote()) if is_first_bundle and num_elements > 0: - is_first_bundle = False breakpoint = element_counter.set_breakpoint(1) yield breakpoint.wait() split1 = yield 0.5 split2 = yield 0.5 split3 = yield 0.5 - self.verify_channel_split(split1, 0, 1) - self.verify_channel_split(split2, -1, 1) - self.verify_channel_split(split3, -1, 1) + verify_channel_split(split1, 0, 1) + verify_channel_split(split2, -1, 1) + verify_channel_split(split3, -1, 1) breakpoint.clear() elements = [4, 4] @@ -1765,7 +1791,6 @@ def split_manager(num_elements): _LOGGER.error("test_split_crazy_sdf.seed = %s", seed) raise - @unittest.skip("SDF not yet supported") def test_nosplit_sdf(self): def split_manager(num_elements): yield @@ -1776,27 +1801,21 @@ def split_manager(num_elements): split_manager, elements, ElementCounter(), expected_groups ) - @unittest.skip("SDF not yet supported") def test_checkpoint_sdf(self): self.run_sdf_checkpoint(is_drain=False) - @unittest.skip("SDF not yet supported") def test_checkpoint_draining_sdf(self): self.run_sdf_checkpoint(is_drain=True) - @unittest.skip("SDF not yet supported") def test_split_half_sdf(self): self.run_sdf_split_half(is_drain=False) - @unittest.skip("SDF not yet supported") def test_split_half_draining_sdf(self): self.run_sdf_split_half(is_drain=True) - @unittest.skip("SDF not yet supported") def test_split_crazy_sdf(self, seed=None): self.run_split_crazy_sdf(seed=seed, is_drain=False) - @unittest.skip("SDF not yet supported") def test_split_crazy_draining_sdf(self, seed=None): self.run_split_crazy_sdf(seed=seed, is_drain=True) @@ -1858,29 +1877,31 @@ def process( grouped, equal_to(expected_groups), label="CheckGrouped" ) - def verify_channel_split(self, split_result, last_primary, first_residual): - self.assertEqual(1, len(split_result.channel_splits), split_result) - (channel_split,) = split_result.channel_splits - self.assertEqual(last_primary, channel_split.last_primary_element) - self.assertEqual(first_residual, channel_split.first_residual_element) - # There should be a primary and residual application for each element - # not covered above. - self.assertEqual( - first_residual - last_primary - 1, - len(split_result.primary_roots), - split_result.primary_roots, - ) - self.assertEqual( - first_residual - last_primary - 1, - len(split_result.residual_roots), - split_result.residual_roots, - ) + +# use regular asserts here since unittest.TestCase cannot be pickled +def verify_channel_split(split_result, last_primary, first_residual): + assert 1 == len(split_result.channel_splits), split_result + (channel_split,) = split_result.channel_splits + assert last_primary == channel_split.last_primary_element + assert first_residual == channel_split.first_residual_element + # There should be a primary and residual application for each element + # not covered above. + assert first_residual - last_primary - 1 == len( + split_result.primary_roots + ), split_result.primary_roots + assert first_residual - last_primary - 1 == len( + split_result.residual_roots + ), split_result.residual_roots class ElementCounter(object): """Used to wait until a certain number of elements are seen.""" - def __init__(self): + def __init__(self, name=None): + if name is None: + name = uuid.uuid4().hex + self._name = name + self._cv = threading.Condition() self.reset() @@ -1920,16 +1941,21 @@ def clear(): return Breakpoint() def __reduce__(self): - # Ensure we get the same element back through a pickling round-trip. - name = uuid.uuid4().hex - _pickled_element_counters[name] = self - return _unpickle_element_counter, (name,) + # Ensure that when we unpickle a counter multiple times + # within the same process, all instances will be backed by the same object. + # This is required for the per-process locks to work. + # For the current test suite, having counters that work across + # process boundaries is not required. + _pickled_element_counters[self._name] = self + return _unpickle_element_counter, (self._name,) _pickled_element_counters = {} # type: Dict[str, ElementCounter] def _unpickle_element_counter(name): + if name not in _pickled_element_counters: + _pickled_element_counters[name] = ElementCounter(name=name) return _pickled_element_counters[name]