From 4f18b6b164c405ed2aea579b5c36c721a67e3e6f Mon Sep 17 00:00:00 2001 From: Thibaut Lunet Date: Fri, 3 Jan 2025 17:56:47 +0100 Subject: [PATCH] TL: finished MPI time-parallel SDC implementation, test OK --- pySDC/playgrounds/dedalus/sdc.py | 133 +++++++++++++++++++++++++++---- 1 file changed, 119 insertions(+), 14 deletions(-) diff --git a/pySDC/playgrounds/dedalus/sdc.py b/pySDC/playgrounds/dedalus/sdc.py index 0259b7fb5f..eba3902288 100644 --- a/pySDC/playgrounds/dedalus/sdc.py +++ b/pySDC/playgrounds/dedalus/sdc.py @@ -286,7 +286,7 @@ def leftIsNode(self): def doProlongation(self): return not self.rightIsNode or self.forceProl - def _computeMX0(self, MX0): + def _computeMX0(self, MX0:CoeffSystem): """ Compute MX0 term used in RHS of both initStep and sweep methods @@ -380,6 +380,7 @@ def _evalF(self, F, time, dt, wall_time): t0 = solver.sim_time solver.sim_time = time if self.firstEval: + # Eventually write field in file solver.evaluator.evaluate_scheduled( wall_time=wall_time, timestep=dt, sim_time=time, iteration=solver.iteration) @@ -596,7 +597,7 @@ def _sweep(self, k): axpy(a=-dt*qE[k, m, i], x=Fk[i].data, y=RHS.data) axpy(a=dt*qI[k, m, i], x=LXk[i].data, y=RHS.data) - # Add LX terms from iteration k+1 and current nodes + # Add LX terms from iteration k and current nodes axpy(a=dt*qI[k, m, m], x=LXk[m].data, y=RHS.data) # Solve system and store node solution in solver state @@ -717,6 +718,7 @@ def step(self, dt, wall_time): self.firstEval = True self.firstStep = False + def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False): gComm = MPI.COMM_WORLD @@ -776,7 +778,7 @@ def initSpaceTimeMPI(nProcSpace=None, nProcTime=None, groupTime=False): sColor = (gRank - gRank % nProcSpace) / nProcSpace sComm = gComm.Split(sColor, gRank) gComm.Barrier() - + return gComm, sComm, tComm @@ -788,7 +790,7 @@ class SDCIMEX_MPI(SpectralDeferredCorrectionIMEX): def initSpaceTimeComms(cls, nProcSpace=None, groupTime=False): gComm, sComm, cls.comm = initSpaceTimeMPI(nProcSpace, cls.getM(), groupTime) return gComm, sComm, cls.comm - + @property def rank(self): return self.comm.Get_rank() @@ -796,6 +798,8 @@ def rank(self): def __init__(self, solver): assert isinstance(self.comm, MPI.Intracomm), "comm is not a MPI communicator" + assert self.diagonal, "MPI parallelization works only with diagonal SDC" + assert not self.forceProl, "MPI parallelization not implemented with forceProl" # Store class attributes as instance attributes self.infos = self.getInfos() @@ -814,10 +818,9 @@ def __init__(self, solver): # Attributes self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) self.dt = None - - self.firstEval = (self.rank == 0) - self.firstStep = True + self.firstEval = (self.rank == self.M-1) + self.firstStep = True def _updateLHS(self, dt, init=False): """Update LHS and LHS solvers for each subproblem @@ -831,22 +834,124 @@ def _updateLHS(self, dt, init=False): subproblem. The default is False. """ # Attribute references - qI = self.QDeltaI + m = self.rank + qI = self.QDeltaI[:, m, m] solver = self.solver # Update LHS and LHS solvers for each subproblems for sp in solver.subproblems: if init: # Potentially instantiate list of solver (ony first time step) - sp.LHS_solvers = [[None for _ in range(self.M)] for _ in range(self.nSweeps)] + sp.LHS_solvers = [None for _ in range(self.nSweeps)] for k in range(self.nSweeps): - m = self.rank if solver.store_expanded_matrices: raise NotImplementedError("code correction required") np.copyto(sp.LHS.data, sp.M_exp.data) - self.axpy(a=dt*qI[k, m, m], x=sp.L_exp.data, y=sp.LHS.data) + self.axpy(a=dt*qI[k], x=sp.L_exp.data, y=sp.LHS.data) else: - sp.LHS = (sp.M_min + dt*qI[k, m, m]*sp.L_min) - sp.LHS_solvers[k][m] = solver.matsolver(sp.LHS, solver) + sp.LHS = (sp.M_min + dt*qI[k]*sp.L_min) + sp.LHS_solvers[k] = solver.matsolver(sp.LHS, solver) if self.initSweep == "QDELTA": - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() + + def _solveAndStoreState(self, k): + """ + Solve LHS * X = RHS using the LHS associated to a given node, + and store X into the solver state. + It uses the current RHS attribute of the object. + + Parameters + ---------- + k : int + Sweep index (0 for the first sweep). + """ + # Attribute references + solver = self.solver + RHS = self.RHS + + self._presetStateCoeffSpace(solver.state) + + # Solve and store for each subproblem + for sp in solver.subproblems: + # Slice out valid subdata, skipping invalid components + spRHS = RHS.get_subdata(sp) + spX = sp.LHS_solvers[k].solve(spRHS) # CREATES TEMPORARY + sp.scatter_inputs(spX, solver.state) + + def _computeMX0(self, MX0:CoeffSystem): + """ + Compute MX0 term used in RHS of both initStep and sweep methods + + Update the MX0 attribute of the timestepper object. + """ + if self.rank == self.M-1: # only last node compute MX0 + super()._computeMX0(MX0) + # Broadcast MX0 to all nodes + self.comm.Bcast(MX0.data, root=self.M-1) + + def _initSweep(self): + t0, dt, wall_time = self.solver.sim_time, self.dt, self.wall_time + Fk, LXk = self.F[0], self.LX[0] + if self.initSweep == 'COPY': + if self.rank == self.M-1: # only last node evaluate + self._evalLX(LXk) + self._evalF(Fk, t0, dt, wall_time) + # Broadcast LXk and Fk to all nodes + self.comm.Bcast(LXk.data, root=self.M-1) + self.comm.Bcast(Fk.data, root=self.M-1) + else: + raise NotImplementedError() + + def _sweep(self, k): + """Perform a sweep for the current time-step""" + # Only compute for the current node + m = self.rank + + # Attribute references + tau, qI, q = self.nodes[m], self.QDeltaI[:, m, m], self.Q[:, m] + solver = self.solver + t0, dt, wall_time = solver.sim_time, self.dt, self.wall_time + RHS, MX0 = self.RHS, self.MX0 + Fk, LXk, Fk1, LXk1 = self.F[0], self.LX[0], self.F[1], self.LX[1] + axpy = self.axpy + + # Build RHS + if RHS.data.size: + + # Initialize with MX0 term + np.copyto(RHS.data, MX0.data) + + # Add quadrature terms using reduced sum accross nodes + recvBuf = np.zeros_like(RHS.data) + sendBuf = np.zeros_like(RHS.data) + for i in range(self.M-1, -1, -1): # start from last node + sendBuf.fill(0) + axpy(a=dt*q[i], x=Fk.data, y=sendBuf) + axpy(a=-dt*q[i], x=LXk.data, y=sendBuf) + self.comm.Reduce(sendBuf, recvBuf, root=i, op=MPI.SUM) + RHS.data += recvBuf + + # Add LX terms from iteration k and current nodes + axpy(a=dt*qI[k], x=LXk.data, y=RHS.data) + + # Solve system and store node solution in solver state + self._solveAndStoreState(k) + + if k < self.nSweeps-1: + tEval = t0+dt*tau + # Evaluate and store F(X, t) with current state + self._evalF(Fk1, tEval, dt, wall_time) + # Evaluate and store LX with current state + self._evalLX(LXk1) + + # Inverse position for iterate k and k+1 in storage + # ie making the new evaluation the old for next iteration + self.F.rotate() + self.LX.rotate() + + def step(self, dt, wall_time): + super().step(dt, wall_time) + + # Only last rank (i.e node) will be allowed to (eventually) write outputs + if self.rank != self.M-1: + self.firstEval = False