From 0139c38ad19971f5d2f7a81581b6355912c9caea Mon Sep 17 00:00:00 2001 From: Pablo E Date: Wed, 31 Aug 2022 11:16:05 -0700 Subject: [PATCH] Ensuring tests can run without passing context around --- .../portability/context_management.py | 6 ++--- ray_beam_runner/portability/execution.py | 1 + .../portability/ray_runner_test.py | 27 +++++++++++-------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ray_beam_runner/portability/context_management.py b/ray_beam_runner/portability/context_management.py index 5b12f46..e4e1fcc 100644 --- a/ray_beam_runner/portability/context_management.py +++ b/ray_beam_runner/portability/context_management.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import typing from typing import Dict from typing import List @@ -33,11 +34,8 @@ from apache_beam.runners.worker import bundle_processor from apache_beam.utils import proto_utils -import ray from ray_beam_runner.portability.execution import RayRunnerExecutionContext -ENCODED_IMPULSE_REFERENCE = ray.put([fn_execution.ENCODED_IMPULSE_VALUE]) - class RayBundleContextManager: def __init__( @@ -155,7 +153,7 @@ def setup(self): if pcoll_id == translations.IMPULSE_BUFFER: pcoll_id = transform.unique_name.encode("utf8") self.execution_context.pcollection_buffers.put.remote( - pcoll_id, [ENCODED_IMPULSE_REFERENCE] + pcoll_id, [self.execution_context.encoded_impulse_ref] ) else: pass diff --git a/ray_beam_runner/portability/execution.py b/ray_beam_runner/portability/execution.py index 4c01277..e03b4ab 100644 --- a/ray_beam_runner/portability/execution.py +++ b/ray_beam_runner/portability/execution.py @@ -421,6 +421,7 @@ def __init__( self._uid = 0 self.worker_manager = worker_manager or RayWorkerHandlerManager() self.timer_coder_ids = self._build_timer_coders_id_map() + self.encoded_impulse_ref = ray.put([fn_execution.ENCODED_IMPULSE_VALUE]) @property def watermark_manager(self): diff --git a/ray_beam_runner/portability/ray_runner_test.py b/ray_beam_runner/portability/ray_runner_test.py index 41a90dd..a096d3d 100644 --- a/ray_beam_runner/portability/ray_runner_test.py +++ b/ray_beam_runner/portability/ray_runner_test.py @@ -20,6 +20,7 @@ import gc import logging import os +import pytest import random import re import shutil @@ -94,10 +95,12 @@ def contains_labels(mi, labels): class RayFnApiRunnerTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: + def setUp(self) -> None: if not ray.is_initialized(): - ray.init() + ray.init(num_cpus=1) + + def tearDown(self) -> None: + ray.shutdown() def create_pipeline(self, is_drain=False): return beam.Pipeline( @@ -1210,10 +1213,12 @@ def test_pack_combiners(self): # the sampling counter. @unittest.skip("Metrics not yet supported.") class RayRunnerMetricsTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: + def setUp(self) -> None: if not ray.is_initialized(): - ray.init() + ray.init(num_cpus=1) + + def tearDown(self) -> None: + ray.shutdown() def assert_has_counter(self, mon_infos, urn, labels, value=None, ge_value=None): found = 0 @@ -1631,10 +1636,12 @@ def has_mi_for_ptransform(mon_infos, ptransform): @unittest.skip("Runner-initiated splitting not yet supported") class RayRunnerSplitTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: + def setUp(self) -> None: if not ray.is_initialized(): - ray.init() + ray.init(num_cpus=1) + + def tearDown(self) -> None: + ray.shutdown() def create_pipeline(self, is_drain=False): return beam.Pipeline( @@ -2073,8 +2080,6 @@ def process(self, element, *side_inputs): yield self._name -logging.getLogger().setLevel(logging.INFO) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) unittest.main()