From c9916d6f5f6a3b84368e65b2e8fd10b565994695 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 26 Jun 2024 21:07:21 +0200 Subject: [PATCH 1/2] Adjust rechunking code to nest less --- distributed/shuffle/_core.py | 33 +++++++++++++++++++++------ distributed/shuffle/_rechunk.py | 18 ++++++++------- distributed/shuffle/_worker_plugin.py | 4 +++- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 48510dfd41a..d0a9e08adcf 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -203,18 +203,32 @@ async def barrier(self, run_ids: Sequence[int]) -> int: return self.run_id async def _send( - self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes + self, + address: str, + input_partitions: list[_T_partition_id], + output_partitions: list[_T_partition_id], + shards: list[Any] | bytes, ) -> OKMessage | ErrorMessage: self.raise_if_closed() return await self.rpc(address).shuffle_receive( + input_partitions=input_partitions, + output_partitions=output_partitions, data=to_serialize(shards), shuffle_id=self.id, run_id=self.run_id, ) async def send( - self, address: str, shards: list[tuple[_T_partition_id, Any]] + self, address: str, sharded: list[tuple[_T_partition_id, Any]] ) -> OKMessage | ErrorMessage: + ipids = [] + opids = [] + shards = [] + for input_partition, inshards in sharded: + for output_partition, shard in inshards: + ipids.append(input_partition) + opids.append(output_partition) + shards.append(shard) if _mean_shard_size(shards) < 65536: # Don't send buffers individually over the tcp comms. # Instead, merge everything into an opaque bytes blob, send it all at once, @@ -226,7 +240,7 @@ async def send( shards_or_bytes = shards def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]: - return self._send(address, shards_or_bytes) + return self._send(address, ipids, opids, shards_or_bytes) return await retry( _send, @@ -308,13 +322,16 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing return self._disk_buffer.read("_".join(str(i) for i in id)) async def receive( - self, data: list[tuple[_T_partition_id, Any]] | bytes + self, + input_partitions: list[_T_partition_id], + output_partitions: list[_T_partition_type], + data: list[Any] | bytes, ) -> OKMessage | ErrorMessage: try: if isinstance(data, bytes): # Unpack opaque blob. See send() - data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data)) - await self._receive(data) + data = cast(list[Any], pickle.loads(data)) + await self._receive(input_partitions, output_partitions, data) return {"status": "OK"} except P2PConsistencyError as e: return error_message(e) @@ -336,7 +353,9 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str: """Get the address of the worker assigned to the output partition""" @abc.abstractmethod - async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None: + async def _receive( + self, input_partitions: list[_T_partition_id], data: list[Any] + ) -> None: """Receive shards belonging to output partitions of this shuffle run""" def add_partition( diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index b33e90730b7..328b47300c7 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -658,20 +658,22 @@ def __init__( async def _receive( self, - data: list[tuple[NDIndex, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]]], + input_partitions: list[NDIndex], + output_partitions: list[NDIndex], + data: list[tuple[NDIndex, np.ndarray]], ) -> None: self.raise_if_closed() # Repartition shards and filter out already received ones shards = defaultdict(list) - for d in data: - id1, payload = d - if id1 in self.received: + for ipid, opid, dat in zip(input_partitions, output_partitions, data): + if ipid in self.received: continue - self.received.add(id1) - for id2, shard in payload: - shards[id2].append(shard) - self.total_recvd += sizeof(d) + shards[opid].append(dat) + self.total_recvd += sizeof(dat) + self.received.update(input_partitions) + del input_partitions + del output_partitions del data if not shards: return diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 57d2cfe3696..8f20f1695e0 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -311,6 +311,8 @@ async def shuffle_receive( self, shuffle_id: ShuffleId, run_id: int, + input_partitions, + output_partitions, data: list[tuple[int, Any]] | bytes, ) -> OKMessage | ErrorMessage: """ @@ -319,7 +321,7 @@ async def shuffle_receive( """ try: shuffle_run = await self._get_shuffle_run(shuffle_id, run_id) - return await shuffle_run.receive(data) + return await shuffle_run.receive(input_partitions, output_partitions, data) except P2PConsistencyError as e: return error_message(e) From 48be30f169e58c7d20949d2ce0e6699305d3a85e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 27 Jun 2024 14:12:06 +0200 Subject: [PATCH 2/2] Oh well, let's try this --- distributed/shuffle/_core.py | 12 +++++++++--- distributed/shuffle/_rechunk.py | 7 ++++--- distributed/shuffle/_worker_plugin.py | 3 ++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index d0a9e08adcf..4f799f5072c 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -207,12 +207,14 @@ async def _send( address: str, input_partitions: list[_T_partition_id], output_partitions: list[_T_partition_id], + locs: list[_T_partition_id], shards: list[Any] | bytes, ) -> OKMessage | ErrorMessage: self.raise_if_closed() return await self.rpc(address).shuffle_receive( input_partitions=input_partitions, output_partitions=output_partitions, + locs=locs, data=to_serialize(shards), shuffle_id=self.id, run_id=self.run_id, @@ -223,12 +225,15 @@ async def send( ) -> OKMessage | ErrorMessage: ipids = [] opids = [] + locs = [] shards = [] for input_partition, inshards in sharded: for output_partition, shard in inshards: + loc, data = shard ipids.append(input_partition) opids.append(output_partition) - shards.append(shard) + locs.append(loc) + shards.append(data) if _mean_shard_size(shards) < 65536: # Don't send buffers individually over the tcp comms. # Instead, merge everything into an opaque bytes blob, send it all at once, @@ -240,7 +245,7 @@ async def send( shards_or_bytes = shards def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]: - return self._send(address, ipids, opids, shards_or_bytes) + return self._send(address, ipids, opids, locs, shards_or_bytes) return await retry( _send, @@ -325,13 +330,14 @@ async def receive( self, input_partitions: list[_T_partition_id], output_partitions: list[_T_partition_type], + locs: list[_T_partition_id], data: list[Any] | bytes, ) -> OKMessage | ErrorMessage: try: if isinstance(data, bytes): # Unpack opaque blob. See send() data = cast(list[Any], pickle.loads(data)) - await self._receive(input_partitions, output_partitions, data) + await self._receive(input_partitions, output_partitions, locs, data) return {"status": "OK"} except P2PConsistencyError as e: return error_message(e) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 328b47300c7..1f78da3c369 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -660,16 +660,17 @@ async def _receive( self, input_partitions: list[NDIndex], output_partitions: list[NDIndex], - data: list[tuple[NDIndex, np.ndarray]], + locs: list[NDIndex], + data: list[np.ndarray], ) -> None: self.raise_if_closed() # Repartition shards and filter out already received ones shards = defaultdict(list) - for ipid, opid, dat in zip(input_partitions, output_partitions, data): + for ipid, opid, loc, dat in zip(input_partitions, output_partitions, locs, data): if ipid in self.received: continue - shards[opid].append(dat) + shards[opid].append((loc, dat)) self.total_recvd += sizeof(dat) self.received.update(input_partitions) del input_partitions diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 8f20f1695e0..d59e6601285 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -313,6 +313,7 @@ async def shuffle_receive( run_id: int, input_partitions, output_partitions, + locs, data: list[tuple[int, Any]] | bytes, ) -> OKMessage | ErrorMessage: """ @@ -321,7 +322,7 @@ async def shuffle_receive( """ try: shuffle_run = await self._get_shuffle_run(shuffle_id, run_id) - return await shuffle_run.receive(input_partitions, output_partitions, data) + return await shuffle_run.receive(input_partitions, output_partitions, locs, data) except P2PConsistencyError as e: return error_message(e)