Skip to content

Commit

Permalink
implement basic watermarking
Browse files Browse the repository at this point in the history
  • Loading branch information
iasoon committed Jun 16, 2022
1 parent f1b8fde commit 015df34
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 37 deletions.
7 changes: 5 additions & 2 deletions ray_beam_runner/portability/context_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,18 @@ def setup(self):
pcoll_id,
self.execution_context.safe_coders.get(coder_id, coder_id),
)

elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
coder_id = self.execution_context.data_channel_coders[
translations.only_element(transform.inputs.values())
]
else:
raise NotImplementedError
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
transform.spec.payload = data_spec.SerializeToString()
if pcoll_id != translations.IMPULSE_BUFFER:
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
transform.spec.payload = data_spec.SerializeToString()

elif transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload
Expand Down
32 changes: 9 additions & 23 deletions ray_beam_runner/portability/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,20 +345,9 @@ def get(self, pcoll) -> List[ray.ObjectRef]:

@ray.remote
class RayWatermarkManager(watermark_manager.WatermarkManager):
def __init__(self):
# the original WatermarkManager performs a lot of computation
# in its __init__ method. Because Ray calls __init__ whenever
# it deserializes an object, we'll move its setup elsewhere.
self._initialized = False
self._pcollections_by_name = {}
self._stages_by_name = {}

def setup(self, stages):
if self._initialized:
return
logging.debug("initialized the RayWatermarkManager")
self._initialized = True
watermark_manager.WatermarkManager.setup(self, stages)
def set_pcoll_produced_watermark(self, name, watermark):
element = self._pcollections_by_name[name]
element.set_produced_watermark(watermark)


class RayRunnerExecutionContext(object):
Expand All @@ -371,6 +360,7 @@ def __init__(
state_servicer: Optional[RayStateManager] = None,
worker_manager: Optional[RayWorkerHandlerManager] = None,
pcollection_buffers: PcollectionBufferManager = None,
watermark_manager: Optional[RayWatermarkManager] = None,
) -> None:
ray.util.register_serializer(
beam_runner_api_pb2.Components,
Expand Down Expand Up @@ -405,7 +395,9 @@ def __init__(
for t in s.transforms
if t.spec.urn == bundle_processor.DATA_INPUT_URN
}
self._watermark_manager = RayWatermarkManager.remote()
self.watermark_manager = watermark_manager or RayWatermarkManager.remote(
self.stages
)
self.pipeline_context = pipeline_context.PipelineContext(pipeline_components)
self.safe_windowing_strategies = {
# TODO: Enable safe_windowing_strategy after
Expand All @@ -419,14 +411,6 @@ def __init__(
self.worker_manager = worker_manager or RayWorkerHandlerManager()
self.timer_coder_ids = self._build_timer_coders_id_map()

@property
def watermark_manager(self):
# We don't need to wait for this line to execute with ray.get,
# because any further calls to the watermark manager actor will
# have to wait for it.
self._watermark_manager.setup.remote(self.stages)
return self._watermark_manager

@staticmethod
def next_uid():
# TODO(pabloem): Use stats actor for UIDs.
Expand Down Expand Up @@ -464,6 +448,7 @@ def __reduce__(self):
self.state_servicer,
self.worker_manager,
self.pcollection_buffers,
self.watermark_manager,
)

def deserializer(*args):
Expand All @@ -475,6 +460,7 @@ def deserializer(*args):
args[4],
args[5],
args[6],
args[7],
)

return (deserializer, data)
Expand Down
64 changes: 52 additions & 12 deletions ray_beam_runner/portability/ray_fn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer
from apache_beam.transforms import environments
from apache_beam.utils import proto_utils
from apache_beam.utils import proto_utils, timestamp

import ray
from ray_beam_runner.portability.context_management import RayBundleContextManager
Expand Down Expand Up @@ -227,7 +227,9 @@ def _run_stage(
bundle_context_manager (execution.BundleContextManager): A description of
the stage to execute, and its context.
"""

bundle_context_manager.setup()

runner_execution_context.worker_manager.register_process_bundle_descriptor(
bundle_context_manager.process_bundle_descriptor
)
Expand All @@ -246,6 +248,8 @@ def _run_stage(
for k in bundle_context_manager.transform_to_buffer_coder
}

watermark_manager = runner_execution_context.watermark_manager

final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse]

while True:
Expand All @@ -262,19 +266,26 @@ def _run_stage(

final_result = merge_stage_results(final_result, last_result)
if not delayed_applications and not fired_timers:
# Processing has completed; marking all outputs as completed
for output_pc in bundle_outputs:
_, update_output_pc = translations.split_buffer_id(output_pc)
watermark_manager.set_pcoll_produced_watermark.remote(
update_output_pc, timestamp.MAX_TIMESTAMP
)
break
else:
# TODO: Enable following assertion after watermarking is implemented
# assert (ray.get(
# runner_execution_context.watermark_manager
# .get_stage_node.remote(
# bundle_context_manager.stage.name)).output_watermark()
# < timestamp.MAX_TIMESTAMP), (
# 'wrong timestamp for %s. '
# % ray.get(
# runner_execution_context.watermark_manager
# .get_stage_node.remote(
# bundle_context_manager.stage.name)))
assert (
ray.get(
watermark_manager.get_stage_node.remote(
bundle_context_manager.stage.name
)
).output_watermark()
< timestamp.MAX_TIMESTAMP
), "wrong timestamp for %s. " % ray.get(
watermark_manager.get_stage_node.remote(
bundle_context_manager.stage.name
)
)
input_data = delayed_applications
input_timers = fired_timers

Expand All @@ -288,6 +299,20 @@ def _run_stage(
# TODO(pabloem): Make sure that side inputs are being stored somewhere.
# runner_execution_context.commit_side_inputs_to_state(data_side_input)

# assert that the output watermark was correctly set for this stage
stage_node = ray.get(
runner_execution_context.watermark_manager.get_stage_node.remote(
bundle_context_manager.stage.name
)
)
assert (
stage_node.output_watermark() == timestamp.MAX_TIMESTAMP
), "wrong output watermark for %s. Expected %s, but got %s." % (
stage_node,
timestamp.MAX_TIMESTAMP,
stage_node.output_watermark(),
)

return final_result

def _run_bundle(
Expand Down Expand Up @@ -346,6 +371,21 @@ def _run_bundle(
# coder_impl=bundle_context_manager.get_input_coder_impl(
# other_input))

# TODO: replace placeholder sets when timers are implemented
watermark_updates = fn_runner.FnApiRunner._build_watermark_updates(
runner_execution_context,
transform_to_buffer_coder.keys(),
set(), # expected_timers
set(), # pcolls_with_da
delayed_applications.keys(),
set(), # watermarks_by_transform_and_timer_family
)

for pc_name, watermark in watermark_updates.items():
runner_execution_context.watermark_manager.set_pcoll_watermark.remote(
pc_name, watermark
)

newly_set_timers = {}
return result, newly_set_timers, delayed_applications, output

Expand Down

0 comments on commit 015df34

Please sign in to comment.