Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

feat(compiler): implement last-seen heuristic for resharding source #917

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions alpa/pipeline_parallel/cross_mesh_resharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import math
import random
import time
from typing import List, Any
from typing import Dict, List, Any, Sequence
from alpa.pipeline_parallel.schedules import PipelineSchedule

from jax.core import Var
from jax.interpreters import pxla
import numpy as np
import ray
Expand Down Expand Up @@ -682,7 +684,8 @@ class ReshardingTaskSpec:
VirtualDistributedArray.
"""

def __init__(self, src_array, dst_array, final_dst_spec):
def __init__(self, src_array: VirtualDistributedArray,
dst_array: VirtualDistributedArray, final_dst_spec):
self.src = src_array
self.dst = dst_array
self._dst_tile_to_src_tiles_map = None
Expand Down Expand Up @@ -949,7 +952,8 @@ class CrossMeshCommunicator:
schedule (Any): the pipelining schedule for these stages.
"""

def __init__(self, sharded_stages, schedule):
def __init__(self, sharded_stages: Sequence[XlaShardedPipelineComputation],
schedule: PipelineSchedule):
if not isinstance(sharded_stages, list):
raise RuntimeError("Require a list of stages.")
for s in sharded_stages:
Expand Down Expand Up @@ -1091,6 +1095,9 @@ def _create_resharding_specs(self):
[{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh)
]

# We will grab the var from the stage where it is last an input, if any
# We will map it to the corresponding stage index where it is last seen
last_seen: Dict[Var, int] = {}
# find stages that will communicate
pairs = np.argwhere(deps > 0)
for i in range(pairs.shape[0]):
Expand All @@ -1116,29 +1123,48 @@ def _create_resharding_specs(self):
out_sharding_specs = src_stage.output_sharding_specs
in_sharding_specs = dst_stage.input_sharding_specs

# Make a ReshardSpec for each VirtualDistributedArray
# Make a ReshardingTaskSpec for each VirtualDistributedArray
for var, out_var_index, in_var_index in zip(resharding_vars,
out_var_indices,
in_var_indices):
src_sharding_spec = out_sharding_specs[out_var_index]
dst_sharding_spec = in_sharding_specs[in_var_index]
if var in last_seen:
last_seen_stage_index = last_seen[var]
last_seen[var] = dst_stage_index

last_seen_var_index = last_seen_stage.invars.index(var)
last_seen_sharding_spec = last_seen_stage.input_sharding_specs[
last_seen_var_index]

last_seen_stage = stages[last_seen_stage_index]
last_seen_mesh_index = stage_placements[
last_seen_stage_index]
last_seen_mesh = meshes[last_seen_mesh_index]
final_src_array = VirtualDistributedArray(
device_mesh=last_seen_mesh,
aval=var.aval,
sharding_spec=last_seen_sharding_spec)
final_src_mesh_index = last_seen_mesh_index
else:
last_seen[var] = dst_stage_index
src_sharding_spec = out_sharding_specs[out_var_index]
final_src_array = VirtualDistributedArray(
device_mesh=src_mesh,
aval=var.aval,
sharding_spec=src_sharding_spec)
final_src_mesh_index = src_mesh_index

dst_sharding_spec = in_sharding_specs[in_var_index]
final_dst_spec = dst_sharding_spec
if global_config.resharding_mode == "send_recv":
dst_sharding_spec = self._rewrite_allgather_spec(
dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape)

src_array = VirtualDistributedArray(
device_mesh=src_mesh,
aval=var.aval,
sharding_spec=src_sharding_spec)
dst_array = VirtualDistributedArray(
device_mesh=dst_mesh,
aval=var.aval,
sharding_spec=dst_sharding_spec)
task_spec = ReshardingTaskSpec(src_array, dst_array,
task_spec = ReshardingTaskSpec(final_src_array, dst_array,
final_dst_spec)
self.resharding_specs[src_mesh_index][dst_mesh_index][
self.resharding_specs[final_src_mesh_index][dst_mesh_index][
var] = task_spec

def task_spec_iter(self):
Expand Down Expand Up @@ -1425,7 +1451,8 @@ def _generate_broadcast_resharding_strategy_by_loads(
return strategy

@staticmethod
def _args_between(src_stage, dst_stage):
def _args_between(src_stage: XlaShardedPipelineComputation,
dst_stage: XlaShardedPipelineComputation):
"""Find the variable exchanged between stages."""
resharding_vars = []
src_indices = []
Expand Down