Skip to content

Commit

Permalink
Merge in changes from experimental dask-ms branch i.e. fragment support.
Browse files Browse the repository at this point in the history
  • Loading branch information
JSKenyon committed Sep 21, 2023
2 parents 9f1d03a + d8a9719 commit 34db3fb
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 45 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ columnar = "^1.4.1"
"ruamel.yaml" = "^0.17.26"
dask = {extras = ["diagnostics"], version = "^2023.1.0"}
distributed = "^2023.1.0"
dask-ms = {git = "https://github.com/ratt-ru/dask-ms.git", branch = "multisource-experimental", extras = ["s3", "xarray", "zarr"]}
dask-ms = {git = "https://github.com/ratt-ru/dask-ms.git", extras = ["s3", "xarray", "zarr"]}
codex-africanus = {extras = ["dask", "scipy", "astropy", "python-casacore"], version = "^0.3.4"}
astro-tigger-lsm = "^1.7.2"
loguru = "^0.7.0"
Expand Down
31 changes: 17 additions & 14 deletions quartical/apps/summary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import argparse
from pathlib import Path
from daskms import xds_from_storage_ms, xds_from_storage_table
from daskms.experimental.fragments import (
xds_from_ms_fragment,
xds_from_table_fragment
)
from daskms.fsspec_store import DaskMSStore
import numpy as np
import dask.array as da
Expand Down Expand Up @@ -44,7 +47,7 @@ def configure_loguru(output_dir):
def antenna_info(path):

# NOTE: Assume one dataset for now.
ant_xds = xds_from_storage_table(path + "::ANTENNA")[0]
ant_xds = xds_from_table_fragment(path + "::ANTENNA")[0]

antenna_names = ant_xds.NAME.values
antenna_mounts = ant_xds.MOUNT.values
Expand All @@ -64,7 +67,7 @@ def antenna_info(path):

def data_desc_info(path):

dd_xds_list = xds_from_storage_table( # noqa
dd_xds_list = xds_from_table_fragment( # noqa
path + "::DATA_DESCRIPTION",
group_cols=["__row__"],
chunks={"row": 1, "chan": -1}
Expand All @@ -76,7 +79,7 @@ def data_desc_info(path):

def feed_info(path):

feed_xds_list = xds_from_storage_table(
feed_xds_list = xds_from_table_fragment(
path + "::FEED",
group_cols=["SPECTRAL_WINDOW_ID"],
chunks={"row": -1}
Expand Down Expand Up @@ -106,15 +109,15 @@ def feed_info(path):

def flag_cmd_info(path):

flag_cmd_xds = xds_from_storage_table(path + "::FLAG_CMD") # noqa
flag_cmd_xds = xds_from_table_fragment(path + "::FLAG_CMD") # noqa

# Not printing any summary information for this subtable yet - not sure
# what is relevant.


def field_info(path):

field_xds = xds_from_storage_table(path + "::FIELD")[0]
field_xds = xds_from_table_fragment(path + "::FIELD")[0]

ids = [i for i in field_xds.SOURCE_ID.values]
names = [n for n in field_xds.NAME.values]
Expand All @@ -141,23 +144,23 @@ def field_info(path):

def history_info(path):

history_xds = xds_from_storage_table(path + "::HISTORY")[0] # noqa
history_xds = xds_from_table_fragment(path + "::HISTORY")[0] # noqa

# Not printing any summary information for this subtable yet - not sure
# what is relevant.


def observation_info(path):

observation_xds = xds_from_storage_table(path + "::OBSERVATION")[0] # noqa
observation_xds = xds_from_table_fragment(path + "::OBSERVATION")[0] # noqa

# Not printing any summary information for this subtable yet - not sure
# what is relevant.


def polarization_info(path):

polarization_xds = xds_from_storage_table(path + "::POLARIZATION")[0]
polarization_xds = xds_from_table_fragment(path + "::POLARIZATION")[0]

corr_types = polarization_xds.CORR_TYPE.values

Expand All @@ -175,15 +178,15 @@ def polarization_info(path):

def processor_info(path):

processor_xds = xds_from_storage_table(path + "::PROCESSOR")[0] # noqa
processor_xds = xds_from_table_fragment(path + "::PROCESSOR")[0] # noqa

# Not printing any summary information for this subtable yet - not sure
# what is relevant.


def spw_info(path):

spw_xds_list = xds_from_storage_table(
spw_xds_list = xds_from_table_fragment(
path + "::SPECTRAL_WINDOW",
group_cols=["__row__"],
chunks={"row": 1, "chan": -1}
Expand All @@ -207,7 +210,7 @@ def spw_info(path):

def state_info(path):

state_xds = xds_from_storage_table(path + "::STATE")[0] # noqa
state_xds = xds_from_table_fragment(path + "::STATE")[0] # noqa

# Not printing any summary information for this subtable yet - not sure
# what is relevant.
Expand All @@ -226,7 +229,7 @@ def source_info(path):

def pointing_info(path):

pointing_xds = xds_from_storage_table(path + "::POINTING")[0] # noqa
pointing_xds = xds_from_table_fragment(path + "::POINTING")[0] # noqa

# Not printing any summary information for this subtable yet - not sure
# what is relevant.
Expand Down Expand Up @@ -355,7 +358,7 @@ def summary():
# Open the data, grouping by the usual columns. Use these datasets to
# produce some useful summaries.

data_xds_list = xds_from_storage_ms(
data_xds_list = xds_from_ms_fragment(
path,
index_cols=("TIME",),
columns=("TIME", "FLAG", "FLAG_ROW", "DATA"),
Expand Down
9 changes: 9 additions & 0 deletions quartical/config/argument_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ output:
Name of directory in which QuartiCal logging outputs will be stored.
s3 is not currently supported for these outputs.

fragment_path:
dtype: Optional[str]
info:
If set, instead of mutating the input by e.g. writing flags, instead
writes a fragment to this location. A fragment is a zarr backed data
format that is read and dynamically combined with any parent datasets.
This allows QuartiCal to operate in an entirely read-only fashion.
This option is experimental.

log_to_terminal:
default: true
dtype: bool
Expand Down
8 changes: 4 additions & 4 deletions quartical/data_handling/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import casacore.measures
import casacore.quanta as pq

from daskms import xds_from_storage_table
from daskms.experimental.fragments import xds_from_table_fragment
import dask.array as da
import threading
from dask.graph_manipulation import clone
Expand All @@ -24,9 +24,9 @@ def make_parangle_xds_list(ms_path, data_xds_list):

# This may need to be more sophisticated. TODO: Can we guarantee that
# these only ever have one element?
anttab = xds_from_storage_table(ms_path + "::ANTENNA")[0]
feedtab = xds_from_storage_table(ms_path + "::FEED")[0]
fieldtab = xds_from_storage_table(ms_path + "::FIELD")[0]
anttab = xds_from_table_fragment(ms_path + "::ANTENNA")[0]
feedtab = xds_from_table_fragment(ms_path + "::FEED")[0]
fieldtab = xds_from_table_fragment(ms_path + "::FIELD")[0]

# We do this eagerly to make life easier.
feeds = feedtab.POLARIZATION_TYPE.values
Expand Down
9 changes: 6 additions & 3 deletions quartical/data_handling/chunking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import dask.delayed as dd
import numpy as np
import dask.array as da
from daskms import xds_from_storage_ms, xds_from_storage_table
from daskms.experimental.fragments import (
xds_from_ms_fragment,
xds_from_table_fragment
)


def compute_chunking(ms_opts, compute=True):
Expand All @@ -10,7 +13,7 @@ def compute_chunking(ms_opts, compute=True):
# necessary to determine initial chunking over row and chan. TODO: Test
# multi-SPW/field cases. Implement a memory budget.

indexing_xds_list = xds_from_storage_ms(
indexing_xds_list = xds_from_ms_fragment(
ms_opts.path,
columns=("TIME", "INTERVAL"),
index_cols=("TIME",),
Expand All @@ -24,7 +27,7 @@ def compute_chunking(ms_opts, compute=True):
compute=False
)

spw_xds_list = xds_from_storage_table(
spw_xds_list = xds_from_table_fragment(
ms_opts.path + "::SPECTRAL_WINDOW",
group_cols=["__row__"],
columns=["CHAN_FREQ", "CHAN_WIDTH"],
Expand Down
44 changes: 29 additions & 15 deletions quartical/data_handling/ms_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import warnings
import dask.array as da
import numpy as np
from daskms import (xds_from_storage_ms,
xds_from_storage_table,
xds_to_storage_table)
from daskms import xds_to_storage_table
from daskms.experimental.fragments import (
xds_from_ms_fragment,
xds_from_table_fragment,
xds_to_table_fragment
)
from dask.graph_manipulation import clone
from loguru import logger
from quartical.weights.weights import initialize_weights
Expand All @@ -28,15 +31,15 @@ def read_xds_list(model_columns, ms_opts):
data_xds_list: A list of appropriately chunked xarray datasets.
"""

antenna_xds = xds_from_storage_table(ms_opts.path + "::ANTENNA")[0]
antenna_xds = xds_from_table_fragment(ms_opts.path + "::ANTENNA")[0]

n_ant = antenna_xds.dims["row"]

logger.info("Antenna table indicates {} antennas were present for this "
"observation.", n_ant)

# Determine the number/type of correlations present in the measurement set.
pol_xds = xds_from_storage_table(ms_opts.path + "::POLARIZATION")[0]
pol_xds = xds_from_table_fragment(ms_opts.path + "::POLARIZATION")[0]

try:
corr_types = [CORR_TYPES[ct] for ct in pol_xds.CORR_TYPE.values[0]]
Expand All @@ -56,7 +59,7 @@ def read_xds_list(model_columns, ms_opts):
# probably need to be done on a per xds basis. Can probably be accomplished
# by merging the field xds grouped by DDID into data grouped by DDID.

field_xds = xds_from_storage_table(ms_opts.path + "::FIELD")[0]
field_xds = xds_from_table_fragment(ms_opts.path + "::FIELD")[0]
phase_dir = np.squeeze(field_xds.PHASE_DIR.values)
field_names = field_xds.NAME.values

Expand Down Expand Up @@ -90,7 +93,7 @@ def read_xds_list(model_columns, ms_opts):
schema[ms_opts.weight_column] = {'dims': ('chan', 'corr')}

try:
data_xds_list = xds_from_storage_ms(
data_xds_list = xds_from_ms_fragment(
ms_opts.path,
columns=columns,
index_cols=("TIME",),
Expand All @@ -103,7 +106,7 @@ def read_xds_list(model_columns, ms_opts):
f"Invalid/missing column specified. Underlying error: {e}."
) from e

spw_xds_list = xds_from_storage_table(
spw_xds_list = xds_from_table_fragment(
ms_opts.path + "::SPECTRAL_WINDOW",
group_cols=["__row__"],
columns=["CHAN_FREQ", "CHAN_WIDTH"],
Expand Down Expand Up @@ -213,7 +216,7 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts):
if not (output_opts.products or output_opts.flags):
return [None] * len(xds_list) # Write nothing to the MS.

pol_xds = xds_from_storage_table(ms_path + "::POLARIZATION")[0]
pol_xds = xds_from_table_fragment(ms_path + "::POLARIZATION")[0]
corr_types = [CORR_TYPES[ct] for ct in pol_xds.CORR_TYPE.values[0]]
ms_n_corr = len(corr_types)

Expand Down Expand Up @@ -295,12 +298,23 @@ def write_xds_list(xds_list, ref_xds_list, ms_path, output_opts):

with warnings.catch_warnings(): # We anticipate spurious warnings.
warnings.simplefilter("ignore")
write_xds_list = xds_to_storage_table(
xds_list,
ms_path,
columns=output_cols,
rechunk=True # Needed to ensure zarr chunks map correctly to disk.
)

if output_opts.fragment_path:
write_xds_list = xds_to_table_fragment(
xds_list,
output_opts.fragment_path,
ms_path,
columns=output_cols,
rechunk=True # Ensure zarr chunks map correctly to disk.
)

else:
write_xds_list = xds_to_storage_table(
xds_list,
ms_path,
columns=output_cols,
rechunk=True # Ensure zarr chunks map correctly to disk.
)

return write_xds_list

Expand Down
20 changes: 12 additions & 8 deletions quartical/data_handling/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import dask
from xarray import DataArray, Dataset
from dask.graph_manipulation import clone
from daskms import xds_from_storage_table
from daskms.experimental.fragments import xds_from_table_fragment

from loguru import logger
import numpy as np
import Tigger
Expand Down Expand Up @@ -310,21 +311,24 @@ def get_support_tables(ms_path):
"SPECTRAL_WINDOW", "POLARIZATION", "FEED")}

# All rows at once
lazy_tables = {"ANTENNA": xds_from_storage_table(n["ANTENNA"]),
"FEED": xds_from_storage_table(n["FEED"])}
lazy_tables = {"ANTENNA": xds_from_table_fragment(n["ANTENNA"]),
"FEED": xds_from_table_fragment(n["FEED"])}

compute_tables = {
# NOTE: Even though this has a fixed shape, I have ammended it to
# also group by row. This just makes life fractionally easier.
"DATA_DESCRIPTION": xds_from_storage_table(n["DATA_DESCRIPTION"],
group_cols="__row__"),
"DATA_DESCRIPTION": xds_from_table_fragment(
n["DATA_DESCRIPTION"], group_cols="__row__"
),
# Variably shaped, need a dataset per row
"FIELD":
xds_from_storage_table(n["FIELD"], group_cols="__row__"),
xds_from_table_fragment(n["FIELD"], group_cols="__row__"),
"SPECTRAL_WINDOW":
xds_from_storage_table(n["SPECTRAL_WINDOW"], group_cols="__row__"),
xds_from_table_fragment(
n["SPECTRAL_WINDOW"], group_cols="__row__"
),
"POLARIZATION":
xds_from_storage_table(n["POLARIZATION"], group_cols="__row__"),
xds_from_table_fragment(n["POLARIZATION"], group_cols="__row__"),
}

lazy_tables.update(dask.compute(compute_tables)[0])
Expand Down

0 comments on commit 34db3fb

Please sign in to comment.