-
Notifications
You must be signed in to change notification settings - Fork 48
280 lines (268 loc) · 10.9 KB
/
nsys-jax.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
name: nsys-jax pure-Python CI
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
on:
pull_request:
types:
- opened
- reopened
- ready_for_review
- synchronize
paths-ignore:
- '**.md'
push:
branches:
- main
env:
NSYS_JAX_PYTHON_FILES: |
JAX-Toolbox/.github/container/nsys-jax
JAX-Toolbox/.github/container/nsys-jax-combine
JAX-Toolbox/.github/container/jax_nsys
JAX-Toolbox/.github/container/jax-nccl-test
jobs:
mypy:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
# jax is just a CPU-only build of the latest release for type-checking purposes
- name: "Install jax / jax_nsys / mypy"
run: pip install jax -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys matplotlib mypy nbconvert types-protobuf
- name: "Install protoc"
run: ./JAX-Toolbox/.github/container/jax_nsys/install-protoc local_protoc
- name: "Fetch XLA .proto files"
uses: actions/checkout@v4
with:
path: xla
repository: openxla/xla
sparse-checkout: |
*.proto
sparse-checkout-cone-mode: false
- name: "Compile .proto files"
shell: bash -x -e {0}
run: |
mkdir compiled_protos compiled_stubs protos
mv -v xla/third_party/tsl/tsl protos/
mv -v xla/xla protos/
PATH=${PWD}/local_protoc/bin:$PATH python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
touch compiled_stubs/py.typed
- name: "Convert .ipynb to .py"
shell: bash -x -e {0}
run: |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do
jupyter nbconvert --to script ${notebook}
done
- name: "Run mypy checks"
shell: bash -x -e {0}
run: |
export MYPYPATH="${PWD}/compiled_stubs"
mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES}
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered
# notebook from here too. These input files were generated with something like
# srun -n 4 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-4proc-proc%q{SLURM_PROCID}
# -- test-pax.sh --steps=5 --fsdp=4 --multiprocess
# with newer nsys-jax components bind-mounted in.
combine:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: "Setup Python 3.12"
uses: actions/setup-python@v5
with:
python-version: '3.12'
# TODO: a modern nsys-jax-combine with old .zip input should probably produce a
# .zip with a modern jax_nsys/
- name: Add modern jax_nsys/ files to static .zip inputs
run: |
cd .github/container/jax_nsys
for zip in ../../workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip; do
zip -ur "${zip}" .
zipinfo "${zip}"
done
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes
shell: bash -x -e {0}
run: |
pip install -e .github/container/jax_nsys/python/jax_nsys
python .github/container/jax_nsys/install-protoc local_protoc
PATH=${PWD}/local_protoc/bin:$PATH .github/container/nsys-jax-combine \
--analysis summary \
--analysis communication \
-o pax_fsdp4_4proc.zip \
.github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip
- name: Extract the output .zip file
run: |
mkdir combined/
unzip -d combined/ pax_fsdp4_4proc.zip
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./combined/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd combined/
./nsys_jax_venv/bin/python -m jupyterlab --version
# Run with ipython for the sake of getting a clear error message
./nsys_jax_venv/bin/ipython Analysis.ipynb
# This input file was generated with something like
# srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-1proc -- test-pax.sh
# --steps=5 --fsdp=4
notebook:
env:
# TODO: these could/should be saved in the repository settings instead
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975
# Name/bash regex for shields.io endpoint JSON files
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb|.*\.svg)'
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Mock up the structure of an extracted .zip file
shell: bash -x -e {0}
run: |
# Get the actual test data from a real archive, minus the .nsys-rep file
unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
./nsys_jax_venv/bin/python -m jupyterlab --version
# Run with ipython for the sake of getting a clear error message
./nsys_jax_venv/bin/ipython Analysis.ipynb
- name: Render the notebook
id: render
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
workdir=$(mktemp -d)
./nsys_jax_venv/bin/jupyter nbconvert --execute --inplace Analysis.ipynb
cp *.ipynb *.svg "${workdir}"
echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT
- name: Upload rendered notebook to Gist
id: upload
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const currentDateTime = new Date().toISOString();
const gistDescription =
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` +
`Run ID: ${{ github.run_id }}, ` +
`Repository: ${{ github.repository }}, ` +
`Event: ${{ github.event_name }}, ` +
`Created: ${currentDateTime}`;
const fs = require('fs').promises;
const workdir = '${{ steps.render.outputs.WORKDIR }}'
const files = await fs.readdir(workdir);
gist = await github.rest.gists.create({
description: gistDescription,
public: false,
files: Object.fromEntries(
await Promise.all(
files.map(
async filename => {
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8');
return [filename, { content }];
}
)
)
)
});
console.log(gist)
return gist.data.id;
- name: Copy rendered notebook to Gist with well-known ID
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const srcId = ${{ steps.upload.outputs.result }};
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}";
const { PUBLISH_NOTEBOOK_FILES } = process.env;
// Fetch existing files from destination gist
const { data: dstData } = await github.rest.gists.get({
gist_id: dstId
});
// Mark existing files in destination gist for deletion
let filesToUpdate = {};
for (const filename of Object.keys(dstData.files)) {
filesToUpdate[filename] = null;
}
// Fetch files from source gist
const { data: srcData } = await github.rest.gists.get({
gist_id: srcId
});
// Add or update files based on the pattern
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`);
for (const [filename, fileObj] of Object.entries(srcData.files)) {
if (filename.match(pattern)) {
// If the total gist size is too large, not all the content will have
// been returned and we need some extra requests.
if (fileObj.truncated) {
const { data } = await github.request(fileObj.raw_url)
filesToUpdate[filename] = {
content: new TextDecoder().decode(data)
};
} else {
filesToUpdate[filename] = {
content: fileObj.content
};
}
}
}
// Update files in destination gist
await github.rest.gists.update({
gist_id: dstId,
files: filesToUpdate
});
console.log("Files copied successfully.");
console.log(Object.keys(filesToUpdate));
ruff:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Install ruff"
run: pip install ruff
- name: "Run ruff checks"
shell: bash -x {0}
run: |
ruff check ${NSYS_JAX_PYTHON_FILES}
check_status=$?
ruff format --check ${NSYS_JAX_PYTHON_FILES}
format_status=$?
if [[ $format_status != 0 || $check_status != 0 ]]; then
exit 1
fi