diff --git a/ray_beam_runner/portability/context_management.py b/ray_beam_runner/portability/context_management.py index 5b12f46..af2105a 100644 --- a/ray_beam_runner/portability/context_management.py +++ b/ray_beam_runner/portability/context_management.py @@ -163,6 +163,7 @@ def setup(self): pcoll_id, self.execution_context.safe_coders.get(coder_id, coder_id), ) + elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: data_output[transform.unique_name] = pcoll_id coder_id = self.execution_context.data_channel_coders[ @@ -170,8 +171,10 @@ def setup(self): ] else: raise NotImplementedError - data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) - transform.spec.payload = data_spec.SerializeToString() + if pcoll_id != translations.IMPULSE_BUFFER: + data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) + transform.spec.payload = data_spec.SerializeToString() + elif transform.spec.urn in translations.PAR_DO_URNS: payload = proto_utils.parse_Bytes( transform.spec.payload, beam_runner_api_pb2.ParDoPayload diff --git a/ray_beam_runner/portability/execution.py b/ray_beam_runner/portability/execution.py index 5e0a551..b1f80c2 100644 --- a/ray_beam_runner/portability/execution.py +++ b/ray_beam_runner/portability/execution.py @@ -345,20 +345,9 @@ def get(self, pcoll) -> List[ray.ObjectRef]: @ray.remote class RayWatermarkManager(watermark_manager.WatermarkManager): - def __init__(self): - # the original WatermarkManager performs a lot of computation - # in its __init__ method. Because Ray calls __init__ whenever - # it deserializes an object, we'll move its setup elsewhere. - self._initialized = False - self._pcollections_by_name = {} - self._stages_by_name = {} - - def setup(self, stages): - if self._initialized: - return - logging.debug("initialized the RayWatermarkManager") - self._initialized = True - watermark_manager.WatermarkManager.setup(self, stages) + def set_pcoll_produced_watermark(self, name, watermark): + element = self._pcollections_by_name[name] + element.set_produced_watermark(watermark) class RayRunnerExecutionContext(object): @@ -371,6 +360,7 @@ def __init__( state_servicer: Optional[RayStateManager] = None, worker_manager: Optional[RayWorkerHandlerManager] = None, pcollection_buffers: PcollectionBufferManager = None, + watermark_manager: Optional[RayWatermarkManager] = None, ) -> None: ray.util.register_serializer( beam_runner_api_pb2.Components, @@ -405,7 +395,9 @@ def __init__( for t in s.transforms if t.spec.urn == bundle_processor.DATA_INPUT_URN } - self._watermark_manager = RayWatermarkManager.remote() + self.watermark_manager = watermark_manager or RayWatermarkManager.remote( + self.stages + ) self.pipeline_context = pipeline_context.PipelineContext(pipeline_components) self.safe_windowing_strategies = { # TODO: Enable safe_windowing_strategy after @@ -419,14 +411,6 @@ def __init__( self.worker_manager = worker_manager or RayWorkerHandlerManager() self.timer_coder_ids = self._build_timer_coders_id_map() - @property - def watermark_manager(self): - # We don't need to wait for this line to execute with ray.get, - # because any further calls to the watermark manager actor will - # have to wait for it. - self._watermark_manager.setup.remote(self.stages) - return self._watermark_manager - @staticmethod def next_uid(): # TODO(pabloem): Use stats actor for UIDs. @@ -464,6 +448,7 @@ def __reduce__(self): self.state_servicer, self.worker_manager, self.pcollection_buffers, + self.watermark_manager, ) def deserializer(*args): @@ -475,6 +460,7 @@ def deserializer(*args): args[4], args[5], args[6], + args[7], ) return (deserializer, data) diff --git a/ray_beam_runner/portability/ray_fn_runner.py b/ray_beam_runner/portability/ray_fn_runner.py index 1388afb..7d6daf9 100644 --- a/ray_beam_runner/portability/ray_fn_runner.py +++ b/ray_beam_runner/portability/ray_fn_runner.py @@ -42,7 +42,7 @@ from apache_beam.runners.portability.fn_api_runner import translations from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer from apache_beam.transforms import environments -from apache_beam.utils import proto_utils +from apache_beam.utils import proto_utils, timestamp import ray from ray_beam_runner.portability.context_management import RayBundleContextManager @@ -227,7 +227,9 @@ def _run_stage( bundle_context_manager (execution.BundleContextManager): A description of the stage to execute, and its context. """ + bundle_context_manager.setup() + runner_execution_context.worker_manager.register_process_bundle_descriptor( bundle_context_manager.process_bundle_descriptor ) @@ -246,6 +248,8 @@ def _run_stage( for k in bundle_context_manager.transform_to_buffer_coder } + watermark_manager = runner_execution_context.watermark_manager + final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] while True: @@ -262,19 +266,26 @@ def _run_stage( final_result = merge_stage_results(final_result, last_result) if not delayed_applications and not fired_timers: + # Processing has completed; marking all outputs as completed + for output_pc in bundle_outputs: + _, update_output_pc = translations.split_buffer_id(output_pc) + watermark_manager.set_pcoll_produced_watermark.remote( + update_output_pc, timestamp.MAX_TIMESTAMP + ) break else: - # TODO: Enable following assertion after watermarking is implemented - # assert (ray.get( - # runner_execution_context.watermark_manager - # .get_stage_node.remote( - # bundle_context_manager.stage.name)).output_watermark() - # < timestamp.MAX_TIMESTAMP), ( - # 'wrong timestamp for %s. ' - # % ray.get( - # runner_execution_context.watermark_manager - # .get_stage_node.remote( - # bundle_context_manager.stage.name))) + assert ( + ray.get( + watermark_manager.get_stage_node.remote( + bundle_context_manager.stage.name + ) + ).output_watermark() + < timestamp.MAX_TIMESTAMP + ), "wrong timestamp for %s. " % ray.get( + watermark_manager.get_stage_node.remote( + bundle_context_manager.stage.name + ) + ) input_data = delayed_applications input_timers = fired_timers @@ -288,6 +299,20 @@ def _run_stage( # TODO(pabloem): Make sure that side inputs are being stored somewhere. # runner_execution_context.commit_side_inputs_to_state(data_side_input) + # assert that the output watermark was correctly set for this stage + stage_node = ray.get( + runner_execution_context.watermark_manager.get_stage_node.remote( + bundle_context_manager.stage.name + ) + ) + assert ( + stage_node.output_watermark() == timestamp.MAX_TIMESTAMP + ), "wrong output watermark for %s. Expected %s, but got %s." % ( + stage_node, + timestamp.MAX_TIMESTAMP, + stage_node.output_watermark(), + ) + return final_result def _run_bundle( @@ -346,6 +371,21 @@ def _run_bundle( # coder_impl=bundle_context_manager.get_input_coder_impl( # other_input)) + # TODO: replace placeholder sets when timers are implemented + watermark_updates = fn_runner.FnApiRunner._build_watermark_updates( + runner_execution_context, + transform_to_buffer_coder.keys(), + set(), # expected_timers + set(), # pcolls_with_da + delayed_applications.keys(), + set(), # watermarks_by_transform_and_timer_family + ) + + for pc_name, watermark in watermark_updates.items(): + runner_execution_context.watermark_manager.set_pcoll_watermark.remote( + pc_name, watermark + ) + newly_set_timers = {} return result, newly_set_timers, delayed_applications, output