Skip to content

Commit

Permalink
TL: finished MPI time-parallel SDC implementation, test OK
Browse files Browse the repository at this point in the history
  • Loading branch information
tlunet committed Jan 3, 2025
1 parent b85fd2d commit 4f18b6b
Showing 1 changed file with 119 additions and 14 deletions.
133 changes: 119 additions & 14 deletions pySDC/playgrounds/dedalus/sdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def leftIsNode(self):
def doProlongation(self):
return not self.rightIsNode or self.forceProl

def _computeMX0(self, MX0):
def _computeMX0(self, MX0:CoeffSystem):
"""
Compute MX0 term used in RHS of both initStep and sweep methods
Expand Down Expand Up @@ -380,6 +380,7 @@ def _evalF(self, F, time, dt, wall_time):
t0 = solver.sim_time
solver.sim_time = time
if self.firstEval:
# Eventually write field in file
solver.evaluator.evaluate_scheduled(
wall_time=wall_time, timestep=dt, sim_time=time,
iteration=solver.iteration)
Expand Down Expand Up @@ -596,7 +597,7 @@ def _sweep(self, k):
axpy(a=-dt*qE[k, m, i], x=Fk[i].data, y=RHS.data)
axpy(a=dt*qI[k, m, i], x=LXk[i].data, y=RHS.data)

# Add LX terms from iteration k+1 and current nodes
# Add LX terms from iteration k and current nodes
axpy(a=dt*qI[k, m, m], x=LXk[m].data, y=RHS.data)

# Solve system and store node solution in solver state
Expand Down Expand Up @@ -717,6 +718,7 @@ def step(self, dt, wall_time):
self.firstEval = True
self.firstStep = False


def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False):

gComm = MPI.COMM_WORLD
Expand Down Expand Up @@ -776,7 +778,7 @@ def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False):
sColor = (gRank - gRank % nProcSpace) / nProcSpace
sComm = gComm.Split(sColor, gRank)
gComm.Barrier()

return gComm, sComm, tComm


Expand All @@ -788,14 +790,16 @@ class SDCIMEX_MPI(SpectralDeferredCorrectionIMEX):
def initSpaceTimeComms(cls, nProcSpace=None, groupTime=False):
gComm, sComm, cls.comm = initSpaceTimeMPI(nProcSpace, cls.getM(), groupTime)
return gComm, sComm, cls.comm

@property
def rank(self):
return self.comm.Get_rank()

def __init__(self, solver):

assert isinstance(self.comm, MPI.Intracomm), "comm is not a MPI communicator"
assert self.diagonal, "MPI parallelization works only with diagonal SDC"
assert not self.forceProl, "MPI parallelization not implemented with forceProl"

# Store class attributes as instance attributes
self.infos = self.getInfos()
Expand All @@ -814,10 +818,9 @@ def __init__(self, solver):
# Attributes
self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype)
self.dt = None

self.firstEval = (self.rank == 0)
self.firstStep = True

self.firstEval = (self.rank == self.M-1)
self.firstStep = True

def _updateLHS(self, dt, init=False):
"""Update LHS and LHS solvers for each subproblem
Expand All @@ -831,22 +834,124 @@ def _updateLHS(self, dt, init=False):
subproblem. The default is False.
"""
# Attribute references
qI = self.QDeltaI
m = self.rank
qI = self.QDeltaI[:, m, m]
solver = self.solver

# Update LHS and LHS solvers for each subproblems
for sp in solver.subproblems:
if init:
# Potentially instantiate list of solver (ony first time step)
sp.LHS_solvers = [[None for _ in range(self.M)] for _ in range(self.nSweeps)]
sp.LHS_solvers = [None for _ in range(self.nSweeps)]
for k in range(self.nSweeps):
m = self.rank
if solver.store_expanded_matrices:
raise NotImplementedError("code correction required")
np.copyto(sp.LHS.data, sp.M_exp.data)
self.axpy(a=dt*qI[k, m, m], x=sp.L_exp.data, y=sp.LHS.data)
self.axpy(a=dt*qI[k], x=sp.L_exp.data, y=sp.LHS.data)
else:
sp.LHS = (sp.M_min + dt*qI[k, m, m]*sp.L_min)
sp.LHS_solvers[k][m] = solver.matsolver(sp.LHS, solver)
sp.LHS = (sp.M_min + dt*qI[k]*sp.L_min)
sp.LHS_solvers[k] = solver.matsolver(sp.LHS, solver)
if self.initSweep == "QDELTA":
raise NotImplementedError()
raise NotImplementedError()

def _solveAndStoreState(self, k):
"""
Solve LHS * X = RHS using the LHS associated to a given node,
and store X into the solver state.
It uses the current RHS attribute of the object.
Parameters
----------
k : int
Sweep index (0 for the first sweep).
"""
# Attribute references
solver = self.solver
RHS = self.RHS

self._presetStateCoeffSpace(solver.state)

# Solve and store for each subproblem
for sp in solver.subproblems:
# Slice out valid subdata, skipping invalid components
spRHS = RHS.get_subdata(sp)
spX = sp.LHS_solvers[k].solve(spRHS) # CREATES TEMPORARY
sp.scatter_inputs(spX, solver.state)

def _computeMX0(self, MX0:CoeffSystem):
"""
Compute MX0 term used in RHS of both initStep and sweep methods
Update the MX0 attribute of the timestepper object.
"""
if self.rank == self.M-1: # only last node compute MX0
super()._computeMX0(MX0)
# Broadcast MX0 to all nodes
self.comm.Bcast(MX0.data, root=self.M-1)

def _initSweep(self):
t0, dt, wall_time = self.solver.sim_time, self.dt, self.wall_time
Fk, LXk = self.F[0], self.LX[0]
if self.initSweep == 'COPY':
if self.rank == self.M-1: # only last node evaluate
self._evalLX(LXk)
self._evalF(Fk, t0, dt, wall_time)
# Broadcast LXk and Fk to all nodes
self.comm.Bcast(LXk.data, root=self.M-1)
self.comm.Bcast(Fk.data, root=self.M-1)
else:
raise NotImplementedError()

def _sweep(self, k):
"""Perform a sweep for the current time-step"""
# Only compute for the current node
m = self.rank

# Attribute references
tau, qI, q = self.nodes[m], self.QDeltaI[:, m, m], self.Q[:, m]
solver = self.solver
t0, dt, wall_time = solver.sim_time, self.dt, self.wall_time
RHS, MX0 = self.RHS, self.MX0
Fk, LXk, Fk1, LXk1 = self.F[0], self.LX[0], self.F[1], self.LX[1]
axpy = self.axpy

# Build RHS
if RHS.data.size:

# Initialize with MX0 term
np.copyto(RHS.data, MX0.data)

# Add quadrature terms using reduced sum accross nodes
recvBuf = np.zeros_like(RHS.data)
sendBuf = np.zeros_like(RHS.data)
for i in range(self.M-1, -1, -1): # start from last node
sendBuf.fill(0)
axpy(a=dt*q[i], x=Fk.data, y=sendBuf)
axpy(a=-dt*q[i], x=LXk.data, y=sendBuf)
self.comm.Reduce(sendBuf, recvBuf, root=i, op=MPI.SUM)
RHS.data += recvBuf

# Add LX terms from iteration k and current nodes
axpy(a=dt*qI[k], x=LXk.data, y=RHS.data)

# Solve system and store node solution in solver state
self._solveAndStoreState(k)

if k < self.nSweeps-1:
tEval = t0+dt*tau
# Evaluate and store F(X, t) with current state
self._evalF(Fk1, tEval, dt, wall_time)
# Evaluate and store LX with current state
self._evalLX(LXk1)

# Inverse position for iterate k and k+1 in storage
# ie making the new evaluation the old for next iteration
self.F.rotate()
self.LX.rotate()

def step(self, dt, wall_time):
super().step(dt, wall_time)

# Only last rank (i.e node) will be allowed to (eventually) write outputs
if self.rank != self.M-1:
self.firstEval = False

0 comments on commit 4f18b6b

Please sign in to comment.