Skip to content

Commit

Permalink
Better FFT's (#74)
Browse files Browse the repository at this point in the history
* persist, scatter and compute

* update remote

* update remote

* in place fft's in clark

* add pad_and_shift_cube functions

* remove profiling statement

* typos in clark

* more typos in clark

* import prange

* pass nthreads to clark

* nthreads in clark

* clean speed up

* fix double normalisation by sum in restore

Co-authored-by: landmanbester <[email protected]>
  • Loading branch information
landmanbester and landmanbester authored Dec 7, 2022
1 parent 2eb0521 commit ae78b8c
Show file tree
Hide file tree
Showing 14 changed files with 450 additions and 350 deletions.
9 changes: 7 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ Install the package by cloning and running

:code:`$ pip install -e pfb-clean/`

Note casacore needs to be installed on the system for this to work.
Note casacore needs to be installed on the system for this to work.

It is strongly recommended to install ducc in no binary mode eg

:code:`$ git clone https://gitlab.mpcdf.mpg.de/mtr/ducc.git`
:code:`$ pip install -e ducc`

If you find any of this useful please cite (for now)

https://arxiv.org/abs/2101.08072
https://arxiv.org/abs/2101.08072
22 changes: 18 additions & 4 deletions pfb/deconv/clark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial
import numba
import dask.array as da
from pfb.operators.psf import psf_convolve_cube
import pyscilog
log = pyscilog.get_logger('CLARK')

Expand Down Expand Up @@ -61,6 +62,7 @@ def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000):
Idelq = q - Iq
# find where PSF overlaps with image
mask = (np.abs(Idelp) <= nxo2) & (np.abs(Idelq) <= nyo2)
# Ipp, Iqq = psf[:, nxo2 - Ip[mask], nyo2 - Iq[mask]]
A = subtract(A[:, mask], psf,
Idelp[mask], Idelq[mask],
xhat, nxo2, nyo2)
Expand All @@ -73,10 +75,9 @@ def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000):
k += 1
return model


def clark(ID,
PSF,
psfo,
PSFHAT,
threshold=0,
gamma=0.05,
pf=0.05,
Expand All @@ -86,7 +87,8 @@ def clark(ID,
report_freq=1,
verbosity=1,
psfopts=None,
sigmathreshold=2):
sigmathreshold=2,
nthreads=1):
nband, nx, ny = ID.shape
_, nx_psf, ny_psf = PSF.shape
wsums = np.amax(PSF, axis=(1,2))
Expand All @@ -97,6 +99,10 @@ def clark(ID,
wsums = np.amax(PSF, axis=(1,2))
model = np.zeros((nband, nx, ny), dtype=ID.dtype)
IR = ID.copy()
# pre-allocate arrays for doing FFT's
xout = np.empty(ID.shape, dtype=ID.dtype, order='C')
xpad = np.empty(PSF.shape, dtype=ID.dtype, order='C')
xhat = np.empty(PSFHAT.shape, dtype=PSFHAT.dtype)
# square avoids abs of full array
IRsearch = np.sum(IR, axis=0)**2
pq = IRsearch.argmax()
Expand All @@ -116,7 +122,15 @@ def clark(ID,
th=subth,
maxit=submaxit)
# subtract from full image (as in major cycle)
IR = ID - psfo(model)
psf_convolve_cube(
xpad,
xhat,
xout,
PSFHAT,
ny_psf,
model,
nthreads=nthreads)
IR = ID - xout
IRsearch = np.sum(IR, axis=0)**2
pq = IRsearch.argmax()
p = pq//ny
Expand Down
59 changes: 57 additions & 2 deletions pfb/operators/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from daskms.optimisation import inlined_array
from ducc0.wgridder import ms2dirty, dirty2ms
from uuid import uuid4
from pfb.operators.psf import (psf_convolve_slice,
psf_convolve_cube)


def hessian_xds(x, xds, hessopts, wsum, sigmainv, mask,
Expand Down Expand Up @@ -90,8 +92,8 @@ def _hessian_impl(x, uvw, weight, vis_mask, freq, beam,
pixsize_y=cell,
epsilon=epsilon,
nthreads=nthreads,
do_wstacking=wstack) #,
# double_precision_accumulation=double_accum)
do_wstacking=wstack,
double_precision_accumulation=double_accum)

if beam is not None:
convim *= beam
Expand All @@ -117,3 +119,56 @@ def hessian(x, uvw, weight, vis_mask, freq, beam, hessopts):
beam, bout,
hessopts, None,
dtype=x.dtype)


def hessian_psf_slice(
xpad, # preallocated array to store padded image
xhat, # preallocated array to store FTd image
xout, # preallocated array to store output image
psfhat,
beam,
lastsize,
x, # input image, not overwritten
nthreads=1,
sigmainv=1,
wsum=1):
"""
Tikhonov regularised Hessian approx
"""

if beam is not None:
psf_convolve_slice(xpad, xhat, xout,
psfhat/wsum, lastsize, x*beam)
else:
psf_convolve_slice(xpad, xhat, xout,
psfhat/wsum, lastsize, x)

if beam is not None:
xout *= beam

return xout + x * sigmainv


def hessian_psf_cube(
xpad, # preallocated array to store padded image
xhat, # preallocated array to store FTd image
xout, # preallocated array to store output image
beam,
psfhat,
lastsize,
x, # input image, not overwritten
nthreads=1,
sigmainv=1):
"""
Tikhonov regularised Hessian approx
"""

if beam is not None:
psf_convolve_cube(x*beam, xpad, xhat, xout, psfhat, lastsize)
else:
psf_convolve_cube(x, xpad, xhat, xout, psfhat, lastsize)

if beam is not None:
xout *= beam

return xout + x * sigmainv
160 changes: 53 additions & 107 deletions pfb/operators/psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,44 @@
from uuid import uuid4
from ducc0.fft import r2c, c2r, c2c, good_size
from uuid import uuid4
from pfb.utils.misc import pad_and_shift, unpad_and_unshift
from pfb.utils.misc import pad_and_shift_cube, unpad_and_unshift_cube
import gc
iFs = np.fft.ifftshift
Fs = np.fft.fftshift


def psf_convolve_slice(
xpad, # preallocated array to store padded image
xhat, # preallocated array to store FTd image
xout, # preallocated array to store output image
psfhat,
lastsize,
x, # input image, not overwritten
nthreads=1):
pad_and_shift(x, xpad)
r2c(xpad, axes=(0, 1), nthreads=nthreads,
forward=True, inorm=0, out=xhat)
xhat *= psfhat
c2r(xhat, axes=(0, 1), forward=False, out=xpad,
lastsize=lastsize, inorm=2, nthreads=nthreads)
unpad_and_unshift(xpad, xout)
return xout


def psf_convolve_cube(xpad, # preallocated array to store padded image
xhat, # preallocated array to store FTd image
xout, # preallocated array to store output image
psfhat,
lastsize,
x, # input image, not overwritten
nthreads=1):
pad_and_shift_cube(x, xpad)
r2c(xpad, axes=(1, 2), nthreads=nthreads,
forward=True, inorm=0, out=xhat)
xhat *= psfhat
c2r(xhat, axes=(1, 2), forward=False, out=xpad,
lastsize=lastsize, inorm=2, nthreads=nthreads)
unpad_and_unshift_cube(xpad, xout)
return xout


def psf_convolve_xds(x, xds, psfopts, wsum, sigmainv, mask,
Expand Down Expand Up @@ -52,43 +87,11 @@ def psf_convolve_xds(x, xds, psfopts, wsum, sigmainv, mask,
else:
return convim

def _psf_convolve_impl(x, psfhat, beam,
nthreads=None,
padding=None,
unpad_x=None,
unpad_y=None,
lastsize=None):
nx, ny = x.shape
xhat = x if beam is None else x * beam
xhat = iFs(np.pad(xhat, padding, mode='constant'), axes=(0, 1))
xhat = r2c(xhat, axes=(0, 1), nthreads=nthreads,
forward=True, inorm=0)
xhat = c2r(xhat * psfhat, axes=(0, 1), forward=False,
lastsize=lastsize, inorm=2, nthreads=nthreads)
convim = Fs(xhat, axes=(0, 1))[unpad_x, unpad_y]

if beam is not None:
convim *= beam

return convim

def _psf_convolve(x, psfhat, beam, psfopts):
return _psf_convolve_impl(x, psfhat, beam, **psfopts)

def psf_convolve(x, psfhat, beam, psfopts):
if not isinstance(x, da.Array):
x = da.from_array(x, chunks=(-1, -1), name=False)
if beam is None:
bout = None
else:
bout = ('nx', 'ny')
return da.blockwise(_psf_convolve, ('nx', 'ny'),
x, ('nx', 'ny'),
psfhat, ('nx', 'ny'),
beam, bout,
psfopts, None,
align_arrays=False,
dtype=x.dtype)




def _psf_convolve_cube_impl(x, psfhat, beam,
Expand All @@ -113,10 +116,8 @@ def _psf_convolve_cube_impl(x, psfhat, beam,

return convim

def _psf_convolve_cube(x, psfhat, beam, psfopts):
return _psf_convolve_cube_impl(x, psfhat, beam, **psfopts)

def psf_convolve_cube(x, psfhat, beam, psfopts,
def psf_convolve_cube_dask(x, psfhat, beam, psfopts,
wsum=1, sigmainv=None, compute=True):
if not isinstance(x, da.Array):
x = da.from_array(x, chunks=(1, -1, -1),
Expand Down Expand Up @@ -150,72 +151,17 @@ def psf_convolve_cube(x, psfhat, beam, psfopts,
else:
return convim


def _hessian_reg_psf(x, beam, psfhat,
nthreads=None,
sigmainv=None,
padding=None,
unpad_x=None,
unpad_y=None,
lastsize=None):
"""
Tikhonov regularised Hessian approx
"""
if isinstance(psfhat, da.Array):
psfhat = psfhat.compute()
if isinstance(beam, da.Array):
beam = beam.compute()

if beam is not None:
xhat = iFs(np.pad(beam*x, padding, mode='constant'), axes=(1, 2))
else:
xhat = iFs(np.pad(x, padding, mode='constant'), axes=(1, 2))
xhat = r2c(xhat, axes=(1, 2), nthreads=nthreads,
forward=True, inorm=0)
xhat = c2r(xhat * psfhat, axes=(1, 2), forward=False,
lastsize=lastsize, inorm=2, nthreads=nthreads)
im = Fs(xhat, axes=(1, 2))[:, unpad_x, unpad_y]

if beam is not None:
im *= beam

if np.any(sigmainv):
return im + x * sigmainv
else:
return im

def _hessian_reg_psf_slice(
x,
psfhat=None,
beam=None,
nthreads=None,
sigmainv=None,
padding=None,
unpad_x=None,
unpad_y=None,
lastsize=None,
wsum=1.0):
"""
Tikhonov regularised Hessian approx
"""
if isinstance(psfhat, da.Array):
psfhat = psfhat.compute()
if isinstance(beam, da.Array):
beam = beam.compute()
if beam is not None:
xhat = iFs(np.pad(beam*x, padding, mode='constant'), axes=(0, 1))
else:
xhat = iFs(np.pad(x, padding, mode='constant'), axes=(0, 1))
xhat = r2c(xhat, axes=(0, 1), nthreads=nthreads,
forward=True, inorm=0)
xhat = c2r(xhat * psfhat/wsum, axes=(0, 1), forward=False,
lastsize=lastsize, inorm=2, nthreads=nthreads)
im = Fs(xhat, axes=(0, 1))[unpad_x, unpad_y]

if beam is not None:
im *= beam

if np.any(sigmainv):
return im + x * sigmainv
def psf_convolve_slice_dask(x, psfhat, beam, psfopts):
if not isinstance(x, da.Array):
x = da.from_array(x, chunks=(-1, -1), name=False)
if beam is None:
bout = None
else:
return im
bout = ('nx', 'ny')
return da.blockwise(psf_convolve_slice, ('nx', 'ny'),
x, ('nx', 'ny'),
psfhat, ('nx', 'ny'),
beam, bout,
psfopts, None,
align_arrays=False,
dtype=x.dtype)
Loading

0 comments on commit ae78b8c

Please sign in to comment.