From ae78b8ca546fc78447473f888cb1334cff506e43 Mon Sep 17 00:00:00 2001 From: Landman Bester Date: Wed, 7 Dec 2022 13:03:30 +0200 Subject: [PATCH] Better FFT's (#74) * persist, scatter and compute * update remote * update remote * in place fft's in clark * add pad_and_shift_cube functions * remove profiling statement * typos in clark * more typos in clark * import prange * pass nthreads to clark * nthreads in clark * clean speed up * fix double normalisation by sum in restore Co-authored-by: landmanbester --- README.rst | 9 ++- pfb/deconv/clark.py | 22 +++++- pfb/operators/hessian.py | 59 ++++++++++++++- pfb/operators/psf.py | 160 +++++++++++++-------------------------- pfb/opt/pcg.py | 46 +++++++---- pfb/opt/power_method.py | 53 ++++++++++++- pfb/opt/primal_dual.py | 12 +-- pfb/parser/clean.yaml | 43 +++-------- pfb/utils/dist.py | 17 +++-- pfb/utils/misc.py | 122 ++++++++++++++++++++++++++--- pfb/workers/clean.py | 100 +++++++----------------- pfb/workers/restore.py | 13 ++-- pfb/workers/spotless.py | 140 ++++++++++++++-------------------- tests/test_clean.py | 4 +- 14 files changed, 450 insertions(+), 350 deletions(-) diff --git a/README.rst b/README.rst index 37ac646c4..0fa90fa44 100644 --- a/README.rst +++ b/README.rst @@ -6,8 +6,13 @@ Install the package by cloning and running :code:`$ pip install -e pfb-clean/` -Note casacore needs to be installed on the system for this to work. +Note casacore needs to be installed on the system for this to work. + +It is strongly recommended to install ducc in no binary mode eg + +:code:`$ git clone https://gitlab.mpcdf.mpg.de/mtr/ducc.git` +:code:`$ pip install -e ducc` If you find any of this useful please cite (for now) -https://arxiv.org/abs/2101.08072 \ No newline at end of file +https://arxiv.org/abs/2101.08072 diff --git a/pfb/deconv/clark.py b/pfb/deconv/clark.py index a4d147e36..c92c6d716 100644 --- a/pfb/deconv/clark.py +++ b/pfb/deconv/clark.py @@ -3,6 +3,7 @@ from functools import partial import numba import dask.array as da +from pfb.operators.psf import psf_convolve_cube import pyscilog log = pyscilog.get_logger('CLARK') @@ -61,6 +62,7 @@ def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000): Idelq = q - Iq # find where PSF overlaps with image mask = (np.abs(Idelp) <= nxo2) & (np.abs(Idelq) <= nyo2) + # Ipp, Iqq = psf[:, nxo2 - Ip[mask], nyo2 - Iq[mask]] A = subtract(A[:, mask], psf, Idelp[mask], Idelq[mask], xhat, nxo2, nyo2) @@ -73,10 +75,9 @@ def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000): k += 1 return model - def clark(ID, PSF, - psfo, + PSFHAT, threshold=0, gamma=0.05, pf=0.05, @@ -86,7 +87,8 @@ def clark(ID, report_freq=1, verbosity=1, psfopts=None, - sigmathreshold=2): + sigmathreshold=2, + nthreads=1): nband, nx, ny = ID.shape _, nx_psf, ny_psf = PSF.shape wsums = np.amax(PSF, axis=(1,2)) @@ -97,6 +99,10 @@ def clark(ID, wsums = np.amax(PSF, axis=(1,2)) model = np.zeros((nband, nx, ny), dtype=ID.dtype) IR = ID.copy() + # pre-allocate arrays for doing FFT's + xout = np.empty(ID.shape, dtype=ID.dtype, order='C') + xpad = np.empty(PSF.shape, dtype=ID.dtype, order='C') + xhat = np.empty(PSFHAT.shape, dtype=PSFHAT.dtype) # square avoids abs of full array IRsearch = np.sum(IR, axis=0)**2 pq = IRsearch.argmax() @@ -116,7 +122,15 @@ def clark(ID, th=subth, maxit=submaxit) # subtract from full image (as in major cycle) - IR = ID - psfo(model) + psf_convolve_cube( + xpad, + xhat, + xout, + PSFHAT, + ny_psf, + model, + nthreads=nthreads) + IR = ID - xout IRsearch = np.sum(IR, axis=0)**2 pq = IRsearch.argmax() p = pq//ny diff --git a/pfb/operators/hessian.py b/pfb/operators/hessian.py index af9c41dd9..011898611 100644 --- a/pfb/operators/hessian.py +++ b/pfb/operators/hessian.py @@ -5,6 +5,8 @@ from daskms.optimisation import inlined_array from ducc0.wgridder import ms2dirty, dirty2ms from uuid import uuid4 +from pfb.operators.psf import (psf_convolve_slice, + psf_convolve_cube) def hessian_xds(x, xds, hessopts, wsum, sigmainv, mask, @@ -90,8 +92,8 @@ def _hessian_impl(x, uvw, weight, vis_mask, freq, beam, pixsize_y=cell, epsilon=epsilon, nthreads=nthreads, - do_wstacking=wstack) #, - # double_precision_accumulation=double_accum) + do_wstacking=wstack, + double_precision_accumulation=double_accum) if beam is not None: convim *= beam @@ -117,3 +119,56 @@ def hessian(x, uvw, weight, vis_mask, freq, beam, hessopts): beam, bout, hessopts, None, dtype=x.dtype) + + +def hessian_psf_slice( + xpad, # preallocated array to store padded image + xhat, # preallocated array to store FTd image + xout, # preallocated array to store output image + psfhat, + beam, + lastsize, + x, # input image, not overwritten + nthreads=1, + sigmainv=1, + wsum=1): + """ + Tikhonov regularised Hessian approx + """ + + if beam is not None: + psf_convolve_slice(xpad, xhat, xout, + psfhat/wsum, lastsize, x*beam) + else: + psf_convolve_slice(xpad, xhat, xout, + psfhat/wsum, lastsize, x) + + if beam is not None: + xout *= beam + + return xout + x * sigmainv + + +def hessian_psf_cube( + xpad, # preallocated array to store padded image + xhat, # preallocated array to store FTd image + xout, # preallocated array to store output image + beam, + psfhat, + lastsize, + x, # input image, not overwritten + nthreads=1, + sigmainv=1): + """ + Tikhonov regularised Hessian approx + """ + + if beam is not None: + psf_convolve_cube(x*beam, xpad, xhat, xout, psfhat, lastsize) + else: + psf_convolve_cube(x, xpad, xhat, xout, psfhat, lastsize) + + if beam is not None: + xout *= beam + + return xout + x * sigmainv diff --git a/pfb/operators/psf.py b/pfb/operators/psf.py index d9a6137de..f23bede64 100644 --- a/pfb/operators/psf.py +++ b/pfb/operators/psf.py @@ -4,9 +4,44 @@ from uuid import uuid4 from ducc0.fft import r2c, c2r, c2c, good_size from uuid import uuid4 +from pfb.utils.misc import pad_and_shift, unpad_and_unshift +from pfb.utils.misc import pad_and_shift_cube, unpad_and_unshift_cube import gc -iFs = np.fft.ifftshift -Fs = np.fft.fftshift + + +def psf_convolve_slice( + xpad, # preallocated array to store padded image + xhat, # preallocated array to store FTd image + xout, # preallocated array to store output image + psfhat, + lastsize, + x, # input image, not overwritten + nthreads=1): + pad_and_shift(x, xpad) + r2c(xpad, axes=(0, 1), nthreads=nthreads, + forward=True, inorm=0, out=xhat) + xhat *= psfhat + c2r(xhat, axes=(0, 1), forward=False, out=xpad, + lastsize=lastsize, inorm=2, nthreads=nthreads) + unpad_and_unshift(xpad, xout) + return xout + + +def psf_convolve_cube(xpad, # preallocated array to store padded image + xhat, # preallocated array to store FTd image + xout, # preallocated array to store output image + psfhat, + lastsize, + x, # input image, not overwritten + nthreads=1): + pad_and_shift_cube(x, xpad) + r2c(xpad, axes=(1, 2), nthreads=nthreads, + forward=True, inorm=0, out=xhat) + xhat *= psfhat + c2r(xhat, axes=(1, 2), forward=False, out=xpad, + lastsize=lastsize, inorm=2, nthreads=nthreads) + unpad_and_unshift_cube(xpad, xout) + return xout def psf_convolve_xds(x, xds, psfopts, wsum, sigmainv, mask, @@ -52,43 +87,11 @@ def psf_convolve_xds(x, xds, psfopts, wsum, sigmainv, mask, else: return convim -def _psf_convolve_impl(x, psfhat, beam, - nthreads=None, - padding=None, - unpad_x=None, - unpad_y=None, - lastsize=None): - nx, ny = x.shape - xhat = x if beam is None else x * beam - xhat = iFs(np.pad(xhat, padding, mode='constant'), axes=(0, 1)) - xhat = r2c(xhat, axes=(0, 1), nthreads=nthreads, - forward=True, inorm=0) - xhat = c2r(xhat * psfhat, axes=(0, 1), forward=False, - lastsize=lastsize, inorm=2, nthreads=nthreads) - convim = Fs(xhat, axes=(0, 1))[unpad_x, unpad_y] - - if beam is not None: - convim *= beam - return convim -def _psf_convolve(x, psfhat, beam, psfopts): - return _psf_convolve_impl(x, psfhat, beam, **psfopts) -def psf_convolve(x, psfhat, beam, psfopts): - if not isinstance(x, da.Array): - x = da.from_array(x, chunks=(-1, -1), name=False) - if beam is None: - bout = None - else: - bout = ('nx', 'ny') - return da.blockwise(_psf_convolve, ('nx', 'ny'), - x, ('nx', 'ny'), - psfhat, ('nx', 'ny'), - beam, bout, - psfopts, None, - align_arrays=False, - dtype=x.dtype) + + def _psf_convolve_cube_impl(x, psfhat, beam, @@ -113,10 +116,8 @@ def _psf_convolve_cube_impl(x, psfhat, beam, return convim -def _psf_convolve_cube(x, psfhat, beam, psfopts): - return _psf_convolve_cube_impl(x, psfhat, beam, **psfopts) -def psf_convolve_cube(x, psfhat, beam, psfopts, +def psf_convolve_cube_dask(x, psfhat, beam, psfopts, wsum=1, sigmainv=None, compute=True): if not isinstance(x, da.Array): x = da.from_array(x, chunks=(1, -1, -1), @@ -150,72 +151,17 @@ def psf_convolve_cube(x, psfhat, beam, psfopts, else: return convim - -def _hessian_reg_psf(x, beam, psfhat, - nthreads=None, - sigmainv=None, - padding=None, - unpad_x=None, - unpad_y=None, - lastsize=None): - """ - Tikhonov regularised Hessian approx - """ - if isinstance(psfhat, da.Array): - psfhat = psfhat.compute() - if isinstance(beam, da.Array): - beam = beam.compute() - - if beam is not None: - xhat = iFs(np.pad(beam*x, padding, mode='constant'), axes=(1, 2)) - else: - xhat = iFs(np.pad(x, padding, mode='constant'), axes=(1, 2)) - xhat = r2c(xhat, axes=(1, 2), nthreads=nthreads, - forward=True, inorm=0) - xhat = c2r(xhat * psfhat, axes=(1, 2), forward=False, - lastsize=lastsize, inorm=2, nthreads=nthreads) - im = Fs(xhat, axes=(1, 2))[:, unpad_x, unpad_y] - - if beam is not None: - im *= beam - - if np.any(sigmainv): - return im + x * sigmainv - else: - return im - -def _hessian_reg_psf_slice( - x, - psfhat=None, - beam=None, - nthreads=None, - sigmainv=None, - padding=None, - unpad_x=None, - unpad_y=None, - lastsize=None, - wsum=1.0): - """ - Tikhonov regularised Hessian approx - """ - if isinstance(psfhat, da.Array): - psfhat = psfhat.compute() - if isinstance(beam, da.Array): - beam = beam.compute() - if beam is not None: - xhat = iFs(np.pad(beam*x, padding, mode='constant'), axes=(0, 1)) - else: - xhat = iFs(np.pad(x, padding, mode='constant'), axes=(0, 1)) - xhat = r2c(xhat, axes=(0, 1), nthreads=nthreads, - forward=True, inorm=0) - xhat = c2r(xhat * psfhat/wsum, axes=(0, 1), forward=False, - lastsize=lastsize, inorm=2, nthreads=nthreads) - im = Fs(xhat, axes=(0, 1))[unpad_x, unpad_y] - - if beam is not None: - im *= beam - - if np.any(sigmainv): - return im + x * sigmainv +def psf_convolve_slice_dask(x, psfhat, beam, psfopts): + if not isinstance(x, da.Array): + x = da.from_array(x, chunks=(-1, -1), name=False) + if beam is None: + bout = None else: - return im + bout = ('nx', 'ny') + return da.blockwise(psf_convolve_slice, ('nx', 'ny'), + x, ('nx', 'ny'), + psfhat, ('nx', 'ny'), + beam, bout, + psfopts, None, + align_arrays=False, + dtype=x.dtype) diff --git a/pfb/opt/pcg.py b/pfb/opt/pcg.py index 4e5b80343..3321c36f6 100644 --- a/pfb/opt/pcg.py +++ b/pfb/opt/pcg.py @@ -132,12 +132,14 @@ def M(x): return x else: return x, r -from pfb.operators.psf import _hessian_reg_psf_slice as hessian_psf +from pfb.operators.hessian import hessian_psf_slice def _pcg_psf_impl(psfhat, b, x0, beam, - hessopts, + lastsize, + nthreads, + sigmainv, tol=1e-5, maxit=500, minit=100, @@ -148,20 +150,26 @@ def _pcg_psf_impl(psfhat, A specialised distributed version of pcg when the operator implements convolution with the psf (+ L2 regularisation by sigma**2) ''' - nband, nbasis, nmax = b.shape - model = np.zeros((nband, nbasis, nmax), dtype=b.dtype) - sigmainvsq = hessopts['sigmainv']**2 + nband, nx, ny = b.shape + _, nx_psf, nyo2 = psfhat.shape + model = np.zeros((nband, nx, ny), dtype=b.dtype, order='C') # PCG preconditioner - if sigmainvsq > 0: - def M(x): return x / sigmainvsq + if sigmainv > 0: + def M(x): return x / sigmainv else: M = None for k in range(nband): - A = partial(hessian_psf, - psfhat=psfhat[k], - beam=beam[k], - **hessopts) + A = partial(hessian_psf_slice, + np.empty((nx_psf, lastsize), dtype=b.dtype, order='C'), # xpad + np.empty((nx_psf, nyo2), dtype=psfhat.dtype, order='C'), # xhat + np.empty((nx, ny), dtype=b.dtype, order='C'), # xout + psfhat[k], + beam[k], + lastsize, + nthreads=nthreads, + sigmainv=sigmainv) + model[k] = pcg(A, b[k], x0[k], M=M, tol=tol, maxit=maxit, minit=minit, verbosity=verbosity, report_freq=report_freq, @@ -173,20 +181,26 @@ def _pcg_psf(psfhat, b, x0, beam, - hessopts, + lastsize, + nthreads, + sigmainv, cgopts): return _pcg_psf_impl(psfhat, b, x0, beam, - hessopts, + lastsize, + nthreads, + sigmainv, **cgopts) def pcg_psf(psfhat, b, x0, beam, - hessopts, + lastsize, + nthreads, + sigmainv, cgopts, compute=True): @@ -221,7 +235,9 @@ def pcg_psf(psfhat, b, ('nb', 'nx', 'ny'), x0, ('nb', 'nx', 'ny'), beam, bout, - hessopts, None, + lastsize, None, + nthreads, None, + sigmainv, None, cgopts, None, align_arrays=False, dtype=b.dtype) diff --git a/pfb/opt/power_method.py b/pfb/opt/power_method.py index 712ebd367..249214032 100644 --- a/pfb/opt/power_method.py +++ b/pfb/opt/power_method.py @@ -1,7 +1,9 @@ import numpy as np +import dask.array as da from operator import getitem -from distributed import wait, get_client +from distributed import wait, get_client, as_completed from scipy.linalg import norm +from copy import deepcopy import pyscilog log = pyscilog.get_logger('PM') @@ -102,3 +104,52 @@ def power_method_dist(Af, wait([b, bnorm, beta]) return beta + + +def power2(A, bp, bnorm): + bp /= bnorm + b = A(bp) + bsumsq = da.sum(b**2) + beta_num = da.vdot(b, bp) + beta_den = da.vdot(bp, bp) + + return b, bsumsq, beta_num, beta_den + + +def power_method_persist(ddsf, + Af, + nx, + ny, + nband, + tol=1e-5, + maxit=200): + + client = get_client() + b = [] + bssq = [] + for ds in ddsf: + wid = ds.worker + tmp = client.persist(da.random.normal(0, 1, (nx, ny)), + workers={wid}) + b.append(tmp) + bssq.append(da.sum(b**2)) + + bssq = da.stack(bssq) + bnorm = da.sqrt(da.sum(bssq)) + bp = deepcopy(b) + beta_num = [da.array()] + + for k in range(maxit): + for i, ds, A in enumerate(zip(ddsf, Af)): + bp[i] = b[i]/bnorm + b[i] = A(bp[i]) + bssq[i] = da.sum(b[i]**2) + beta_num = da.vdot(b, bp) + beta_den = da.vdot(bp, bp) + + + + + + + diff --git a/pfb/opt/primal_dual.py b/pfb/opt/primal_dual.py index b05758141..b4e6200b7 100644 --- a/pfb/opt/primal_dual.py +++ b/pfb/opt/primal_dual.py @@ -101,13 +101,13 @@ def get_ratio(vtildes, lam, sigma, l1weights): ratio[mask] = vsoft[mask] / vmfs[mask] return ratio -def update(ds, A, y, vtilde, ratio, psi, psiH, **kwargs): +def update(ds, A, y, vtilde, ratio, **kwargs): sigma = kwargs['sigma'] lam = kwargs['lam'] tau = kwargs['tau'] gamma = kwargs['gamma'] - # psi = kwargs['psi'] - # psiH = kwargs['psiH'] + psi = kwargs['psi'] + psiH = kwargs['psiH'] xp = ds.MODEL.values vp = ds.DUAL.values @@ -182,9 +182,9 @@ def primal_dual_dist( wait([ratio]) future = client.map(update, - ddsf, Af, yf, vtildes, [ratio]*len(ddsf), psif, psiHf, - # psi=psi, - # psiH=psiH, + ddsf, Af, yf, vtildes, [ratio]*len(ddsf), + psi=psi, + psiH=psiH, pure=False, wsum=wsum, sigma=sigma, diff --git a/pfb/parser/clean.yaml b/pfb/parser/clean.yaml index 806fbb608..61008e046 100644 --- a/pfb/parser/clean.yaml +++ b/pfb/parser/clean.yaml @@ -23,31 +23,11 @@ inputs: info: Imaging products to produce. Options are I, Q, U, V. Only single Stokes products are currently supported - residual_name: - dtype: str - abbreviation: rname - default: RESIDUAL - info: - Name of residual in dds - model_name: - dtype: str - abbreviation: mname - default: MODEL - info: - Name of model in mds mask: dtype: str abbreviation: mask info: Path to mask.fits - algo: - dtype: str - default: clark - choices: - - hogbom - - clark - info: - Which minor cycle to use dirosion: dtype: int default: 1 @@ -67,7 +47,14 @@ inputs: dtype: bool default: true info: - Run PCG based flux mop between major iterations + Trigger PCG based flux mop if minor cycle stalls, + the final threshold is reached or on the final iteration. + mop_gamma: + dtype: float + default: 0.65 + info: + Step size for flux mop. Should be between (0,1). + A value of 1 is most aggressive. nmiter: dtype: int default: 5 @@ -103,22 +90,16 @@ inputs: abbreviation: spf info: Peak factor of sub-minor loop - hogbom_maxit: - dtype: int - default: 5000 - abbreviation: hmaxit - info: - Maximum number of peak finding iterations between major cycles - clark_maxit: + minor_maxit: dtype: int default: 50 - abbreviation: cmaxit + abbreviation: mmaxit info: Maximum number of PSF convolutions between major cycles - sub_maxit: + subminor_maxit: dtype: int default: 1000 - abbreviation: smaxit + abbreviation: smmaxit info: Maximum number of iterations for the sub-minor cycle verbose: diff --git a/pfb/utils/dist.py b/pfb/utils/dist.py index 7f1f09eef..10ca88d7e 100644 --- a/pfb/utils/dist.py +++ b/pfb/utils/dist.py @@ -18,8 +18,9 @@ def get_resid_and_stats(dds, wsum): def accum_wsums(dds): wsum = 0 + # import pdb; pdb.set_trace() for ds in dds: - wsum += ds.WSUM.values[0] + wsum += ds.WSUM.data[0] return wsum def get_eps(modelp, dds): @@ -46,9 +47,9 @@ def l1reweight(dds, l1weight, psiH, wsum, pix_per_beam, alpha=2): return alpha/(1 + (l2mod/rms[:, None])**2) def get_cbeam_area(dds, wsum): - psf_mfs = np.zeros(dds[0].PSF.shape, dtype=dds[0].PSF.dtype) + psf_mfs = da.zeros(dds[0].PSF.shape, dtype=dds[0].PSF.dtype) for ds in dds: - psf_mfs += ds.PSF.values/wsum + psf_mfs += ds.PSF.data/wsum # beam pars in pixel units GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)[0] return GaussPar[0]*GaussPar[1]*np.pi/4 @@ -58,11 +59,13 @@ def get_cbeam_area(dds, wsum): def init_dual_and_model(ds, **kwargs): dct = {} if 'MODEL' not in ds: - model = da.zeros((kwargs['nx'], kwargs['ny']), chunks=(-1, -1)) - dct['MODEL'] = (('x', 'y'), model) + model = np.zeros((kwargs['nx'], kwargs['ny'])) + dct['MODEL'] = (('x', 'y'), da.from_array(model, chunks=(-1, -1))) + # dct['MODEL'] = (('x', 'y'), model) if 'DUAL' not in ds: - dual = da.zeros((kwargs['nbasis'], kwargs['nmax']), chunks=(-1,-1)) - dct['DUAL'] = (('b', 'c'), dual) + dual = np.zeros((kwargs['nbasis'], kwargs['nmax'])) + dct['DUAL'] = (('b', 'c'), da.from_array(dual, chunks=(-1, -1))) + # dct['DUAL'] = (('b', 'c'), dual) ds_out = ds.assign(**dct) return ds_out diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index 890921c3a..b4bb91e14 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -1,7 +1,7 @@ import sys import numpy as np import numexpr as ne -from numba import jit, njit +from numba import jit, njit, prange import dask import dask.array as da from dask.distributed import performance_report @@ -820,8 +820,7 @@ def coerce_literal(func, literals): return -def dds2cubes(dds, opts, apparent=False, log=None): - nband = opts.nband +def dds2cubes(dds, nband, apparent=False): real_type = dds[0].DIRTY.dtype complex_type = np.result_type(real_type, np.complex64) nx, ny = dds[0].DIRTY.shape @@ -829,7 +828,7 @@ def dds2cubes(dds, opts, apparent=False, log=None): dtype=real_type) for _ in range(nband)] model = [da.zeros((nx, ny), chunks=(-1, -1), dtype=real_type) for _ in range(nband)] - if opts.residual_name in dds[0]: + if 'RESIDUAL' in dds[0]: residual = [da.zeros((nx, ny), chunks=(-1, -1), dtype=real_type) for _ in range(nband)] else: @@ -851,12 +850,12 @@ def dds2cubes(dds, opts, apparent=False, log=None): b = ds.bandid if apparent: dirty[b] += ds.DIRTY.data - if opts.residual_name in ds: - residual[b] += ds.get(opts.residual_name).data + if 'RESIDUAL' in ds: + residual[b] += ds.RESIDUAL.data else: dirty[b] += ds.DIRTY.data * ds.BEAM.data - if opts.residual_name in ds: - residual[b] += ds.get(opts.residual_name).data * ds.BEAM.data + if 'RESIDUAL' in ds: + residual[b] += ds.RESIDUAL.data * ds.BEAM.data if 'PSF' in ds: psf[b] += ds.PSF.data psfhat[b] += ds.PSFHAT.data @@ -868,7 +867,7 @@ def dds2cubes(dds, opts, apparent=False, log=None): wsum = wsums.sum() dirty = da.stack(dirty)/wsum model = da.stack(model) - if opts.residual_name in ds: + if 'RESIDUAL' in ds: residual = da.stack(residual)/wsum if 'PSF' in ds: psf = da.stack(psf)/wsum @@ -1104,3 +1103,108 @@ def lthreshold(x, sigma, kind='l1'): elif kind=='l1': absx = np.abs(x) return np.where(absx > sigma, absx - sigma, 0) * np.sign(x) + + +@njit(nogil=True, cache=True, inline='always') +def pad_and_shift(x, out): + ''' + Pad x with zeros so as to have the same shape as out and perform + ifftshift in place + ''' + nxi, nyi = x.shape + nxo, nyo = out.shape + if nxi >= nxo or nyi >= nyo: + raise ValueError('Output must be larger than input') + padx = nxo-nxi + pady = nyo-nyi + out[...] = 0.0 + # first and last quadrants + for i in range(nxi//2): + for j in range(nyi//2): + # first to last quadrant + out[padx + nxi//2 + i, pady + nyi//2 + j] = x[i, j] + # last to first quadrant + out[i, j] = x[nxi//2+i, nyi//2+j] + # third to second quadrant + out[i, pady + nyi//2 + j] = x[nxi//2 + i, j] + # second to third quadrant + out[padx + nxi//2 + i, j] = x[i, nyi//2 + j] + return out + + +@njit(nogil=True, cache=True, inline='always') +def unpad_and_unshift(x, out): + ''' + fftshift x and unpad it into out + ''' + nxi, nyi = x.shape + nxo, nyo = out.shape + if nxi < nxo or nyi < nyo: + raise ValueError('Output must be smaller than input') + out[...] = 0.0 + padx = nxo-nxi + pady = nyo-nyi + for i in range(nxo//2): + for j in range(nyo//2): + # first to last quadrant + out[nxo//2+i, nyo//2+j] = x[i, j] + # last to first quadrant + out[i, j] = x[padx + nxo//2 + i, pady + nyo//2 + j] + # third to second quadrant + out[nxo//2 + i, j] = x[i, pady + nyo//2 + j] + # second to third quadrant + out[i, nyo//2 + j] = x[padx + nxo//2 + i, j] + return out + + +@njit(nogil=True, cache=True, inline='always', parallel=True) +def pad_and_shift_cube(x, out): + ''' + Pad x with zeros so as to have the same shape as out and perform + ifftshift in place + ''' + nband, nxi, nyi = x.shape + nband, nxo, nyo = out.shape + if nxi >= nxo or nyi >= nyo: + raise ValueError('Output must be larger than input') + padx = nxo-nxi + pady = nyo-nyi + out[...] = 0.0 + for b in prange(nband): + for i in range(nxi//2): + for j in range(nyi//2): + # first to last quadrant + out[b, padx + nxi//2 + i, pady + nyi//2 + j] = x[b, i, j] + # last to first quadrant + out[b, i, j] = x[b, nxi//2+i, nyi//2+j] + # third to second quadrant + out[b, i, pady + nyi//2 + j] = x[b, nxi//2 + i, j] + # second to third quadrant + out[b, padx + nxi//2 + i, j] = x[b, i, nyi//2 + j] + return out + + +@njit(nogil=True, cache=True, inline='always', parallel=True) +def unpad_and_unshift_cube(x, out): + ''' + fftshift x and unpad it into out + ''' + nband, nxi, nyi = x.shape + nband, nxo, nyo = out.shape + if nxi < nxo or nyi < nyo: + raise ValueError('Output must be smaller than input') + out[...] = 0.0 + padx = nxo-nxi + pady = nyo-nyi + for b in prange(nband): + for i in range(nxo//2): + for j in range(nyo//2): + # first to last quadrant + out[b, nxo//2+i, nyo//2+j] = x[b, i, j] + # last to first quadrant + out[b, i, j] = x[b, padx + nxo//2 + i, pady + nyo//2 + j] + # third to second quadrant + out[b, nxo//2 + i, j] = x[b, i, pady + nyo//2 + j] + # second to third quadrant + out[b, i, nyo//2 + j] = x[b, padx + nxo//2 + i, j] + return out diff --git a/pfb/workers/clean.py b/pfb/workers/clean.py index 02f0c37ae..359eba082 100644 --- a/pfb/workers/clean.py +++ b/pfb/workers/clean.py @@ -66,10 +66,8 @@ def _clean(**kw): from pfb.deconv.hogbom import hogbom from pfb.deconv.clark import clark from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr - from pfb.operators.hessian import hessian from pfb.opt.pcg import pcg, pcg_psf from pfb.operators.hessian import hessian_xds - from pfb.operators.psf import psf_convolve_cube, _hessian_reg_psf from scipy import ndimage from copy import copy @@ -87,14 +85,14 @@ def _clean(**kw): dds = dask.persist(dds)[0] nx_psf, ny_psf = dds[0].nx_psf, dds[0].ny_psf + lastsize = ny_psf # stitch dirty/psf in apparent scale output_type = dds[0].DIRTY.dtype dirty, model, residual, psf, psfhat, _, wsums = dds2cubes( dds, - opts, - apparent=True, - log=log) + opts.nband, + apparent=True) wsum = np.sum(wsums) psf_mfs = np.sum(psf, axis=0) assert (psf_mfs.max() - 1.0) < 2*opts.epsilon @@ -144,40 +142,7 @@ def _clean(**kw): mask=np.ones((nx, ny), dtype=output_type), compute=True, use_beam=False) - # set up image space Hessian - npad_xl = (nx_psf - nx)//2 - npad_xr = nx_psf - nx - npad_xl - npad_yl = (ny_psf - ny)//2 - npad_yr = ny_psf - ny - npad_yl - padding = ((npad_xl, npad_xr), (npad_yl, npad_yr)) - unpad_x = slice(npad_xl, -npad_xr) - unpad_y = slice(npad_yl, -npad_yr) - lastsize = ny + np.sum(padding[-1]) - psfopts = {} - psfopts['nthreads'] = opts.nvthreads - psfopts['padding'] = padding - psfopts['unpad_x'] = unpad_x - psfopts['unpad_y'] = unpad_y - psfopts['lastsize'] = lastsize - # psfo = partial(psf_convolve_cube, - # psfhat=da.from_array(psfhat, chunks=(1, -1, -1)), - # beam=None, - # psfopts=psfopts, - # wsum=1, # psf is normalised to sum to one - # sigmainv=0, - # compute=True) - - hess2opts = copy(psfopts) - hess2opts['sigmainv'] = 1e-8 - - padding = ((0, 0), (npad_xl, npad_xr), (npad_yl, npad_yr)) - psfo = partial(_hessian_reg_psf, beam=mask[None, :, :], psfhat=psfhat, - nthreads=opts.nthreads, sigmainv=0, - padding=padding, unpad_x=unpad_x, unpad_y=unpad_y, - lastsize = lastsize) - - - + # PCG related options for flux mop cgopts = {} cgopts['tol'] = opts.cg_tol cgopts['maxit'] = opts.cg_maxit @@ -198,34 +163,18 @@ def _clean(**kw): print(f"Iter 0: peak residual = {rmax:.3e}, rms = {rms:.3e}", file=log) for k in range(opts.nmiter): - if opts.algo.lower() == 'clark': - print("Running Clark", file=log) - # import cProfile - # with cProfile.Profile() as pr: - x, status = clark(mask*residual, psf, psfo, - threshold=threshold, - gamma=opts.gamma, - pf=opts.peak_factor, - maxit=opts.clark_maxit, - subpf=opts.sub_peak_factor, - submaxit=opts.sub_maxit, - verbosity=opts.verbose, - report_freq=opts.report_freq, - sigmathreshold=opts.sigmathreshold) - # pr.print_stats(sort='cumtime') - # quit() - - elif opts.algo.lower() == 'hogbom': - print("Running Hogbom", file=log) - x, status = hogbom(residual, psf, - threshold=threshold, - gamma=opts.gamma, - pf=opts.peak_factor, - maxit=opts.hogbom_maxit, - verbosity=opts.verbose, - report_freq=opts.report_freq) - else: - raise ValueError(f'{opts.algo} is not a valid algo option') + print("Cleaning", file=log) + x, status = clark(mask*residual, psf, psfhat, + threshold=threshold, + gamma=opts.gamma, + pf=opts.peak_factor, + maxit=opts.minor_maxit, + subpf=opts.sub_peak_factor, + submaxit=opts.subminor_maxit, + verbosity=opts.verbose, + report_freq=opts.report_freq, + sigmathreshold=opts.sigmathreshold, + nthreads=opts.nthreads) model += x @@ -257,15 +206,20 @@ def _clean(**kw): struct = ndimage.generate_binary_structure(2, opts.dirosion) mopmask = ndimage.binary_dilation(mopmask, structure=struct) mopmask = ndimage.binary_erosion(mopmask, structure=struct) - # hess2opts['sigmainv'] = 1e-8 x0 = np.zeros_like(x) x0[:, mopmask] = residual_mfs[mopmask] + # TODO - applying mask as beam is wasteful mopmask = mopmask[None, :, :].astype(residual.dtype) - hess2opts['sigmainv'] = rmax - x = pcg_psf(psfhat, mopmask*residual, x0, - mopmask, hess2opts, cgopts) - - model += x + x = pcg_psf(psfhat, + mopmask*residual, + x0, + mopmask, + lastsize, + opts.nvthreads, + rmax, # used as sigmainv + cgopts) + + model += opts.mop_gamma*x save_fits(f'{basename}_premop{k}_resid_mfs.fits', residual_mfs, hdr_mfs) diff --git a/pfb/workers/restore.py b/pfb/workers/restore.py index 965e00952..8daa2cc5b 100644 --- a/pfb/workers/restore.py +++ b/pfb/workers/restore.py @@ -73,16 +73,15 @@ def _restore(**kw): dirty, model, residual, psf, _, _, wsums = dds2cubes(dds, - opts, - apparent=True, - log=log) + nband, + apparent=True) wsum = np.sum(wsums) output_type = dirty.dtype - psf_mfs = np.sum(psf, axis=0)/wsum - residual_mfs = np.sum(residual, axis=0)/wsum + psf_mfs = np.sum(psf, axis=0) + residual_mfs = np.sum(residual, axis=0) fmask = wsums > 0 - residual[fmask] /= wsums[fmask, None, None] - psf[fmask] /= wsums[fmask, None, None] + residual[fmask] /= wsums[fmask, None, None]/wsum + psf[fmask] /= wsums[fmask, None, None]/wsum # sanity check assert (psf_mfs.max() - 1.0) < 2e-7 assert ((np.amax(psf, axis=(1,2)) - 1.0) < 2e-7).all() diff --git a/pfb/workers/spotless.py b/pfb/workers/spotless.py index 70dfd6fda..2ea84f6a3 100644 --- a/pfb/workers/spotless.py +++ b/pfb/workers/spotless.py @@ -103,6 +103,7 @@ def spotless(**kw): client = stack.enter_context(Client(cluster)) client.wait_for_workers(opts.nworkers) + client.amm.stop() # TODO - prettier config printing print('Input Options:', file=log) @@ -155,19 +156,6 @@ def _spotless(**kw): 'yo2':-1, 'b':-1, 'c':-1}) - if opts.memory_greedy: - # pass workers as set - ddsf = [ds.persist(workers={names[i]}) - for ds, i in zip(dds, cycle(range(opts.nworkers)))] - else: - ddsf = client.scatter(dds) - - # names={} - # for ds in ddsf: - # b = ds.result().bandid - # tmp = client.who_has(ds) - # for key in tmp.keys(): - # names[b] = tmp[key] real_type = dds[0].DIRTY.dtype complex_type = np.result_type(real_type, np.complex64) @@ -197,10 +185,6 @@ def _spotless(**kw): freq_out[b] = ds.freq_out hdr = set_wcs(cell_deg, cell_deg, nx, ny, radec, freq_out) - # assumed constant - wsum = client.submit(accum_wsums, ddsf).result() - pix_per_beam = client.submit(get_cbeam_area, ddsf, wsum).result() - # dictionary setup print("Setting up dictionary", file=log) bases = tuple(opts.bases.split(',')) @@ -210,13 +194,12 @@ def _spotless(**kw): bases, opts.nlevels) ntot = tuple(ntot) - psiH = partial(im2coef, bases=bases, ntot=ntot, nmax=nmax, nlevels=opts.nlevels) - # avoids pickling on dumba Dict + # avoid pickling dumba Dict psi = partial(coef2im, bases=bases, ntot=ntot, @@ -225,45 +208,61 @@ def _spotless(**kw): nx=nx, ny=ny) + print('Scattering data', file=log) + ddsf = [] + for ds, i in zip(dds, cycle(range(opts.nworkers))): + if 'MODEL' not in ds: + model = da.zeros((ds.nx, ds.ny), + chunks=(-1, -1)) + ds = ds.assign(**{ + 'MODEL': (('x', 'y'), model) + }) + if 'DUAL' not in ds: + dual = da.zeros((kwargs['nbasis'], kwargs['nmax']), + chunks=(-1, -1)) + ds = ds.assign(**{ + 'DUAL': (('b', 'c'), dual) + }) + ds.attrs.update(**{'worker':names[i]}) + ddsf.append(client.persist(ds, workers={names[i]})) + + # assumed constant + wsum = accum_wsums(ddsf) + pix_per_beam = get_cbeam_area(ddsf, wsum) + + # this makes for cleaner algorithms but is it a bad pattern? Afs = [] for ds in ddsf: - tmp = client.who_has(ds.PSFHAT) - wip = list(tmp.values())[0][0] - wid = client.scheduler_info()['workers'][wip]['id'] - Af = client.submit(partial, - _hessian_reg_psf_slice, - psfhat=ds.PSFHAT.data, - beam=ds.BEAM.data, - wsum=wsum, - nthreads=opts.nthreads, - sigmainv=opts.sigmainv, - padding=psf_padding, - unpad_x=unpad_x, - unpad_y=unpad_y, - lastsize=lastsize, - workers={wid}) + # tmp = client.who_has(ds) + # wip = list(tmp.values())[0][0] + # wid = client.scheduler_info()['workers'][wip]['id'] + Af = partial(_hessian_reg_psf_slice, + psfhat=ds.PSFHAT.data, + beam=ds.BEAM.data, + wsum=wsum, + nthreads=opts.nthreads, + sigmainv=opts.sigmainv, + padding=psf_padding, + unpad_x=unpad_x, + unpad_y=unpad_y, + lastsize=lastsize) Afs.append(Af) - # initialise for backward step - ddsf = client.map(init_dual_and_model, ddsf, - nx=nx, - ny=ny, - nbasis=nbasis, - nmax=nmax, - pure=False) - try: l1ds = xds_from_zarr(f'{dds_name}::L1WEIGHT', chunks={'b':-1,'c':-1}) if 'L1WEIGHT' in l1ds: - l1weight = client.submit(lambda ds: ds[0].L1WEIGHT.values, l1ds, - workers={names[0]}) + l1weight = client.persist(ds[0].L1WEIGHT.data, + workers={names[0]}) else: raise except Exception as e: print(f'Did not find l1weights at {dds_name}/L1WEIGHT. ' 'Initialising to unity', file=log) - l1weight = client.submit(np.ones, (nbasis, nmax), workers={names[0]}) + l1weight = client.persist(da.ones((nbasis, nmax), + chunks=(-1, -1), + dtype=real_type), + workers={names[0]}) if opts.hessnorm is None: print('Getting spectral norm of Hessian approximation', file=log) @@ -272,8 +271,6 @@ def _spotless(**kw): hessnorm = opts.hessnorm print(f'hessnorm = {hessnorm:.3e}', file=log) - import pdb; pdb.set_trace() - # future contains mfs residual and stats residf = client.submit(get_resid_and_stats, ddsf, wsum, workers={names[0]}) @@ -354,42 +351,19 @@ def _spotless(**kw): if eps < opts.tol: break - if opts.fits_mfs or opts.fits_cubes: - print("Writing fits files", file=log) - - # construct a header from xds attrs - ra = dds[0].ra - dec = dds[0].dec - radec = [ra, dec] - - cell_rad = dds[0].cell_rad - cell_deg = np.rad2deg(cell_rad) - - ref_freq = np.mean(freq_out) - hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq) - - residual = np.zeros((nband, nx, ny), dtype=real_type) - model = np.zeros((nband, nx, ny), dtype=real_type) - wsums = np.zeros(nband) - for ds in dds: - b = ds.bandid - wsums[b] += ds.WSUM.values[0] - residual[b] += ds.RESIDUAL.values - model[b] = ds.MODEL.values - wsum = np.sum(wsums) - residual_mfs = np.sum(residual, axis=0)/wsum - model_mfs = np.mean(model, axis=0) - save_fits(f'{basename}_residual_mfs.fits', residual_mfs, hdr_mfs) - save_fits(f'{basename}_model_mfs.fits', model_mfs, hdr_mfs) - - if opts.fits_cubes: - # need residual in Jy/beam - hdr = set_wcs(cell_deg, cell_deg, nx, ny, radec, freq_out) - save_fits(f'{basename}_model.fits', model, hdr) - fmask = wsums > 0 - residual[fmask] /= wsums[fmask, None, None] - save_fits(f'{basename}_residual.fits', - residual, hdr) + # convert to fits files + fitsout = [] + if opts.fits_mfs: + fitsout.append(dds2fits_mfs(dds, 'RESIDUAL', basename, norm_wsum=True)) + fitsout.append(dds2fits_mfs(dds, 'MODEL', basename, norm_wsum=False)) + + if opts.fits_cubes: + fitsout.append(dds2fits(dds, 'RESIDUAL', basename, norm_wsum=True)) + fitsout.append(dds2fits(dds, 'MODEL', basename, norm_wsum=False)) + + if len(fitsout): + print("Writing fits", file=log) + dask.compute(fitsout) if opts.scheduler=='distributed': from distributed import get_client diff --git a/tests/test_clean.py b/tests/test_clean.py index 6fdfff81c..90b75f5a9 100644 --- a/tests/test_clean.py +++ b/tests/test_clean.py @@ -9,8 +9,7 @@ pmp = pytest.mark.parametrize @pmp('do_gains', (False, True)) -@pmp('algo', ('hogbom', 'clark')) -def test_clean(do_gains, algo, tmp_path_factory): +def test_clean(do_gains, tmp_path_factory): ''' Here we test that clean correctly infers the fluxes of point sources placed at the centers of pixels in the presence of the wterm and DI gain @@ -219,7 +218,6 @@ def test_clean(do_gains, algo, tmp_path_factory): clean_args["output_filename"] = outname clean_args["postfix"] = postfix clean_args["nband"] = nchan - clean_args["algo"] = algo clean_args["dirosion"] = 0 clean_args["do_residual"] = False clean_args["nmiter"] = 100