diff --git a/ray_beam_runner/portability/context_management.py b/ray_beam_runner/portability/context_management.py index 5b12f46..e368849 100644 --- a/ray_beam_runner/portability/context_management.py +++ b/ray_beam_runner/portability/context_management.py @@ -33,11 +33,8 @@ from apache_beam.runners.worker import bundle_processor from apache_beam.utils import proto_utils -import ray from ray_beam_runner.portability.execution import RayRunnerExecutionContext -ENCODED_IMPULSE_REFERENCE = ray.put([fn_execution.ENCODED_IMPULSE_VALUE]) - class RayBundleContextManager: def __init__( @@ -155,7 +152,7 @@ def setup(self): if pcoll_id == translations.IMPULSE_BUFFER: pcoll_id = transform.unique_name.encode("utf8") self.execution_context.pcollection_buffers.put.remote( - pcoll_id, [ENCODED_IMPULSE_REFERENCE] + pcoll_id, [self.execution_context.encoded_impulse_ref] ) else: pass diff --git a/ray_beam_runner/portability/execution.py b/ray_beam_runner/portability/execution.py index 4c01277..e03b4ab 100644 --- a/ray_beam_runner/portability/execution.py +++ b/ray_beam_runner/portability/execution.py @@ -421,6 +421,7 @@ def __init__( self._uid = 0 self.worker_manager = worker_manager or RayWorkerHandlerManager() self.timer_coder_ids = self._build_timer_coders_id_map() + self.encoded_impulse_ref = ray.put([fn_execution.ENCODED_IMPULSE_VALUE]) @property def watermark_manager(self): diff --git a/ray_beam_runner/portability/ray_runner_test.py b/ray_beam_runner/portability/ray_runner_test.py index 41a90dd..ff359cb 100644 --- a/ray_beam_runner/portability/ray_runner_test.py +++ b/ray_beam_runner/portability/ray_runner_test.py @@ -94,10 +94,12 @@ def contains_labels(mi, labels): class RayFnApiRunnerTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: + def setUp(self) -> None: if not ray.is_initialized(): - ray.init() + ray.init(num_cpus=1, include_dashboard=False) + + def tearDown(self) -> None: + ray.shutdown() def create_pipeline(self, is_drain=False): return beam.Pipeline( @@ -1210,10 +1212,12 @@ def test_pack_combiners(self): # the sampling counter. @unittest.skip("Metrics not yet supported.") class RayRunnerMetricsTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: + def setUp(self) -> None: if not ray.is_initialized(): - ray.init() + ray.init(num_cpus=1, include_dashboard=False) + + def tearDown(self) -> None: + ray.shutdown() def assert_has_counter(self, mon_infos, urn, labels, value=None, ge_value=None): found = 0 @@ -1631,10 +1635,12 @@ def has_mi_for_ptransform(mon_infos, ptransform): @unittest.skip("Runner-initiated splitting not yet supported") class RayRunnerSplitTest(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: + def setUp(self) -> None: if not ray.is_initialized(): - ray.init() + ray.init(num_cpus=1, include_dashboard=False) + + def tearDown(self) -> None: + ray.shutdown() def create_pipeline(self, is_drain=False): return beam.Pipeline( @@ -2073,8 +2079,6 @@ def process(self, element, *side_inputs): yield self._name -logging.getLogger().setLevel(logging.INFO) - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) unittest.main()