Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize across multiple GPUs with MPI4Jax #1071

Open
kianorr opened this issue Jun 25, 2024 · 8 comments · May be fixed by #1495
Open

Parallelize across multiple GPUs with MPI4Jax #1071

kianorr opened this issue Jun 25, 2024 · 8 comments · May be fixed by #1495
Labels
enhancement General label for enhancement. Please also tag with "Speed", "Interface", "Functionality", etc P3 Highest Priority, someone is/should be actively working on this performance New feature or request to make the code faster question Further information is requested

Comments

@kianorr
Copy link
Collaborator

kianorr commented Jun 25, 2024

The general idea is

  • Physics objectives are put onto separate GPUs
  • Constraints are taken out
  • Then combine jacobians on a single GPU to create the $A$ matrix, and perform SVD on this $A$ matrix
  • From this a new eq is created, which is then fed back to each GPU
@dpanici dpanici added low priority Nice to have, but not needed right away performance New feature or request to make the code faster enhancement General label for enhancement. Please also tag with "Speed", "Interface", "Functionality", etc and removed low priority Nice to have, but not needed right away labels Aug 13, 2024
@dpanici
Copy link
Collaborator

dpanici commented Aug 13, 2024

Add view-only link to hackathon doc

@kianorr
Copy link
Collaborator Author

kianorr commented Aug 20, 2024

instructions to install on della are under "mpi4jax installation instructions on della-gpu" in this google doc: https://docs.google.com/document/d/1x6nGZEiZnAiWBDf20Mcbwbob9a6GMbx3ZpT_huUVZF4/edit?usp=sharing

@dpanici dpanici added the P3 Highest Priority, someone is/should be actively working on this label Nov 11, 2024
@dpanici dpanici added the question Further information is requested label Nov 25, 2024
@dpanici
Copy link
Collaborator

dpanici commented Nov 25, 2024

@kianorr @f0uriest can you put here what the constraints/limitations of this approach are and what situtations it could actually help with?

@YigitElma
Copy link
Collaborator

Maybe instead of separating objectives, we could distribute the transforms and profiles of the compute method to different GPU's and this can make use of multi-GPU parallelism even for single objective (like ForceBalance) cases.

We can distribute quantities related to each grid point and then parallelize over it. Then the question is, can we use this for flux surface averaged stuff? For example, for ForceBalance is there a dependent quantity that requires some average for the calculation?

@dpanici @f0uriest @ddudt

@dpanici
Copy link
Collaborator

dpanici commented Dec 2, 2024

Flag for grids to pad out to have grid.num_nodes be evenly divisible by number of GPUs, and have the extra pad nodes have weight of 0 assigned in grid.weights

@ddudt
Copy link
Collaborator

ddudt commented Dec 2, 2024

More info:

  • Memory usage can already be reduced for many applications with either deriv_mode="blocked" or jac_chunk_size=1 and deriv_mode="batched" on the outer ObjectiveFunction
    • with maximal memory reduction with deriv_mode="blocked" and then jac_chunk_size=1 on each sub Objective in the ObjectiveFunction. If the problem still won't fit with these settings, then multi-gpu with parallelization over the objectives won't help anymore.
  • If the issue is that the resolution is too high (Jacobian is too "wide"), then jac_chunk_size=1 is the solution, and resolving this issue would not help further
  • If the issue is that there are too many objectives (Jacobian is too "tall"), then deriv_mode="blocked" is the solution, and resolving this issue would not help besides giving a speed improvement
  • If memory is still an issue, then parallelizing the grid nodes across multiple devices could help. This is conceptually similar to deriv_mode="blocked", but would sub-divide each objective (helping with "tall" Jacobians, or very complex objectives)
  • Parallelizing across the objectives and/or the grid nodes would require some plumbing to implement
    • with parallelization across grid nodes likely requiring the most work

@ddudt ddudt changed the title Implement multi-gpu functionality for separate objectives with MPI4Jax Parallelize across multiple GPUs with MPI4Jax Dec 2, 2024
@YigitElma
Copy link
Collaborator

YigitElma commented Dec 24, 2024

Some initial trials for a dummyfunction, just for getting used to the syntax. I ran this on della with 3 A100s.
multi-gpu-jax.zip

There is a dummy compute function,

@jax.jit
def compute(params, points):
    modes = jnp.arange(len(params))
    eval = jnp.outer(modes, points)
    res = jnp.dot(params, jnp.cos(eval))
    return res

Some benchmarks with inputs,

num_sharding = 3
mesh = jax.make_mesh((num_sharding,), ('points',))
sharding = NamedSharding(mesh, P('points'))
replicated_sharding = NamedSharding(mesh, P())

points = jnp.arange(3*30000)
params = jax.random.normal(jax.random.key(0), 1000)

points_sharded = jax.device_put(points, sharding)
params_sharded = jax.device_put(params, replicated_sharding)

compute(params_sharded, points_sharded)
%timeit compute(params_sharded, points_sharded).block_until_ready()

compute(params, points)
%timeit compute(params, points).block_until_ready()

compute(params_sharded, points)
%timeit compute(params_sharded, points).block_until_ready()
643 μs ± 3.61 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.67 ms ± 818 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.82 ms ± 7.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Also for making everything faster (I guess now the bottleneck is the QR) we might need jax-ml/jax#16597.

@YigitElma
Copy link
Collaborator

Then the question is, can we use this for flux surface averaged stuff?

Distribute the grid points such that the ones at the same flux surface stay on the same device. Depending on the grid used, we might need to assign different number of flux surfaces to different devices, to keep the number of grid points the the same for each device.

@YigitElma YigitElma linked a pull request Dec 25, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement General label for enhancement. Please also tag with "Speed", "Interface", "Functionality", etc P3 Highest Priority, someone is/should be actively working on this performance New feature or request to make the code faster question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants