Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement runner-initiated split path #54

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 216 additions & 29 deletions ray_beam_runner/portability/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import itertools
import logging
import random
import threading
import time
import typing
from typing import List
from typing import List, MutableMapping
from typing import Mapping
from typing import Optional
from typing import Generator
Expand Down Expand Up @@ -57,6 +59,7 @@ def ray_execute_bundle(
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
expected_outputs: translations.DataOutput,
stage_timers: Mapping[translations.TimerFamilyId, bytes],
split_manager,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see we update the call to ray_execute_bundle() with this new split_manager in this patch? Did i miss anything?

instruction_request_repr: Mapping[str, typing.Any],
dry_run=False,
) -> Generator:
Expand All @@ -83,8 +86,6 @@ def ray_execute_bundle(
runner_context, instruction_request_repr["process_descriptor_id"]
)

_send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id)

input_data = {
k: _fetch_decode_data(
runner_context,
Expand All @@ -95,19 +96,35 @@ def ray_execute_bundle(
for k, objrefs in input_bundle.input_data.items()
}

for transform_id, elements in input_data.items():
data_out = worker_handler.data_conn.output_stream(
process_bundle_id, transform_id
)
for byte_stream in elements:
data_out.write(byte_stream)
data_out.close()

expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list(
expected_outputs.keys()
)
expect_reads.extend(list(stage_timers.keys()))

split_results = []
split_manager_thread = None
if split_manager:
# TODO(iasoon): synchronization can probably be handled
# more cleanly.
split_manager_started_event = threading.Event()
split_manager_thread = threading.Thread(
target=_run_split_manager,
args=(
runner_context,
worker_handler,
split_manager,
input_data,
transform_buffer_coder,
instruction_request,
split_results,
split_manager_started_event,
),
)
split_manager_thread.start()
split_manager_started_event.wait()

_send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id)
_send_input_data(worker_handler, input_data, process_bundle_id)
result_future = worker_handler.control_conn.push(instruction_request)

for output in worker_handler.data_conn.input_elements(
Expand All @@ -125,6 +142,8 @@ def ray_execute_bundle(
output_buffers[expected_outputs[output.transform_id]].append(output.data)

result: beam_fn_api_pb2.InstructionResponse = result_future.get()
if split_manager_thread:
split_manager_thread.join()

if result.process_bundle.requires_finalization:
finalize_request = beam_fn_api_pb2.InstructionRequest(
Expand All @@ -151,14 +170,27 @@ def ray_execute_bundle(
process_bundle_descriptor = runner_context.worker_manager.process_bundle_descriptor(
instruction_request_repr["process_descriptor_id"]
)
delayed_applications = _retrieve_delayed_applications(

deferred_inputs = {}

_add_delayed_applications_to_deferred_inputs(
result,
process_bundle_descriptor,
runner_context,
deferred_inputs,
)

returns.append(len(delayed_applications))
for pcoll, buffer in delayed_applications.items():
_add_residuals_and_channel_splits_to_deferred_inputs(
runner_context,
input_bundle.input_data,
transform_buffer_coder,
process_bundle_descriptor,
split_results,
deferred_inputs,
)

returns.append(len(deferred_inputs))
for pcoll, buffer in deferred_inputs.items():
returns.append(pcoll)
returns.append(buffer)

Expand Down Expand Up @@ -206,37 +238,101 @@ def _get_source_transform_name(
raise RuntimeError("No IO transform feeds %s" % transform_id)


def _retrieve_delayed_applications(
def _add_delayed_application_to_deferred_inputs(
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
delayed_application: beam_fn_api_pb2.DelayedBundleApplication,
deferred_inputs: MutableMapping[str, List[bytes]],
):
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
# time_delay = delayed_application.requested_time_delay
source_transform = _get_source_transform_name(
process_bundle_descriptor,
delayed_application.application.transform_id,
delayed_application.application.input_id,
)

if source_transform not in deferred_inputs:
deferred_inputs[source_transform] = []
deferred_inputs[source_transform].append(delayed_application.application.element)


def _add_delayed_applications_to_deferred_inputs(
bundle_result: beam_fn_api_pb2.InstructionResponse,
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
runner_context: "RayRunnerExecutionContext",
deferred_inputs: MutableMapping[str, List[bytes]],
):
"""Extract delayed applications from a bundle run.

A delayed application represents a user-initiated checkpoint, where user code
delays the consumption of a data element to checkpoint the previous elements
in a bundle.
"""
delayed_bundles = {}
for delayed_application in bundle_result.process_bundle.residual_roots:
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
# time_delay = delayed_application.requested_time_delay
source_transform = _get_source_transform_name(
_add_delayed_application_to_deferred_inputs(
process_bundle_descriptor,
delayed_application.application.transform_id,
delayed_application.application.input_id,
delayed_application,
deferred_inputs,
)

if source_transform not in delayed_bundles:
delayed_bundles[source_transform] = []
delayed_bundles[source_transform].append(
delayed_application.application.element
)

for consumer, data in delayed_bundles.items():
delayed_bundles[consumer] = [data]
def _add_residuals_and_channel_splits_to_deferred_inputs(
runner_context: "RayRunnerExecutionContext",
raw_inputs: Mapping[str, List[ray.ObjectRef]],
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
splits: List[beam_fn_api_pb2.ProcessBundleSplitResponse],
deferred_inputs: MutableMapping[str, List[bytes]],
):
prev_split_point = {} # transform id -> first residual offset
for split in splits:
for delayed_application in split.residual_roots:
_add_delayed_application_to_deferred_inputs(
process_bundle_descriptor,
delayed_application,
deferred_inputs,
)
for channel_split in split.channel_splits:
# Decode all input elements
byte_stream = b"".join(
(
element
for block in ray.get(raw_inputs[channel_split.transform_id])
for element in block
)
)
input_coder_id = transform_buffer_coder[channel_split.transform_id][1]
input_coder = runner_context.pipeline_context.coders[input_coder_id]

buffer_id = transform_buffer_coder[channel_split.transform_id][0]
if buffer_id.startswith(b"group:"):
coder_impl = coders.WindowedValueCoder(
coders.TupleCoder(
(
input_coder.wrapped_value_coder._coders[0],
input_coder.wrapped_value_coder._coders[1]._elem_coder,
)
),
input_coder.window_coder,
).get_impl()
else:
coder_impl = input_coder.get_impl()

all_elements = list(coder_impl.decode_all(byte_stream))

# split at first_residual_element index
end = prev_split_point.get(channel_split.transform_id, len(all_elements))
residual_elements = all_elements[channel_split.first_residual_element : end]
prev_split_point[
channel_split.transform_id
] = channel_split.first_residual_element

return delayed_bundles
if residual_elements:
encoded_residual = coder_impl.encode_all(residual_elements)

if channel_split.transform_id not in deferred_inputs:
deferred_inputs[channel_split.transform_id] = []
deferred_inputs[channel_split.transform_id].append(encoded_residual)


def _get_input_id(buffer_id, transform_name):
Expand Down Expand Up @@ -316,6 +412,97 @@ def _send_timers(
timer_out.close()


def _send_input_data(
worker_handler: worker_handlers.WorkerHandler,
input_data: Mapping[str, fn_execution.PartitionableBuffer],
process_bundle_id,
):
for transform_id, elements in input_data.items():
data_out = worker_handler.data_conn.output_stream(
process_bundle_id, transform_id
)
for byte_stream in elements:
data_out.write(byte_stream)
data_out.close()


def _run_split_manager(
runner_context: "RayRunnerExecutionContext",
worker_handler: worker_handlers.WorkerHandler,
split_manager,
inputs: Mapping[str, fn_execution.PartitionableBuffer],
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
instruction_request,
split_results_buf: List[beam_fn_api_pb2.ProcessBundleSplitResponse],
split_manager_started_event: threading.Event,
):
read_transform_id, buffer_data = translations.only_element(inputs.items())
byte_stream = b"".join(buffer_data or [])
coder_id = transform_buffer_coder[read_transform_id][1]
coder_impl = runner_context.pipeline_context.coders[coder_id].get_impl()
num_elements = len(list(coder_impl.decode_all(byte_stream)))

# Start the split manager in case it wants to set any breakpoints.
split_manager_generator = split_manager(num_elements)
try:
split_fraction = next(split_manager_generator)
done = False
except StopIteration:
split_fraction = None
done = True

split_manager_started_event.set()

assert worker_handler is not None

# Execute the requested splits.
while not done:
if split_fraction is None:
split_result = None
else:
DesiredSplit = beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit
split_request = beam_fn_api_pb2.InstructionRequest(
process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest(
instruction_id=instruction_request.instruction_id,
desired_splits={
read_transform_id: DesiredSplit(
fraction_of_remainder=split_fraction,
estimated_input_elements=num_elements,
)
},
)
)
split_response = worker_handler.control_conn.push(
split_request
).get() # type: beam_fn_api_pb2.InstructionResponse
for t in (0.05, 0.1, 0.2):
if (
"Unknown process bundle" in split_response.error
or split_response.process_bundle_split
== beam_fn_api_pb2.ProcessBundleSplitResponse()
):
time.sleep(t)
split_response = worker_handler.control_conn.push(
split_request
).get()
if (
"Unknown process bundle" in split_response.error
or split_response.process_bundle_split
== beam_fn_api_pb2.ProcessBundleSplitResponse()
):
# It may have finished too fast.
split_result = None
elif split_response.error:
raise RuntimeError(split_response.error)
else:
split_result = split_response.process_bundle_split
split_results_buf.append(split_result)
try:
split_fraction = split_manager_generator.send(split_result)
except StopIteration:
break


@ray.remote
class _RayRunnerStats:
def __init__(self):
Expand Down
Loading