Skip to content

Commit

Permalink
Ensuring tests can run without passing context around
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloem committed Aug 31, 2022
1 parent 521acdb commit 0139c38
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
6 changes: 2 additions & 4 deletions ray_beam_runner/portability/context_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import typing
from typing import Dict
from typing import List
Expand All @@ -33,11 +34,8 @@
from apache_beam.runners.worker import bundle_processor
from apache_beam.utils import proto_utils

import ray
from ray_beam_runner.portability.execution import RayRunnerExecutionContext

ENCODED_IMPULSE_REFERENCE = ray.put([fn_execution.ENCODED_IMPULSE_VALUE])


class RayBundleContextManager:
def __init__(
Expand Down Expand Up @@ -155,7 +153,7 @@ def setup(self):
if pcoll_id == translations.IMPULSE_BUFFER:
pcoll_id = transform.unique_name.encode("utf8")
self.execution_context.pcollection_buffers.put.remote(
pcoll_id, [ENCODED_IMPULSE_REFERENCE]
pcoll_id, [self.execution_context.encoded_impulse_ref]
)
else:
pass
Expand Down
1 change: 1 addition & 0 deletions ray_beam_runner/portability/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(
self._uid = 0
self.worker_manager = worker_manager or RayWorkerHandlerManager()
self.timer_coder_ids = self._build_timer_coders_id_map()
self.encoded_impulse_ref = ray.put([fn_execution.ENCODED_IMPULSE_VALUE])

@property
def watermark_manager(self):
Expand Down
27 changes: 16 additions & 11 deletions ray_beam_runner/portability/ray_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import gc
import logging
import os
import pytest
import random
import re
import shutil
Expand Down Expand Up @@ -94,10 +95,12 @@ def contains_labels(mi, labels):


class RayFnApiRunnerTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
def setUp(self) -> None:
if not ray.is_initialized():
ray.init()
ray.init(num_cpus=1)

def tearDown(self) -> None:
ray.shutdown()

def create_pipeline(self, is_drain=False):
return beam.Pipeline(
Expand Down Expand Up @@ -1210,10 +1213,12 @@ def test_pack_combiners(self):
# the sampling counter.
@unittest.skip("Metrics not yet supported.")
class RayRunnerMetricsTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
def setUp(self) -> None:
if not ray.is_initialized():
ray.init()
ray.init(num_cpus=1)

def tearDown(self) -> None:
ray.shutdown()

def assert_has_counter(self, mon_infos, urn, labels, value=None, ge_value=None):
found = 0
Expand Down Expand Up @@ -1631,10 +1636,12 @@ def has_mi_for_ptransform(mon_infos, ptransform):

@unittest.skip("Runner-initiated splitting not yet supported")
class RayRunnerSplitTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
def setUp(self) -> None:
if not ray.is_initialized():
ray.init()
ray.init(num_cpus=1)

def tearDown(self) -> None:
ray.shutdown()

def create_pipeline(self, is_drain=False):
return beam.Pipeline(
Expand Down Expand Up @@ -2073,8 +2080,6 @@ def process(self, element, *side_inputs):
yield self._name


logging.getLogger().setLevel(logging.INFO)

if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
unittest.main()

0 comments on commit 0139c38

Please sign in to comment.