Skip to content

Commit

Permalink
Merge branch 'main' into vkozlov/jetstream-4-maxtext
Browse files Browse the repository at this point in the history
  • Loading branch information
yhtang authored Oct 10, 2024
2 parents ab21dae + e8043a5 commit d9f5c18
Show file tree
Hide file tree
Showing 27 changed files with 423 additions and 176 deletions.
7 changes: 0 additions & 7 deletions .github/container/Dockerfile.maxtext.amd64
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in
EOF

###############################################################################
## Apply patch
###############################################################################

ADD maxtext-mha.patch /opt
RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff

###############################################################################
## Add test script to the path
###############################################################################
Expand Down
7 changes: 0 additions & 7 deletions .github/container/Dockerfile.maxtext.arm64
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in
EOF

###############################################################################
## Apply patch
###############################################################################

ADD maxtext-mha.patch /opt
RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff

###############################################################################
## Add test script to the path
###############################################################################
Expand Down
3 changes: 3 additions & 0 deletions .github/container/Dockerfile.pax.amd64
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do
pushd ${src}
sed -i "s| @ git+https://github.com/google/flax||g" requirements.in
sed -i "s| @ git+https://github.com/google/jax||g" requirements.in
## we pin etils because newer etils versions are not compatible with the
## version of TFDS required by Pax
sed -i "s/etils/etils==1.7.0/g" requirements.in
if git diff --quiet; then
echo "URL specs no longer present in select dependencies for ${src}"
exit 1
Expand Down
11 changes: 6 additions & 5 deletions .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ apt-get clean

rm -rf /var/lib/apt/lists/*

NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1
if [[ -d "${NSYS202451}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
if [[ -d "${NSYS}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
done

# Install extra dependencies needed for `nsys recipe ...` commands. These are
# used by the nsys-jax wrapper script.
Expand Down
54 changes: 46 additions & 8 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
from typing import Any

from .protobuf import HloProto, xla_module_metadata
from .protobuf import HloProto, _host_memory_space, xla_module_metadata
from .utils import make_child_mask, ProfilerData

pd.options.mode.copy_on_write = True
Expand Down Expand Up @@ -38,6 +38,11 @@ def align_profiler_data_timestamps(
# Determine which collective size will be used for the alignment
num_profiled_devices = len(comm_df.index.get_level_values("Device").unique())
max_collective_size = comm_df["CollectiveSize"].max()
if max_collective_size == 1:
print(
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
)
return frames, {}
assert (
num_profiled_devices == max_collective_size
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
Expand Down Expand Up @@ -193,13 +198,51 @@ def _get_message_size(
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"

def _byte_size(inst) -> int:
size_bits = math.prod(
inst.shape.dimensions,
start=element_type_width(inst.shape.element_type),
)
size_bytes, rem = divmod(size_bits, 8)
assert rem == 0
return size_bytes

if comm_inst.opcode == "collective-permute-start":
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
# generates pair-wise send+recv between devices
collective_size = 2
elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}:
# Label host-device transfers orchestrated by dynamic[-update]-slice as single
# device collectives.
collective_size = 1
if comm_inst.opcode == "dynamic-update-slice":
# For dynamic-update-slice the second operand is the one being copied
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1])
transfer_size = _byte_size(src_inst.proto())
else:
# For dynamic-slice the return type size is the transfer size
assert comm_inst.opcode == "dynamic-slice"
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0])
transfer_size = _byte_size(comm_inst)
dest_on_host = _host_memory_space(comm_inst)
src_on_host = _host_memory_space(src_inst.proto())
assert src_on_host != dest_on_host, (
'dynamic[-update]-slice is only considered is only "communication" if it '
"represents a host-device transfer"
)
return (
transfer_size,
"device-to-host" if dest_on_host else "host-to-device",
1, # collective size
1.0, # bw_correction
1.0, # bus_correction
)
else:
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
# devices that are doing pair-wise collectives
Expand All @@ -220,17 +263,12 @@ def _get_message_size(
total_msg_size = 0
for operand_id in comm_inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
msg_size_bits = math.prod(
operand.proto().shape.dimensions,
start=element_type_width(operand.proto().shape.element_type),
)
msg_size_bytes = _byte_size(operand.proto())
if comm_inst.opcode == "reduce-scatter":
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
msg_size_bytes, rem = divmod(msg_size_bytes, collective_size)
assert rem == 0
msg_size_bytes, rem = divmod(msg_size_bits, 8)
assert rem == 0
total_msg_size += msg_size_bytes

collective = comm_inst.opcode.removesuffix("-start")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def is_communication(row):
return _calculate_overlap(thunk_df)


compile_prefix = "XlaCompile:#module="


def _load_nvtx_gpu_proj_trace_single(
prefix: pathlib.Path,
file: pathlib.Path,
Expand Down Expand Up @@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single(
unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates()
if len(unique_pid_tid_pairs) == 1:
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
# If the profile only includes N>1 modules, we may still be able to identify the
# main thread as the one responsible for XlaCompile ranges projected onto the GPU
# timeline
compile_ranges = df.loc[~all_thunks, "Name"].str.startswith(
tsl_prefix + compile_prefix
)
compile_range_ids = compile_ranges[compile_ranges].index
unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates()
if len(unique_pid_tid_pairs) == 1:
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
assert len(main_pid_tid_candidates) < 2
if len(main_pid_tid_candidates) == 1:
# Possibly not correct if len(device_by_pid_tid) > 1
assert len(device_by_pid_tid) > 0
# Associate the main thread with the 0th device in device_by_pid_tid
main_thread_df = device_by_pid_tid.iloc[:1]
main_thread_df.index = pd.MultiIndex.from_tuples(
main_pid_tid_candidates, names=["PID", "TID"]
Expand Down Expand Up @@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace(
return output


compile_prefix = "TSL:XlaCompile:#module="


def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame:
# When parallel compilation is enabled, we end up with worker threads that
# emit NVTX ranges but which are not accounted for in the RangeStack tree.
# Splice these in under the relevant XlaCompile ranges in the RangeStack tree and
# drop everything else.
retain_mask = pd.Series(False, index=compile_df.index)
compile_mask = compile_df["Name"].str.startswith(compile_prefix)
compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix)
for compile_range in compile_df[compile_mask].itertuples():
# Identify the slice of `compile_df` that overlaps in time with this XlaCompile
# range
Expand Down
90 changes: 71 additions & 19 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import defaultdict
import functools
import lzma
import pathlib
import typing


def _host_memory_space(inst):
return inst.shape.layout.memory_space == 5


class StackFrame(typing.NamedTuple):
column: int
file: str
Expand All @@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto):
# proto representing the actual collective, which will be different if the
# async launch is handled by an async-start op
# TODO: can any of copy-start, custom-call, recv, send represent communication?
# This also aims to identify, and (for now) flag as communication, kernels that
# implement device-to-host and host-to-device copies for memory offloading.
# For example, a device-to-host offload might look like
# computation {
# ...
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...)
# }
# async_computation {
# ...
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation
# }
# start = (...) async-start(...), calls=async_computation
# where the :S(5) annotation shows that a buffer is in host memory.
# A host-to-device load might look like
# computation {
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
# ...
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...)
# }
# async_computation {
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
# ...
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation
# }
# start = (...) async-start(...), calls=async_computation
# where the :S(5) memory space annotation is in a parameter instead of in the
# return value.
# For now, handling host-device kernels as single-device "collective"
# communication should be sufficient.
self._comm_proto = None
comm_opcodes = {
"all-gather",
Expand All @@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto):
"all-reduce-start",
"collective-permute-start",
}

def _is_offloading_instruction(inst):
host_dest = _host_memory_space(inst)

def _host_operand(i):
_, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i])
return _host_memory_space(op.proto())

if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0):
return True
elif (
inst.opcode == "dynamic-update-slice"
and host_dest == _host_operand(0)
and host_dest != _host_operand(1)
):
return True
return False

if self._proto.opcode in comm_opcodes | comm_start_opcodes:
self._comm_proto = self._proto
elif self._proto.opcode == "async-start":
elif self._proto.opcode in {"async-start", "fusion"}:
# fusion example:
# computation {
# param_0 = f32[...]{...:S(5)} parameter(0)
# ...
# ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...)
# }
# inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation
# This might be thinly wrapping an opcode in `comm_opcodes`
other_opcodes = defaultdict(int)
for called_id in self._proto.called_computation_ids:
for called_inst in wrapped_hlo_proto.find_computation(
called_id
).instructions:
if called_inst.opcode in comm_opcodes:
def _visit_computation(computation_id):
computation = wrapped_hlo_proto.find_computation(computation_id)
for called_inst in computation.instructions:
for called_id in called_inst.called_computation_ids:
_visit_computation(called_id)
if called_inst.opcode in comm_opcodes or _is_offloading_instruction(
called_inst
):
assert (
self._comm_proto is None
), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
self._comm_proto = called_inst
else:
other_opcodes[called_inst.opcode] += 1
assert (
other_opcodes.keys() == {"parameter"}
), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}"

for called_id in self._proto.called_computation_ids:
_visit_computation(called_id)

def communication_proto(self):
return self._comm_proto
Expand All @@ -68,12 +125,7 @@ def is_communication(self) -> bool:
a little more complicated than you might hope, because async communications are
not handled uniformly.
"""
if self._comm_proto is None:
return False
assert (
self._comm_proto.channel_id != 0
), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}"
return True
return self._comm_proto is not None

def proto(self):
"""
Expand Down
5 changes: 5 additions & 0 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,8 @@ orbax-checkpoint:
tracking_ref: main
latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f
mode: pip-vcs
pathwaysutils:
url: https://github.com/google/pathways-utils.git
tracking_ref: main
latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9
mode: pip-vcs
15 changes: 0 additions & 15 deletions .github/container/maxtext-mha.patch

This file was deleted.

3 changes: 0 additions & 3 deletions .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ else
fi

for t in $*; do
if [[ "$t" != "//tests:"* ]]; then
t="//tests:${t}"
fi
BAZEL_TARGET="${BAZEL_TARGET} $t"
done

Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,14 @@ jobs:
STATISTICS_SCRIPT: |
summary_line=$(tail -n1 test-te.log)
errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "passed") | .outcome' | wc -l)
failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "failed") | .outcome' | wc -l)
passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l)
failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l)
total_tests=$((failed_tests + passed_tests))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
TIMEOUT_MINUTES: 120
ARTIFACTS: |
test-te.log
pytest-report.jsonl
Expand Down
Loading

0 comments on commit d9f5c18

Please sign in to comment.