diff --git a/quartical/apps/backup.py b/quartical/apps/backup.py index fcb44a69..1a15ea8c 100644 --- a/quartical/apps/backup.py +++ b/quartical/apps/backup.py @@ -1,5 +1,6 @@ import argparse from math import prod, ceil +from quartical.data_handling.selection import filter_xds_list from daskms import xds_from_storage_ms, xds_to_storage_table from daskms.experimental.zarr import xds_to_zarr, xds_from_zarr from daskms.fsspec_store import DaskMSStore @@ -10,8 +11,9 @@ def backup(): parser = argparse.ArgumentParser( description='Backup any Measurement Set column to zarr. Backups will ' - 'be labelled automatically using the current datetime, ' - 'the Measurement Set name and the column name.' + 'be labelled using a combination of the passed in label ' + '(defaults to datetime), the Measurement Set name and ' + 'the column name.' ) parser.add_argument( @@ -33,19 +35,34 @@ def backup(): type=str, help='Name of column to be backed up.' ) + parser.add_argument( + '--label', + type=str, + help='An explicit label to include in the backup name. Defaults to ' + 'datetime at which the backup was created. Full name will be ' + 'given by [label]-[msname]-[column].bkp.qc.' + ) parser.add_argument( '--nthread', type=int, default=1, help='Number of threads to use.' ) + parser.add_argument( + '--field-id', + type=int, + help='Field ID to back up.' + ) args = parser.parse_args() ms_name = args.ms_path.full_path.rsplit("/", 1)[1] column_name = args.column_name - timestamp = time.strftime("%Y%m%d-%H%M%S") + if args.label: + label = args.label + else: + label = time.strftime("%Y%m%d-%H%M%S") # This call exists purely to get the relevant shape and dtype info. data_xds_list = xds_from_storage_ms( @@ -55,8 +72,11 @@ def backup(): group_cols=("FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"), ) + # Use existing functionality. TODO: Improve and expose DDID selection. + xdso = filter_xds_list(data_xds_list, args.field_id) + # Compute appropriate chunks (256MB by default) to keep zarr happy. - chunks = [chunk_by_size(xds[column_name]) for xds in data_xds_list] + chunks = [chunk_by_size(xds[column_name]) for xds in xdso] # Repeat of above call but now with correct chunking information. data_xds_list = xds_from_storage_ms( @@ -67,9 +87,12 @@ def backup(): chunks=chunks ) + # Use existing functionality. TODO: Improve and expose DDID selection. + xdso = filter_xds_list(data_xds_list, args.field_id) + bkp_xds_list = xds_to_zarr( - data_xds_list, - f"{args.zarr_dir.url}::{timestamp}-{ms_name}-{column_name}.bkp.qc", + xdso, + f"{args.zarr_dir.url}::{label}-{ms_name}-{column_name}.bkp.qc", ) dask.compute(bkp_xds_list, num_workers=args.nthread) diff --git a/quartical/calibration/calibrate.py b/quartical/calibration/calibrate.py index 83ed022b..dba5f2f2 100644 --- a/quartical/calibration/calibrate.py +++ b/quartical/calibration/calibrate.py @@ -253,6 +253,18 @@ def make_visibility_output( itr = enumerate(zip(data_xds_list, mapping_xds_list)) + if output_opts.subtract_directions: + n_dir = data_xds_list[0].dims['dir'] # Should be the same on all xdss. + requested = set(output_opts.subtract_directions) + valid = set(range(n_dir)) + invalid = requested - valid + if invalid: + raise ValueError( + f"User has specified output.subtract_directions as " + f"{requested} but the following directions are not present " + f"in the model: {invalid}." + ) + for xds_ind, (data_xds, mapping_xds) in itr: data_col = data_xds.DATA.data model_col = data_xds.MODEL_DATA.data diff --git a/quartical/data_handling/selection.py b/quartical/data_handling/selection.py index 4f6f9301..bd0f66c3 100644 --- a/quartical/data_handling/selection.py +++ b/quartical/data_handling/selection.py @@ -1,7 +1,13 @@ -def filter_xds_list(xds_list, fields, ddids): +def filter_xds_list(xds_list, fields=[], ddids=[]): - filter_fields = {"FIELD_ID": fields, - "DATA_DESC_ID": ddids} + # If we specify an int, make it a list. Might be worth improving. + fields = [fields] if isinstance(fields, int) else fields + ddids = [ddids] if isinstance(ddids, int) else ddids + + filter_fields = { + "FIELD_ID": fields, + "DATA_DESC_ID": ddids + } for k, v in filter_fields.items(): fil = filter(lambda xds: getattr(xds, k) in v, xds_list)