Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

task 801 - launch physical meshes after compilation #938

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions alpa/create_state_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
import numpy as np

from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup
from alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup, VirtualMeshGroup
from alpa.global_env import global_config
from alpa.mesh_executable import (NormalMeshDriverExecutable,
GradAccMeshDriverExecutable)
Expand All @@ -30,12 +30,14 @@ class CreateStateExecutable(PipeshardDriverExecutable):

def __init__(self,
mesh_group: PhysicalDeviceMeshGroup,
virtual_mesh_group: VirtualMeshGroup,
pipeshard_config: PipeshardConfig,
target_placement_specs: Sequence[PlacementSpec],
in_tree: PyTreeDef,
out_tree: Optional[PyTreeDef] = None,
static_argnums: Optional[Sequence[int]] = None):
super().__init__(mesh_group=mesh_group,
virtual_mesh_group= virtual_mesh_group,
pipeshard_config=pipeshard_config,
num_batch=1,
layer_option=None,
Expand Down Expand Up @@ -134,13 +136,14 @@ def compile_create_state_executable(fun, in_tree, out_tree_thunk,
sliced_eqns)

# Compile a pipeshard executable with predefined output shardings
pipeshard_config = compile_pipeshard_executable_internal(
pipeshard_config, _ , virtual_mesh_group = compile_pipeshard_executable_internal(
new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals),
executable.mesh_group.parent, 1, "inference",
AutoShardingOption(enable_auto_sharding=False),
UniformStageOption(), name, None, output_shardings, None, None)

return CreateStateExecutable(mesh_group=executable.mesh_group,
virtual_mesh_group= virtual_mesh_group,
pipeshard_config=pipeshard_config,
target_placement_specs=placement_specs,
in_tree=in_tree,
Expand Down
72 changes: 72 additions & 0 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,6 +2305,78 @@ def profile_all(self, *args, **kwargs):
return mesh_profiling.profile_all(self, *args, **kwargs)


#TODO Github Task - CustomVirtualMesh for interfaces
class VirtualWorker:
def __init__(self, index):
self.index = index
# Additional attributes or methods of virtual workers

class CustomVirtualMesh(VirtualPhysicalMesh):
ZYHowell marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self,
host_ids: Sequence[int],
host_info: Sequence[dict],
num_devices_per_host,
parent: "VirtualPhysicalMesh" = None,
devices: Sequence[Sequence[int]] = None,
mesh_id: int = None
):
super().__init__(host_ids, host_info, num_devices_per_host, parent, devices)
self.host_ips = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This member seems never used. If so, please remove it

self.workers = [] # Virtual workers
self.mesh_id = mesh_id

for host_id in host_ids:
self.host_ips.append(host_info[host_id]['NodeName'])
self.workers.append(VirtualWorker(mesh_id))


#TODO Github Task - VirtualMeshGroup for interfaces
class VirtualMeshGroup:
def __init__(self, sliced_virtual_meshes: List[VirtualPhysicalMesh]):
self.sliced_virtual_meshes = self.get_virtual_meshes(sliced_virtual_meshes)
self.collective_groups: List[List[Any]] = [
[None for _ in range(len(self))] for _ in range(len(self))
]
self.launched_nccl = False

def __getitem__(self, index):
return self.sliced_virtual_meshes[index]

def __len__(self):
return len(self.sliced_virtual_meshes)

def index(self, *args, **kwargs):
return self.sliced_virtual_meshes.index(*args, **kwargs)

def get_virtual_meshes(self, sliced_virtual_meshes):
custom_sliced_virtual_meshes = []
for mesh_idx, mesh in enumerate(sliced_virtual_meshes):
custom_mesh = CustomVirtualMesh(mesh.host_ids, mesh.host_info, mesh.num_devices_per_host, mesh.parent, mesh.devices, mesh_idx)
custom_sliced_virtual_meshes.append(custom_mesh)
return custom_sliced_virtual_meshes

def establish_nccl_group(self,
src_mesh_id: int,
dst_mesh_id: int,
instantiate=False
):
"""Establish NCCL group between two meshes."""
# pylint: disable=import-outside-toplevel
from alpa.pipeline_parallel.cross_mesh_resharding import CollectiveGroup

assert src_mesh_id < dst_mesh_id
if self.collective_groups[src_mesh_id][dst_mesh_id] is not None:
# Already established
return
src_mesh = self.sliced_virtual_meshes[src_mesh_id]
dst_mesh = self.sliced_virtual_meshes[dst_mesh_id]
device_strs = OrderedSet(src_mesh.device_strs + dst_mesh.device_strs)
cg = CollectiveGroup(device_strs, src_mesh, dst_mesh)
self.collective_groups[src_mesh_id][dst_mesh_id] = cg
self.collective_groups[dst_mesh_id][src_mesh_id] = cg



# Global runtime objects
global_cluster: DeviceCluster = None
global_physical_mesh: PhysicalDeviceMesh = None
Expand Down
28 changes: 22 additions & 6 deletions alpa/pipeline_parallel/compile_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax.interpreters import pxla
from jax.tree_util import PyTreeDef

from alpa.device_mesh import VirtualPhysicalMesh
from alpa.device_mesh import VirtualPhysicalMesh, VirtualMeshGroup
from alpa.global_env import global_config
from alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable
from alpa.pipeline_parallel.runtime_emitter import (
Expand Down Expand Up @@ -108,14 +108,19 @@ def compile_pipeshard_executable(
in_tree, out_tree)
else:
parsed_ms_option = None
pipeshard_config = compile_pipeshard_executable_internal(
pipeshard_config, sliced_virtual_meshes, virtual_meshes = compile_pipeshard_executable_internal(
closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars,
batch_invars, virtual_mesh, num_microbatch, pipeline_schedule,
default_as_option, stage_option, name_base, global_input_shardings,
None, stage_input_shardings, parsed_ms_option)

#ToDO Github Task - Adding two lines here
if virtual_mesh.launched_physical_mesh_group is None:
virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes)

executable = PipeshardDriverExecutable(
mesh_group=virtual_mesh.launched_physical_mesh_group,
virtual_mesh_group=virtual_meshes,
pipeshard_config=pipeshard_config,
num_batch=num_microbatch,
layer_option=layer_option,
Expand Down Expand Up @@ -147,6 +152,7 @@ def compile_pipeshard_executable_internal(
stage_input_shardings: Forcibly set sharding specs of input vars of
each stage.
"""
global virtual_meshes
global_invars = closed_jaxpr.jaxpr.invars
gensym_func = gensym([closed_jaxpr.jaxpr])
inference_mode = (pipeline_schedule == "inference")
Expand Down Expand Up @@ -245,8 +251,16 @@ def compile_pipeshard_executable_internal(
debug_compilation_time("shard stages")

# Launch the physical mesh group
if virtual_mesh.launched_physical_mesh_group is None:
virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes)
# if virtual_mesh.launched_physical_mesh_group is None:
# virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes)

nccl_instantiated = False
if 'virtual_meshes' in globals() and virtual_meshes is not None and virtual_mesh.launched_physical_mesh_group is not None:
ZYHowell marked this conversation as resolved.
Show resolved Hide resolved
nccl_instantiated = virtual_meshes.launched_nccl

virtual_meshes = VirtualMeshGroup(sliced_virtual_meshes)
virtual_meshes.launched_nccl = nccl_instantiated

debug_compilation_time("launch meshes")

# Wrap all things into a distributed runtime
Expand All @@ -256,7 +270,8 @@ def compile_pipeshard_executable_internal(
grad_dummy_invars=accumulator_mapping,
global_outvars=global_outvars,
concat_vars_mapping=concat_vars_mapping,
mesh_group=virtual_mesh.launched_physical_mesh_group,
# mesh_group=virtual_mesh.launched_physical_mesh_group,
mesh_group=virtual_meshes,
schedule=schedule,
is_batch=batch_invars,
num_batch=num_microbatch,
Expand All @@ -274,7 +289,8 @@ def compile_pipeshard_executable_internal(
pipeshard_config = emitter_cls(**emitter_kwargs).compile()

debug_compilation_time("runtime emitter")
return pipeshard_config
return pipeshard_config, sliced_virtual_meshes, virtual_meshes
ZYHowell marked this conversation as resolved.
Show resolved Hide resolved



def split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr,
Expand Down
68 changes: 52 additions & 16 deletions alpa/pipeline_parallel/cross_mesh_resharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from alpa.device_mesh import (DistributedArray, RemoteArrayRef,
ReshardingRecvSpec, ReshardingSendSpec,
ReshardingTileSpec, ReshardingBroadcastSpec,
_device_mesh_put_dummy, device_id_to_str)
_device_mesh_put_dummy, device_id_to_str, VirtualWorker)
from alpa.global_env import global_config
from alpa.mesh_executable import (UtilMeshWorkerExecutable,
next_mesh_executable_uuid)
Expand Down Expand Up @@ -195,6 +195,8 @@ def __init__(self, task_spec, collective_group, src_mesh, dst_mesh):
self.send_worker_task_ids = {}
self.recv_worker_task_ids = {}

self.task_dones = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused code. For a series of tasks, we always create a tmp task dones instead of keeping using the same one


# generate the above states
self._compile()
# print(self.__str__()+"\n")
Expand All @@ -220,6 +222,7 @@ def _compile(self):
"""
self._compile_send_recv_tasks()

#TODO Github task - moving this to pipeshard_executable
if not global_config.debug_with_pipeshard_runtime:
self.put_all_tasks()

Expand All @@ -229,19 +232,34 @@ def put_all_tasks(self):
"""
# put send and recv tasks
task_dones = []
temp_worker = None
for worker, task in self.sender_tasks.items():
uuid = next_resharding_task_uuid()
if isinstance(worker, VirtualWorker):
for actor, idx in self.collective_group.worker_to_rank_map.items():
if idx == worker.index:
temp_worker = actor
worker = temp_worker
self.send_worker_task_ids[worker] = uuid
task_dones.append(
worker.put_resharding_send_task.remote(
uuid, task, self.collective_group.group_name))

if not isinstance(worker, VirtualWorker):
task_dones.append(
worker.put_resharding_send_task.remote(
uuid, task, self.collective_group.group_name))
for worker, task in self.receiver_tasks.items():
uuid = next_resharding_task_uuid()
if isinstance(worker, VirtualWorker):
for actor, idx in self.collective_group.worker_to_rank_map.items():
if idx == worker.index:
temp_worker = actor
worker = temp_worker
self.recv_worker_task_ids[worker] = uuid
task_dones.append(
worker.put_resharding_recv_task.remote(
uuid, task, self.collective_group.group_name))
ray.get(task_dones)
if not isinstance(worker, VirtualWorker):
task_dones.append(
worker.put_resharding_recv_task.remote(
uuid, task, self.collective_group.group_name))
if len(task_dones) > 0:
ray.get(task_dones)

# put allgather tasks
task_dones = []
Expand All @@ -252,17 +270,28 @@ def put_all_tasks(self):
task_spec.dst_sharding_spec,
task_spec.final_dst_spec,
np.prod(self.dst_mesh.shape))

for worker in self.dst_mesh.workers:
task_dones.append(
worker.put_executable.remote(uuid, UtilMeshWorkerExecutable,
hlo))
ray.get(task_dones)
if isinstance(worker, VirtualWorker):
for actor, idx in self.collective_group.worker_to_rank_map.items():
if idx == worker.index:
temp_worker = actor
worker = temp_worker
if not isinstance(worker, VirtualWorker):
task_dones.append(
worker.put_executable.remote(uuid, UtilMeshWorkerExecutable,
hlo))
if len(task_dones) > 0:
ray.get(task_dones)

def create_resharding_communicators(self):
"""Create the NCCL communicators in advance."""
communicator_params = set()
for worker, recv_tasks in self.receiver_tasks.items():
dst_rank = self.collective_group.worker_to_rank_map[worker]
if isinstance(worker, VirtualWorker):
dst_rank = worker.index
else:
dst_rank = self.collective_group.worker_to_rank_map[worker]
for recv_task in recv_tasks:
dst_gpu_idx = recv_task.device_id
tile_specs = recv_task.tile_specs
Expand Down Expand Up @@ -456,11 +485,18 @@ def put_all_tasks(self):
task_dones = []
for worker, task in self._broadcast_tasks.items():
uuid = next_resharding_task_uuid()
if isinstance(worker, VirtualWorker):
for actor, idx in self.collective_group.worker_to_rank_map.items():
if idx == worker.index:
temp_worker = actor
worker = temp_worker
self.broadcast_worker_task_ids[worker] = uuid
# print(worker, uuid, task)
task_dones.append(
worker.put_resharding_broadcast_task.remote(
uuid, task, self.collective_group.group_name))
if not isinstance(worker, VirtualWorker):
task_dones.append(
worker.put_resharding_broadcast_task.remote(
uuid, task, self.collective_group.group_name))

ray.get(task_dones)

def _compile_broadcast_tasks(self):
Expand Down
Loading