Skip to content

Commit

Permalink
Merge branch 'v0.2.1-dev' into v0.2.1-degridder
Browse files Browse the repository at this point in the history
  • Loading branch information
JSKenyon committed Sep 13, 2023
2 parents c107e48 + b740a37 commit 9f1d03a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
35 changes: 29 additions & 6 deletions quartical/apps/backup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from math import prod, ceil
from quartical.data_handling.selection import filter_xds_list
from daskms import xds_from_storage_ms, xds_to_storage_table
from daskms.experimental.zarr import xds_to_zarr, xds_from_zarr
from daskms.fsspec_store import DaskMSStore
Expand All @@ -10,8 +11,9 @@
def backup():
parser = argparse.ArgumentParser(
description='Backup any Measurement Set column to zarr. Backups will '
'be labelled automatically using the current datetime, '
'the Measurement Set name and the column name.'
'be labelled using a combination of the passed in label '
'(defaults to datetime), the Measurement Set name and '
'the column name.'
)

parser.add_argument(
Expand All @@ -33,19 +35,34 @@ def backup():
type=str,
help='Name of column to be backed up.'
)
parser.add_argument(
'--label',
type=str,
help='An explicit label to include in the backup name. Defaults to '
'datetime at which the backup was created. Full name will be '
'given by [label]-[msname]-[column].bkp.qc.'
)
parser.add_argument(
'--nthread',
type=int,
default=1,
help='Number of threads to use.'
)
parser.add_argument(
'--field-id',
type=int,
help='Field ID to back up.'
)

args = parser.parse_args()

ms_name = args.ms_path.full_path.rsplit("/", 1)[1]
column_name = args.column_name

timestamp = time.strftime("%Y%m%d-%H%M%S")
if args.label:
label = args.label
else:
label = time.strftime("%Y%m%d-%H%M%S")

# This call exists purely to get the relevant shape and dtype info.
data_xds_list = xds_from_storage_ms(
Expand All @@ -55,8 +72,11 @@ def backup():
group_cols=("FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"),
)

# Use existing functionality. TODO: Improve and expose DDID selection.
xdso = filter_xds_list(data_xds_list, args.field_id)

# Compute appropriate chunks (256MB by default) to keep zarr happy.
chunks = [chunk_by_size(xds[column_name]) for xds in data_xds_list]
chunks = [chunk_by_size(xds[column_name]) for xds in xdso]

# Repeat of above call but now with correct chunking information.
data_xds_list = xds_from_storage_ms(
Expand All @@ -67,9 +87,12 @@ def backup():
chunks=chunks
)

# Use existing functionality. TODO: Improve and expose DDID selection.
xdso = filter_xds_list(data_xds_list, args.field_id)

bkp_xds_list = xds_to_zarr(
data_xds_list,
f"{args.zarr_dir.url}::{timestamp}-{ms_name}-{column_name}.bkp.qc",
xdso,
f"{args.zarr_dir.url}::{label}-{ms_name}-{column_name}.bkp.qc",
)

dask.compute(bkp_xds_list, num_workers=args.nthread)
Expand Down
12 changes: 12 additions & 0 deletions quartical/calibration/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def make_visibility_output(

itr = enumerate(zip(data_xds_list, mapping_xds_list))

if output_opts.subtract_directions:
n_dir = data_xds_list[0].dims['dir'] # Should be the same on all xdss.
requested = set(output_opts.subtract_directions)
valid = set(range(n_dir))
invalid = requested - valid
if invalid:
raise ValueError(
f"User has specified output.subtract_directions as "
f"{requested} but the following directions are not present "
f"in the model: {invalid}."
)

for xds_ind, (data_xds, mapping_xds) in itr:
data_col = data_xds.DATA.data
model_col = data_xds.MODEL_DATA.data
Expand Down
12 changes: 9 additions & 3 deletions quartical/data_handling/selection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
def filter_xds_list(xds_list, fields, ddids):
def filter_xds_list(xds_list, fields=[], ddids=[]):

filter_fields = {"FIELD_ID": fields,
"DATA_DESC_ID": ddids}
# If we specify an int, make it a list. Might be worth improving.
fields = [fields] if isinstance(fields, int) else fields
ddids = [ddids] if isinstance(ddids, int) else ddids

filter_fields = {
"FIELD_ID": fields,
"DATA_DESC_ID": ddids
}

for k, v in filter_fields.items():
fil = filter(lambda xds: getattr(xds, k) in v, xds_list)
Expand Down

0 comments on commit 9f1d03a

Please sign in to comment.