Skip to content

Commit

Permalink
Fixes #14: Refactoring ray portable runner (#18)
Browse files Browse the repository at this point in the history
* Refactoring ray portable runner

* Supporting SDF-initiated checkpoint

* Supporting SDF with SDF-initiated splitting.

* Fix formatting

* fixing smaller issues

* Adding portability tests to CI
  • Loading branch information
pabloem authored Jun 16, 2022
1 parent 95f7cc0 commit f1b8fde
Show file tree
Hide file tree
Showing 9 changed files with 2,973 additions and 2,838 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ ignore =
I
N
avoid-escape = no
per-file-ignores =
*ray_runner_test.py: B008
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ jobs:
- name: Format
run: |
bash scripts/format.sh
- name: Install Ray Beam Runner
run: |
pip install -e .[test]
- name: Run Portability tests
run: |
pytest -r A ray_beam_runner/portability/ray_runner_test.py ray_beam_runner/portability/execution_test.py
LicenseCheck:
name: License Check
Expand Down
272 changes: 154 additions & 118 deletions ray_beam_runner/portability/context_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,137 +15,173 @@
# limitations under the License.
#
import typing
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners.portability.fn_api_runner import execution as fn_execution
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner import worker_handlers
from apache_beam.runners.portability.fn_api_runner.execution import PartitionableBuffer
from apache_beam.runners.portability.fn_api_runner.fn_runner import OutputTimers
from apache_beam.runners.portability.fn_api_runner.translations import DataOutput
from apache_beam.runners.portability.fn_api_runner.translations import TimerFamilyId
from apache_beam.runners.worker import bundle_processor
from apache_beam.utils import proto_utils

import ray
from ray_beam_runner.portability.execution import RayRunnerExecutionContext

class RayBundleContextManager:
ENCODED_IMPULSE_REFERENCE = ray.put([fn_execution.ENCODED_IMPULSE_VALUE])


def __init__(self,
execution_context: RayRunnerExecutionContext,
stage: translations.Stage,
) -> None:
self.execution_context = execution_context
self.stage = stage
# self.extract_bundle_inputs_and_outputs()
self.bundle_uid = self.execution_context.next_uid()

# Properties that are lazily initialized
self._process_bundle_descriptor = None # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
self._worker_handlers = None # type: Optional[List[worker_handlers.WorkerHandler]]
# a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
# is built after self._process_bundle_descriptor is initialized.
# This field can be used to tell whether current bundle has timers.
self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]]

def __reduce__(self):
data = (self.execution_context,
self.stage)
deserializer = lambda args: RayBundleContextManager(args[0], args[1])
return (deserializer, data)

@property
def worker_handlers(self) -> List[worker_handlers.WorkerHandler]:
return []

def data_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return endpoints_pb2.ApiServiceDescriptor(url='fake')

def state_api_service_descriptor(self) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return None

@property
def process_bundle_descriptor(self):
# type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
if self._process_bundle_descriptor is None:
self._process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor.FromString(
self._build_process_bundle_descriptor())
self._timer_coder_ids = fn_execution.BundleContextManager._build_timer_coders_id_map(self)
return self._process_bundle_descriptor

def _build_process_bundle_descriptor(self):
# Cannot be invoked until *after* _extract_endpoints is called.
# Always populate the timer_api_service_descriptor.
pbd = beam_fn_api_pb2.ProcessBundleDescriptor(
id=self.bundle_uid,
transforms={
transform.unique_name: transform
for transform in self.stage.transforms
},
pcollections=dict(
self.execution_context.pipeline_components.pcollections.items()),
coders=dict(self.execution_context.pipeline_components.coders.items()),
windowing_strategies=dict(
self.execution_context.pipeline_components.windowing_strategies.
items()),
environments=dict(
self.execution_context.pipeline_components.environments.items()),
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor())

return pbd.SerializeToString()

def extract_bundle_inputs_and_outputs(self):
# type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]

"""Returns maps of transform names to PCollection identifiers.
Also mutates IO stages to point to the data ApiServiceDescriptor.
Returns:
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
PCollection buffer; `data_output` is a dictionary mapping
(transform_name, output_name) to a PCollection ID.
`expected_timer_output` is a dictionary mapping transform_id and
timer family ID to a buffer id for timers.
"""
transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {}
data_output = {} # type: DataOutput
expected_timer_output = {} # type: OutputTimers
for transform in self.stage.transforms:
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
coder_id = self.execution_context.data_channel_coders[translations.only_element(
transform.outputs.values())]
if pcoll_id == translations.IMPULSE_BUFFER:
buffer_actor = ray.get(self.execution_context.pcollection_buffers.get.remote(
transform.unique_name))
ray.get(buffer_actor.append.remote(fn_execution.ENCODED_IMPULSE_VALUE))
pcoll_id = transform.unique_name.encode('utf8')
else:
pass
transform_to_buffer_coder[transform.unique_name] = (
pcoll_id,
self.execution_context.safe_coders.get(coder_id, coder_id)
)
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
coder_id = self.execution_context.data_channel_coders[translations.only_element(
transform.inputs.values())]
else:
raise NotImplementedError
# TODO(pabloem): Figure out when we DO and we DONT need this particular rewrite of coders.
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
# data_spec.api_service_descriptor.url = 'fake'
transform.spec.payload = data_spec.SerializeToString()
elif transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
for timer_family_id in payload.timer_family_specs.keys():
expected_timer_output[(transform.unique_name, timer_family_id)] = (
translations.create_buffer_id(timer_family_id, 'timers'))
return transform_to_buffer_coder, data_output, expected_timer_output
class RayBundleContextManager:
def __init__(
self,
execution_context: RayRunnerExecutionContext,
stage: translations.Stage,
) -> None:
self.execution_context = execution_context
self.stage = stage
# self.extract_bundle_inputs_and_outputs()
self.bundle_uid = self.execution_context.next_uid()

# Properties that are lazily initialized
self._process_bundle_descriptor = (
None
) # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
self._worker_handlers = (
None
) # type: Optional[List[worker_handlers.WorkerHandler]]
# a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
# is built after self._process_bundle_descriptor is initialized.
# This field can be used to tell whether current bundle has timers.
self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]]

def __reduce__(self):
data = (self.execution_context, self.stage)

def deserializer(args):
RayBundleContextManager(args[0], args[1])

return (deserializer, data)

@property
def worker_handlers(self) -> List[worker_handlers.WorkerHandler]:
return []

def data_api_service_descriptor(
self,
) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return endpoints_pb2.ApiServiceDescriptor(url="fake")

def state_api_service_descriptor(
self,
) -> Optional[endpoints_pb2.ApiServiceDescriptor]:
return None

@property
def process_bundle_descriptor(self) -> beam_fn_api_pb2.ProcessBundleDescriptor:
if self._process_bundle_descriptor is None:
self._process_bundle_descriptor = (
beam_fn_api_pb2.ProcessBundleDescriptor.FromString(
self._build_process_bundle_descriptor()
)
)
self._timer_coder_ids = (
fn_execution.BundleContextManager._build_timer_coders_id_map(self)
)
return self._process_bundle_descriptor

def _build_process_bundle_descriptor(self):
# Cannot be invoked until *after* _extract_endpoints is called.
# Always populate the timer_api_service_descriptor.
pbd = beam_fn_api_pb2.ProcessBundleDescriptor(
id=self.bundle_uid,
transforms={
transform.unique_name: transform for transform in self.stage.transforms
},
pcollections=dict(
self.execution_context.pipeline_components.pcollections.items()
),
coders=dict(self.execution_context.pipeline_components.coders.items()),
windowing_strategies=dict(
self.execution_context.pipeline_components.windowing_strategies.items()
),
environments=dict(
self.execution_context.pipeline_components.environments.items()
),
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor(),
)

return pbd.SerializeToString()

def get_bundle_inputs_and_outputs(
self,
) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]:
"""Returns maps of transform names to PCollection identifiers.
Also mutates IO stages to point to the data ApiServiceDescriptor.
Returns:
A tuple of (data_input, data_output, expected_timer_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
PCollection buffer; `data_output` is a dictionary mapping
(transform_name, output_name) to a PCollection ID.
`expected_timer_output` is a dictionary mapping transform_id and
timer family ID to a buffer id for timers.
"""
return self.transform_to_buffer_coder, self.data_output, self.stage_timers

def setup(self):
transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {}
data_output = {} # type: DataOutput
expected_timer_output = {} # type: OutputTimers
for transform in self.stage.transforms:
if transform.spec.urn in (
bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN,
):
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
coder_id = self.execution_context.data_channel_coders[
translations.only_element(transform.outputs.values())
]
if pcoll_id == translations.IMPULSE_BUFFER:
pcoll_id = transform.unique_name.encode("utf8")
self.execution_context.pcollection_buffers.put.remote(
pcoll_id, [ENCODED_IMPULSE_REFERENCE]
)
else:
pass
transform_to_buffer_coder[transform.unique_name] = (
pcoll_id,
self.execution_context.safe_coders.get(coder_id, coder_id),
)
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
coder_id = self.execution_context.data_channel_coders[
translations.only_element(transform.inputs.values())
]
else:
raise NotImplementedError
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
transform.spec.payload = data_spec.SerializeToString()
elif transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
transform.spec.payload, beam_runner_api_pb2.ParDoPayload
)
for timer_family_id in payload.timer_family_specs.keys():
expected_timer_output[
(transform.unique_name, timer_family_id)
] = translations.create_buffer_id(timer_family_id, "timers")
self.transform_to_buffer_coder, self.data_output, self.stage_timers = (
transform_to_buffer_coder,
data_output,
expected_timer_output,
)
Loading

0 comments on commit f1b8fde

Please sign in to comment.