Skip to content

Commit

Permalink
nsys-jax: add basic CI, support all-to-all and repeated thunks (#877)
Browse files Browse the repository at this point in the history
Basic lint checks of jax_nsys and nsys-jax code.
Address some of @gspschmid's comments from
#863.
Replaces #875, this time with
a source branch that is not in a fork.
  • Loading branch information
olupton authored Jun 5, 2024
1 parent 800e665 commit cd3f8fb
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 79 deletions.
57 changes: 38 additions & 19 deletions .github/container/jax_nsys/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict, namedtuple\n",
"from collections import defaultdict\n",
"from jax_nsys import (\n",
" calculate_collective_metrics,\n",
" compile_protos,\n",
Expand All @@ -20,8 +20,9 @@
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import pandas as pd\n",
"import sys"
"import pandas as pd # type: ignore\n",
"import sys\n",
"from typing import NamedTuple"
]
},
{
Expand Down Expand Up @@ -182,13 +183,19 @@
" )\n",
"\n",
"\n",
"def reduce_module_stats(module_stats):\n",
"class Summary(NamedTuple):\n",
" mean: float\n",
" std: float\n",
" total: float\n",
"\n",
"\n",
"def reduce_module_stats(module_stats) -> dict[str, Summary]:\n",
" # [{\"a\": 0.3}, {\"a\": 0.4}] -> {\"a\": (0.35, stddev), \"#Instances\": 2}\n",
" r = {\"#Instances\": len(module_stats)}\n",
" num_instances = len(module_stats)\n",
" r = {\"#Instances\": Summary(mean=num_instances, std=0.0, total=num_instances)}\n",
" keys = module_stats[0].keys()\n",
" for stats in module_stats[1:]:\n",
" assert stats.keys() == keys\n",
" Summary = namedtuple(\"Number\", [\"mean\", \"std\", \"total\"])\n",
" for k in keys:\n",
" values = [stats[k] for stats in module_stats]\n",
" r[k] = Summary(mean=np.mean(values), std=np.std(values), total=np.sum(values))\n",
Expand All @@ -197,21 +204,26 @@
"\n",
"# Aggregate HLO module statistics over repeated executions of them\n",
"agg_module_stats = [(k, reduce_module_stats(v)) for k, v in module_stats.items()]\n",
"sort_key = lambda x: x[1][\"GPU time [ms]\"].total\n",
"\n",
"\n",
"def sort_key(x):\n",
" return x[1][\"GPU time [ms]\"].total\n",
"\n",
"\n",
"agg_module_stats.sort(key=sort_key, reverse=True)\n",
"total = sum(sort_key(x) for x in agg_module_stats)\n",
"print(\" Active GPU time #Exec. #Thunks Module name\")\n",
"accounted_time, top_n = 0.0, None\n",
"for n, tup in enumerate(agg_module_stats):\n",
" module_name, module_stats = tup\n",
" module_name, stats = tup\n",
" module_time = sort_key(tup)\n",
" print(\n",
" \" {:7.2f}% {:9.2f}ms {:5} {:5.0f}±{:<3.0f} {}\".format(\n",
" 100.0 * module_time / total,\n",
" module_time,\n",
" module_stats[\"#Instances\"],\n",
" module_stats[\"#Thunks\"].mean,\n",
" module_stats[\"#Thunks\"].std,\n",
" stats[\"#Instances\"].mean,\n",
" stats[\"#Thunks\"].mean,\n",
" stats[\"#Thunks\"].std,\n",
" module_name,\n",
" )\n",
" )\n",
Expand Down Expand Up @@ -263,9 +275,9 @@
"\n",
"# Project the thunk runtime data onto some other data structures, to be\n",
"# presented in different ways.\n",
"op_runtime = defaultdict(float)\n",
"op_name_runtime = defaultdict(float)\n",
"src_runtime = defaultdict(float)\n",
"op_runtime: dict[str, float] = defaultdict(float)\n",
"op_name_runtime: dict[tuple[str, ...], float] = defaultdict(float)\n",
"src_runtime: dict[tuple[str, ...], float] = defaultdict(float)\n",
"\n",
"# Dummy entries to massage the source code view\n",
"gpu_active = [\"[GPU active]\"]\n",
Expand Down Expand Up @@ -304,10 +316,15 @@
" for called_comp_id in hlo_inst.called_computation_ids\n",
" for called_inst in hlo_module.find_computation(called_comp_id).instructions\n",
" ]\n",
" src_runtime_preferences = [set(), set(), [tuple(gpu_active_unknown)]]\n",
" op_name_runtime_preferences = [set(), [tuple(gpu_active_unknown)]]\n",
" non_empty_stack_traces = set()\n",
" non_empty_op_names = set()\n",
" src_runtime_preferences: tuple[set[tuple[str, ...]], ...] = (\n",
" set(),\n",
" set(),\n",
" {tuple(gpu_active_unknown)},\n",
" )\n",
" op_name_runtime_preferences: tuple[set[tuple[str, ...]], ...] = (\n",
" set(),\n",
" {tuple(gpu_active_unknown)},\n",
" )\n",
" for inst in [hlo_inst] + called_instructions:\n",
" frames = hlo_module.get_stack_frames(inst.metadata.stack_frame_id)\n",
" op_name = [inst.metadata.op_name] if len(inst.metadata.op_name) else []\n",
Expand Down Expand Up @@ -413,7 +430,9 @@
" # program, there may be different sub-groupings that are participating in smaller\n",
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
" # sub-groupings and group them, but we currently lack the relevant information.\n",
" collective_df = df.groupby([\"ProgramId\", \"Name\", \"ModuleExecution\"])\n",
" collective_df = df.groupby(\n",
" [\"ProgramId\", \"Name\", \"ModuleExecution\", \"ThunkExecution\"]\n",
" )\n",
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
" # collective.\n",
" bandwidth_df = collective_df.agg(\n",
Expand Down
12 changes: 9 additions & 3 deletions .github/container/jax_nsys/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# The expectation is that those archives will be copied and extracted on a
# laptop or workstation, and this installation script will be run there, while
# the `nsys-jax` wrapper is executed on a remote GPU cluster.
set -ex
SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
VIRTUALENV="${SCRIPT_DIR}/nsys_jax_venv"
if [[ ! -d "${VIRTUALENV}" ]]; then
Expand All @@ -18,12 +19,17 @@ if [[ ! -d "${VIRTUALENV}" ]]; then
. "${VIRTUALENV}/bin/activate"
python -m pip install -U pip
"${SCRIPT_DIR}/nsys-jax-ensure-protobuf"
python -m pip install jupyterlab
# matplotlib is a dependency of Analysis.ipynb but not jax_nsys
python -m pip install jupyterlab matplotlib
python -m pip install -e "${SCRIPT_DIR}/python/jax_nsys"
curl -o "${VIRTUALENV}/bin/flamegraph.pl" https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl
chmod 755 "${VIRTUALENV}/bin/flamegraph.pl"
else
echo "Virtual environment already exists, not installing anything..."
fi
echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb"
cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb
if [ -z ${NSYS_JAX_INSTALL_SKIP_LAUNCH+x} ]; then
echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb"
cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb
else
echo "Skipping launch of jupyterlab due to NSYS_JAX_INSTALL_SKIP_LAUNCH"
fi
15 changes: 11 additions & 4 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import pandas as _pd

_pd.options.mode.copy_on_write = True

from .analysis import calculate_collective_metrics, generate_compilation_statistics
from .data_loaders import load_profiler_data
from .protobuf import xla_module_metadata
from .protobuf_utils import compile_protos
from .utils import remove_child_ranges
from .visualization import create_flamegraph, display_flamegraph

__all__ = [
"calculate_collective_metrics",
"compile_protos",
"create_flamegraph",
"display_flamegraph",
"generate_compilation_statistics",
"load_profiler_data",
"remove_child_ranges",
"xla_module_metadata",
]
89 changes: 54 additions & 35 deletions .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import functools
import math
import numpy as np
import pandas as pd
import pandas as pd # type: ignore

from .protobuf import xla_module_metadata
from .utils import make_child_mask

pd.options.mode.copy_on_write = True

def element_type_in_bits(element_type: int) -> int:

def element_type_width(element_type: int) -> int:
"""
Given an int representing an XLA PrimitiveType enum value, return the width of that
type in bits.
Expand All @@ -29,8 +31,35 @@ def element_type_in_bits(element_type: int) -> int:
raise Exception(f"Could not deduce size of {enum_name}")


def _collective_correction(kind: str, size: int) -> tuple[float, float]:
"""
Calculate the correction factor from algorithm bandwidth to bus bandwidth, see:
https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth
"""
match kind:
# For AllGather the size in the bandwidth calculation is the total/output size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather
case "all-gather":
return (size, (size - 1) / size)
case "all-reduce":
return (1, 2 * (size - 1) / size)
case "all-to-all":
# https://github.com/NVIDIA/nccl-tests/blob/a1efb427e764241bc43d2d91be875c9f55da03a5/src/alltoall.cu#L44
return (1, (size - 1) / size)
case "collective-broadcast":
return (1, 1)
case "collective-permute":
return (1, 1)
# For ReduceScatter the size in the bandwidth calculation is the total size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter
case "reduce-scatter":
return (size, (size - 1) / size)
case _:
assert False, f"Unknown collective kind {kind}"


@functools.lru_cache
def get_message_size(program_id: int, instruction: str) -> int:
def get_message_size(program_id: int, instruction: str) -> pd.Series:
"""
Given the name of a collective instruction (e.g. all-gather-start.N), calculate the
message size in bytes. See https://openxla.org/xla/operation_semantics#allgather,
Expand All @@ -40,56 +69,46 @@ def get_message_size(program_id: int, instruction: str) -> int:
"""
module_proto = xla_module_metadata(program_id)
_, inst = module_proto.find_instruction(instruction)
assert inst.opcode in {
"all-gather-start",
"all-reduce-start",
"collective-broadcast",
"collective-permute-start",
"reduce-scatter",
}, f"{instruction}: message size calculation for {inst.opcode} has not yet been validated"
assert (
inst.opcode
in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {inst.opcode} has not yet been validated"
if inst.opcode == "collective-permute-start":
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
# generates pair-wise send+recv between devices
collective_size = 2
else:
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
# devices that are doing pair-wise collectives
collective_sizes = tuple(
{len(group.replica_ids) for group in inst.replica_groups}
)
assert (
len(collective_sizes) == 1
collective_size = len(inst.replica_groups[0].replica_ids)
assert all(
len(group.replica_ids) == collective_size for group in inst.replica_groups
), f"Heterogeneous collective {inst.replica_groups} could not be interpreted"
collective_size = collective_sizes[0]
total_msg_size = 0
for operand_id in inst.operand_ids:
_, operand = module_proto.find_instruction_by_id(operand_id)
msg_size_bits = math.prod(
operand.shape.dimensions,
start=element_type_in_bits(operand.shape.element_type),
start=element_type_width(operand.shape.element_type),
)
if inst.opcode == "reduce-scatter":
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
assert msg_size_bits % collective_size == 0
msg_size_bits //= collective_size
assert msg_size_bits % 8 == 0
total_msg_size += msg_size_bits // 8
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
assert rem == 0
msg_size_bytes, rem = divmod(msg_size_bits, 8)
assert rem == 0
total_msg_size += msg_size_bytes

# Calculate the correction factor from algorithm bandwidth to bus bandwidth, see:
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bus-bandwidth
collective = inst.opcode.removesuffix("-start")
bw_correction, bus_correction = {
# For AllGather the size in the bandwidth calculation is the total/output size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#allgather
"all-gather": (collective_size, (collective_size - 1) / collective_size),
"all-reduce": (1, 2 * (collective_size - 1) / collective_size),
"collective-broadcast": (1, 1),
"collective-permute": (1, 1),
# For ReduceScatter the size in the bandwidth calculation is the total size
# https://github.com/NVIDIA/nccl-tests/blob/c6afef0b6f76ffc55d4172d971be6cf5a08a73a4/doc/PERFORMANCE.md#reducescatter
"reduce-scatter": (collective_size, (collective_size - 1) / collective_size),
}[collective]
bw_correction, bus_correction = _collective_correction(collective, collective_size)
return pd.Series(
[total_msg_size, collective, collective_size, bw_correction, bus_correction],
index=[
Expand Down Expand Up @@ -153,7 +172,7 @@ def generate_compilation_statistics(compile_df: pd.DataFrame) -> pd.DataFrame:
main_thread = main_thread[0]

# Aggregate compilation stats in here
compile_time_ns = defaultdict(lambda: np.zeros(2))
compile_time_ns: dict[str, np.ndarray] = defaultdict(lambda: np.zeros(2))

# Identify the ranges in the main thread that represent parallel compilation, i.e.
# ranges whose child ranges are in different threads.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import lzma
import numpy as np
import pandas as pd
import pandas as pd # type: ignore
import pathlib
import re

from .protobuf import xla_module_metadata
from .utils import make_child_mask

pd.options.mode.copy_on_write = True


def _classify_comms(thunk_df: pd.DataFrame, prefix: pathlib.Path) -> pd.DataFrame:
# Classify each thunk as either communication or computation, as we only
Expand Down Expand Up @@ -245,6 +247,14 @@ def clean_data_frame(d, extra_columns=[]):
value=r"\2",
regex=True,
)
# Add a new column describing which (0th, 1st, ...) execution of the thunk
# within the given module execution this is. For example, while loops in the
# HLO can lead to the same thunk being executed multiple times within the same
# module execution.
thunk_df["ThunkExecution"] = thunk_df.groupby(
["TID", "ProgramId", "Name", "ModuleExecution"]
).cumcount()

# Classify thunks as communication/computation and save to output
output["thunk"] = _classify_comms(thunk_df, prefix)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# WARNING: it is tacitly assumed that the protobuf compiler (protoc) is
# compatible with the google.protobuf version.
import glob
import google.protobuf
import os
import pathlib
import shutil
import subprocess
import sys
from typing import Optional


def which(executable: str) -> pathlib.Path:
Expand All @@ -28,7 +28,11 @@ def which(executable: str) -> pathlib.Path:
return pathlib.Path(exe)


def compile_protos(proto_dir: str | pathlib.Path, output_dir: str | pathlib.Path):
def compile_protos(
proto_dir: str | pathlib.Path,
output_dir: str | pathlib.Path,
output_stub_dir: Optional[str | pathlib.Path] = None,
):
if not os.path.isdir(proto_dir):
raise Exception(f"Input: {proto_dir} is not a directory")
if not os.path.isdir(output_dir):
Expand All @@ -39,6 +43,12 @@ def compile_protos(proto_dir: str | pathlib.Path, output_dir: str | pathlib.Path
raise Exception(f"Did not find any .proto files under {proto_dir}")
protoc = which("protoc")
# Generate code to load the protobuf files
args: list[str | pathlib.Path] = [protoc, f"-I={proto_dir}", f"--python_out={output_dir}"]
args: list[str | pathlib.Path] = [
protoc,
f"-I={proto_dir}",
f"--python_out={output_dir}",
]
if output_stub_dir is not None:
args.append(f"--pyi_out={output_stub_dir}")
args += proto_files
subprocess.run(args, check=True)
4 changes: 3 additions & 1 deletion .github/container/jax_nsys/python/jax_nsys/jax_nsys/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import pandas as pd # type: ignore
from typing import Optional

pd.options.mode.copy_on_write = True


def make_child_mask(df: pd.DataFrame, parent_row: int) -> pd.Series:
"""
Expand Down
Loading

0 comments on commit cd3f8fb

Please sign in to comment.