-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallelize across multiple GPUs with MPI4Jax #1071
Comments
Add view-only link to hackathon doc |
instructions to install on della are under "mpi4jax installation instructions on della-gpu" in this google doc: https://docs.google.com/document/d/1x6nGZEiZnAiWBDf20Mcbwbob9a6GMbx3ZpT_huUVZF4/edit?usp=sharing |
Maybe instead of separating objectives, we could distribute the We can distribute quantities related to each grid point and then parallelize over it. Then the question is, can we use this for flux surface averaged stuff? For example, for |
Flag for grids to pad out to have grid.num_nodes be evenly divisible by number of GPUs, and have the extra pad nodes have weight of 0 assigned in grid.weights |
More info:
|
Some initial trials for a dummyfunction, just for getting used to the syntax. I ran this on della with 3 A100s. There is a dummy compute function, @jax.jit
def compute(params, points):
modes = jnp.arange(len(params))
eval = jnp.outer(modes, points)
res = jnp.dot(params, jnp.cos(eval))
return res Some benchmarks with inputs, num_sharding = 3
mesh = jax.make_mesh((num_sharding,), ('points',))
sharding = NamedSharding(mesh, P('points'))
replicated_sharding = NamedSharding(mesh, P())
points = jnp.arange(3*30000)
params = jax.random.normal(jax.random.key(0), 1000)
points_sharded = jax.device_put(points, sharding)
params_sharded = jax.device_put(params, replicated_sharding)
compute(params_sharded, points_sharded)
%timeit compute(params_sharded, points_sharded).block_until_ready()
compute(params, points)
%timeit compute(params, points).block_until_ready()
compute(params_sharded, points)
%timeit compute(params_sharded, points).block_until_ready()
Also for making everything faster (I guess now the bottleneck is the QR) we might need jax-ml/jax#16597. |
Distribute the grid points such that the ones at the same flux surface stay on the same device. Depending on the grid used, we might need to assign different number of flux surfaces to different devices, to keep the number of grid points the the same for each device. |
The general idea is
eq
is created, which is then fed back to each GPUThe text was updated successfully, but these errors were encountered: