diff --git a/quartical/config/argument_schema.yaml b/quartical/config/argument_schema.yaml index be3566fe..52c78af9 100644 --- a/quartical/config/argument_schema.yaml +++ b/quartical/config/argument_schema.yaml @@ -310,6 +310,15 @@ mad_flags: X_bl > max_deviation*X will be flagged. Set to zero to disable flagging on this statistic. + use_off_diagonals: + dtype: bool + default: False + info: + Controls whether or not the mad flagger will be run on the off-diagonal + elements of the residual. This is disabled by default as the residual + will tend to contain structure in the absence of a polarised model and + adequate leakage calibration. + solver: terms: dtype: List[str] diff --git a/quartical/flagging/flagging.py b/quartical/flagging/flagging.py index 9755c40b..cc4ba1ef 100644 --- a/quartical/flagging/flagging.py +++ b/quartical/flagging/flagging.py @@ -123,6 +123,8 @@ def valid_median(arr): def add_mad_graph(data_xds_list, mad_opts): + diag_corrs = ['RR', 'LL', 'XX', 'YY'] + bl_thresh = mad_opts.threshold_bl gbl_thresh = mad_opts.threshold_global max_deviation = mad_opts.max_deviation @@ -147,6 +149,13 @@ def add_mad_graph(data_xds_list, mad_opts): n_bl_w_autos = (n_ant * (n_ant - 1))/2 + n_ant n_t_chunk, n_f_chunk, _ = residuals.numblocks + if mad_opts.use_off_diagonals: + corr_sel = tuple(np.arange(residuals.shape[-1])) + else: + corr_sel = tuple( + [i for i, c in enumerate(xds.corr.values) if c in diag_corrs] + ) + wres = da.blockwise( compute_whitened_residual, ("rowlike", "chan", "corr"), residuals, ("rowlike", "chan", "corr"), @@ -228,6 +237,7 @@ def add_mad_graph(data_xds_list, mad_opts): gbl_thresh, None, bl_thresh, None, max_deviation, None, + corr_sel, None, n_ant, None, dtype=np.int8, align_arrays=False, diff --git a/quartical/flagging/flagging_kernels.py b/quartical/flagging/flagging_kernels.py index 2ad152f5..21d7eb8a 100644 --- a/quartical/flagging/flagging_kernels.py +++ b/quartical/flagging/flagging_kernels.py @@ -99,6 +99,7 @@ def compute_mad_flags( gbl_threshold, bl_threshold, max_deviation, + corr_sel, n_ant ): @@ -114,7 +115,7 @@ def compute_mad_flags( gbl_threshold2 = gbl_threshold ** 2 or np.inf max_deviation2 = max_deviation ** 2 or np.inf - for corr in range(n_corr): + for corr in corr_sel: gbl_mad_real = gbl_mad_and_med_real[0, 0, corr, 0] * scale_factor gbl_med_real = gbl_mad_and_med_real[0, 0, corr, 1]