Skip to content

Commit

Permalink
Rename ts_slope_ metrics to slope
Browse files Browse the repository at this point in the history
  • Loading branch information
daberer committed Oct 30, 2024
1 parent ca9c482 commit 1413b4f
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 33 deletions.
46 changes: 23 additions & 23 deletions src/qa4sm_reader/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ def get_status_colors():
'YlGn'], # sequential: increasing value good (n_obs, STDerr)
'qua_neutr':
get_status_colors(), # qualitative category with 2 forced colors
# Added colormaps for ts_slope metrics
'div_ts_BIAS': matplotlib.colormaps[
# Added colormaps for slope metrics
'div_slopeBIAS': matplotlib.colormaps[
'RdBu_r'
], # diverging colormap for ts_slope_BIAS
'div_ts_R': matplotlib.colormaps[
], # diverging colormap for slopeBIAS
'div_slopeR': matplotlib.colormaps[
'PiYG'
], # diverging colormap for ts_slope_R
'div_ts_urmsd': matplotlib.colormaps[
], # diverging colormap for slopeR
'div_slopeURMSD': matplotlib.colormaps[
'PuOr'
] # diverging colormap for ts_slope_urmsd
] # diverging colormap for slopeURMSD
}

_colormaps = { # from /qa4sm/validator/validation/graphics.py
Expand All @@ -166,9 +166,9 @@ def get_status_colors():
'err_std': _cclasses['seq_worse'],
'beta': _cclasses['div_neutr'],
'status': _cclasses['qua_neutr'],
'ts_slope_R': _cclasses['div_ts_R'],
'ts_slope_urmsd': _cclasses['div_ts_urmsd'],
'ts_slope_BIAS': _cclasses['div_ts_BIAS'],
'slopeR': _cclasses['div_slopeR'],
'slopeURMSD': _cclasses['div_slopeURMSD'],
'slopeBIAS': _cclasses['div_slopeBIAS'],
}

# Colorbars for difference plots
Expand Down Expand Up @@ -202,7 +202,7 @@ def get_status_colors():
'mse_bias', 'mse_var', 'RSS', 'tau', 'p_tau', 'status'
],
'triple': ['snr', 'err_std', 'beta', 'status'],
'pairwise_stability': ['ts_slope_urmsd', 'ts_slope_R', 'ts_slope_BIAS']
'pairwise_stability': ['slopeURMSD', 'slopeR', 'slopeBIAS']
}

def get_metric_format(group, metric_dict):
Expand Down Expand Up @@ -270,9 +270,9 @@ def get_metric_format(group, metric_dict):
'err_std': [None, None],
'beta': [None, None],
'status': [-1, len(status)-2],
'ts_slope_R': [None, None],
'ts_slope_urmsd': [None, None],
'ts_slope_BIAS': [None, None],
'slopeR': [None, None],
'slopeURMSD': [None, None],
'slopeBIAS': [None, None],
}
# mask values out of range
_metric_mask_range = {
Expand Down Expand Up @@ -304,9 +304,9 @@ def get_metric_format(group, metric_dict):
'err_std': ' in {}',
'beta': ' in {}',
'status': '',
'ts_slope_R': '',
'ts_slope_urmsd': ' in {}',
'ts_slope_BIAS': ' in {}',
'slopeR': '',
'slopeURMSD': ' in {}',
'slopeBIAS': ' in {}',
}


Expand Down Expand Up @@ -391,9 +391,9 @@ def get_metric_units(dataset, metric=None, raise_error=False):
}

STABILITY_METRICS = {
'ts_slope_R' : 'Theil-Sen slope of R',
'ts_slope_urmsd' : 'Theil-Sen slope of urmsd',
'ts_slope_BIAS' : 'Theil-Sen slope of BIAS'
'slopeR' : 'Theil-Sen slope of R',
'slopeURMSD' : 'Theil-Sen slope of urmsd',
'slopeBIAS' : 'Theil-Sen slope of BIAS'
}

_metric_name = {**COMMON_METRICS, **READER_EXCLUSIVE_METRICS, **TC_METRICS, **STABILITY_METRICS}
Expand Down Expand Up @@ -775,9 +775,9 @@ def get_resolution_info(dataset, raise_error=False):
'tau',
'p_tau',
'status',
'ts_slope_R',
'ts_slope_urmsd',
'ts_slope_BIAS',
'slopeR',
'slopeURMSD',
'slopeBIAS',
]

METRIC_TEMPLATE = '_between_{ds1}_and_{ds2}'
Expand Down
6 changes: 3 additions & 3 deletions src/qa4sm_reader/netcdf_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ def drop_obs_dim(self) -> None:

def mask_redundant_tsw_values(self) -> None:
"""
For all variables starting with 'ts_slope_', replace all tsw values ('2010', '2011', etc.) with NaN
For all variables starting with 'slope', replace all tsw values ('2010', '2011', etc.) with NaN
except for the default tsw.
"""
ts_slope_vars = [var for var in self.transcribed_dataset if var.startswith("ts_slope_")]
slope_vars = [var for var in self.transcribed_dataset if var.startswith("slope")]

for var in ts_slope_vars:
for var in slope_vars:
if TEMPORAL_SUB_WINDOW_NC_COORD_NAME in self.transcribed_dataset[var].dims:
mask = self.transcribed_dataset[var][TEMPORAL_SUB_WINDOW_NC_COORD_NAME] == DEFAULT_TSW
self.transcribed_dataset[var] = self.transcribed_dataset[var].where(mask, other=np.nan)
Expand Down
2 changes: 0 additions & 2 deletions src/qa4sm_reader/plot_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,8 @@ def plot_all(filepath: str,
stability = all(item.isdigit() for item in comparison_periods if item != 'bulk')
cbp = QA4SMCompPlotter(filepath)
comparison_boxplot_dir = os.path.join(out_dir, globals.CLUSTERED_BOX_PLOT_OUTDIR)
#if not stability:
os.makedirs(comparison_boxplot_dir, exist_ok=True)


for available_metric in cbp.metric_kinds_available:
if available_metric in metrics.keys(
) and available_metric not in metrics_not_to_plot:
Expand Down
7 changes: 3 additions & 4 deletions src/qa4sm_reader/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,13 +885,14 @@ def plot_metric(self,
fnames_bplot = None
Metric = self.img.metrics[metric]

fnames_mapplot = None
if Metric.name == 'status':
fnames_bplot = self.barplot(metric='status',
period=period,
out_types=out_types,
save_files=save_all)

elif Metric.g == 'common' or Metric.g == 'pairwise':
elif Metric.g == 'common' or Metric.g == 'pairwise' or Metric.g == 'pairwise_stability':
fnames_bplot = self.boxplot_basic(metric=metric,
period=period,
out_types=out_types,
Expand All @@ -908,9 +909,7 @@ def plot_metric(self,
period=period,
out_types=out_types,
save_files=save_all,
**plotting_kwargs)
else:
fnames_mapplot = None
**plotting_kwargs)

return fnames_bplot, fnames_mapplot

Expand Down
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_netcdf_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def test_plotting(seasonal_qa4sm_file, monthly_qa4sm_file, stability_qa4sm_file,
), f"{tmp_stability_dir / tsw / f'{tsw}_statistics_table.csv'} does not exist"

plot_dir = Path(tmp_stability_dir / globals.DEFAULT_TSW)
assert len(list(plot_dir.iterdir())) == 63
assert len(list(plot_dir.iterdir())) == 69
assert all(file.suffix in [".png", ".svg", ".csv"] for file in plot_dir.iterdir()), "Not all files have been saved as .png or .csv"


Expand Down

0 comments on commit 1413b4f

Please sign in to comment.