diff --git a/ray_beam_runner/portability/execution_test.py b/ray_beam_runner/portability/execution_test.py index 7b038af..0c28678 100644 --- a/ray_beam_runner/portability/execution_test.py +++ b/ray_beam_runner/portability/execution_test.py @@ -43,7 +43,7 @@ def test_data_stored_properly(self): for data in StateHandlerTest.SAMPLE_INPUT_DATA: sh.append_raw(StateHandlerTest.SAMPLE_STATE_KEY, data) - with sh.process_instruction_id("anyinstruction"): + with sh.process_instruction_id("otherinstruction"): continuation_token = None all_data = [] while True: diff --git a/ray_beam_runner/portability/state.py b/ray_beam_runner/portability/state.py index 74570b1..7f0bd62 100644 --- a/ray_beam_runner/portability/state.py +++ b/ray_beam_runner/portability/state.py @@ -61,7 +61,6 @@ def __init__(self): def get_raw( self, - bundle_id: str, state_key: str, continuation_token: Optional[bytes] = None, ) -> Tuple[bytes, Optional[bytes]]: @@ -70,7 +69,7 @@ def get_raw( else: continuation_token = 0 - full_state = self._data[(bundle_id, state_key)] + full_state = self._data[state_key] if len(full_state) == continuation_token: return b"", None @@ -81,11 +80,11 @@ def get_raw( return full_state[continuation_token], next_cont_token - def append_raw(self, bundle_id: str, state_key: str, data: bytes): - self._data[(bundle_id, state_key)].append(data) + def append_raw(self, state_key: str, data: bytes): + self._data[state_key].append(data) - def clear(self, bundle_id: str, state_key: str): - self._data[(bundle_id, state_key)] = [] + def clear(self, state_key: str): + self._data[state_key] = [] class RayStateManager(sdk_worker.StateHandler): @@ -105,7 +104,6 @@ def get_raw( assert self._instruction_id is not None return ray.get( self._state_actor.get_raw.remote( - self._instruction_id, RayStateManager._to_key(state_key), continuation_token, ) @@ -115,20 +113,20 @@ def append_raw(self, state_key: beam_fn_api_pb2.StateKey, data: bytes) -> RayFut assert self._instruction_id is not None return RayFuture( self._state_actor.append_raw.remote( - self._instruction_id, RayStateManager._to_key(state_key), data + RayStateManager._to_key(state_key), data ) ) def clear(self, state_key: beam_fn_api_pb2.StateKey) -> RayFuture: assert self._instruction_id is not None return RayFuture( - self._state_actor.clear.remote( - self._instruction_id, RayStateManager._to_key(state_key) - ) + self._state_actor.clear.remote(RayStateManager._to_key(state_key)) ) @contextlib.contextmanager def process_instruction_id(self, bundle_id: str) -> Iterator[None]: + # Instruction id is not being used right now, + # we only assert that it has been set before accessing state. self._instruction_id = bundle_id yield self._instruction_id = None