diff --git a/src/qa4sm_reader/comparing.py b/src/qa4sm_reader/comparing.py index 28c82c8..3952892 100644 --- a/src/qa4sm_reader/comparing.py +++ b/src/qa4sm_reader/comparing.py @@ -141,7 +141,7 @@ def common_metrics(self) -> dict: img_metrics = {} for metric in img.metrics: # hardcoded because n_obs cannot be compared. todo: exclude empty metrics (problem: the values are not loaded here) - if metric in glob.metric_groups[0] or metric in [ + if metric in glob.metric_groups['common'] or metric in [ "tau", "p_tau" ]: continue diff --git a/src/qa4sm_reader/globals.py b/src/qa4sm_reader/globals.py index 7b382a4..b16c629 100644 --- a/src/qa4sm_reader/globals.py +++ b/src/qa4sm_reader/globals.py @@ -68,6 +68,7 @@ os.path.abspath(__file__)), 'static', 'images', 'logo', 'QA4SM_logo_long.png') + # === filename template === ds_fn_templ = "{i}-{ds}.{var}" ds_fn_sep = "_with_" @@ -133,6 +134,16 @@ def get_status_colors(): 'YlGn'], # sequential: increasing value good (n_obs, STDerr) 'qua_neutr': get_status_colors(), # qualitative category with 2 forced colors + # Added colormaps for slope metrics + 'div_slopeBIAS': matplotlib.colormaps[ + 'RdBu_r' + ], # diverging colormap for slopeBIAS + 'div_slopeR': matplotlib.colormaps[ + 'PiYG' + ], # diverging colormap for slopeR + 'div_slopeURMSD': matplotlib.colormaps[ + 'PuOr' + ] # diverging colormap for slopeURMSD } _colormaps = { # from /qa4sm/validator/validation/graphics.py @@ -155,6 +166,9 @@ def get_status_colors(): 'err_std': _cclasses['seq_worse'], 'beta': _cclasses['div_neutr'], 'status': _cclasses['qua_neutr'], + 'slopeR': _cclasses['div_slopeR'], + 'slopeURMSD': _cclasses['div_slopeURMSD'], + 'slopeBIAS': _cclasses['div_slopeBIAS'], } # Colorbars for difference plots @@ -180,34 +194,42 @@ def get_status_colors(): # METRICS AND VARIABLES DEFINITIONS # ===================================================== -# 0=common metrics, 2=paired metrics (2 datasets), 3=triple metrics (TC, 3 datasets) +# common metrics, pairwise metrics (2 datasets), triple metrics (TC, 3 datasets), pairwise stability metrics(2 datasets) metric_groups = { - 0: ['n_obs'], - 2: [ + 'common': ['n_obs'], + 'pairwise': [ 'R', 'p_R', 'rho', 'p_rho', 'RMSD', 'BIAS', 'urmsd', 'mse', 'mse_corr', - 'mse_bias', 'mse_var', 'RSS', 'tau', 'p_tau', 'status' + 'mse_bias', 'mse_var', 'RSS', 'tau', 'p_tau', 'status' ], - 3: ['snr', 'err_std', 'beta', 'status'] + 'triple': ['snr', 'err_std', 'beta', 'status'], + 'pairwise_stability': ['slopeURMSD', 'slopeR', 'slopeBIAS'] } +def get_metric_format(group, metric_dict): + # metric groups 'pairwise and 'pairwise_stability should be handled the same + if group == "pairwise_stability": + group = "pairwise" + return metric_dict.get(group) + # === variable template === # how the metric is separated from the rest var_name_metric_sep = { - 0: "{metric}", - 2: "{metric}_between_", - 3: "{metric}_{mds_id:d}-{mds}_between_" + 'common': "{metric}", + 'pairwise': "{metric}_between_", + 'triple': "{metric}_{mds_id:d}-{mds}_between_" } + var_name_CI = { - 0: "{metric}_ci_{bound}_between_", - 2: "{metric}_ci_{bound}_between_", - 3: "{metric}_ci_{bound}_{mds_id:d}-{mds}_between_" + 'common': "{metric}_ci_{bound}_between_", + 'pairwise': "{metric}_ci_{bound}_between_", + 'triple': "{metric}_ci_{bound}_{mds_id:d}-{mds}_between_" } # how two datasets are separated, ids must be marked as numbers with :d! var_name_ds_sep = { - 0: None, - 2: "{ref_id:d}-{ref_ds}_and_{sat_id0:d}-{sat_ds0}", - 3: - "{ref_id:d}-{ref_ds}_and_{sat_id0:d}-{sat_ds0}_and_{sat_id1:d}-{sat_ds1}" + 'common': None, + 'pairwise': "{ref_id:d}-{ref_ds}_and_{sat_id0:d}-{sat_ds0}", + 'triple': + "{ref_id:d}-{ref_ds}_and_{sat_id0:d}-{sat_ds0}_and_{sat_id1:d}-{sat_ds1}", } # === metadata templates === @@ -221,14 +243,11 @@ def get_status_colors(): # format should have (metric, ds, ref, other ds) _variable_pretty_name = { - 0: "{}", - 2: "{}\nof {}\nwith {} as reference", - 3: "{} of {} \n against {}, {}" + 'common': "{}", + 'pairwise': "{}\nof {}\nwith {} as reference", + 'triple': "{} of {} \n against {}, {}" } -# check if every metric has a colormap -for group in metric_groups.keys(): - assert all([m in _colormaps.keys() for m in metric_groups[group]]) # Value ranges of metrics, either absolute values, or a quantile between 0 and 1 _metric_value_ranges = { # from /qa4sm/validator/validation/graphics.py @@ -251,6 +270,9 @@ def get_status_colors(): 'err_std': [None, None], 'beta': [None, None], 'status': [-1, len(status)-2], + 'slopeR': [None, None], + 'slopeURMSD': [None, None], + 'slopeBIAS': [None, None], } # mask values out of range _metric_mask_range = { @@ -282,11 +304,14 @@ def get_status_colors(): 'err_std': ' in {}', 'beta': ' in {}', 'status': '', + 'slopeR': '', + 'slopeURMSD': ' in {}', + 'slopeBIAS': ' in {}', } # units for all datasets -def get_metric_units(dataset, raise_error=False): +def get_metric_units(dataset, metric=None, raise_error=False): # function to get m.u. with possibility to raise error _metric_units = { # from /qa4sm/validator/validation/graphics.py 'ISMN': 'm³/m³', @@ -311,22 +336,24 @@ def get_metric_units(dataset, raise_error=False): 'SMOS_SBPCA': 'm³/m³', } - try: - return _metric_units[dataset] + unit = _metric_units.get(dataset) - except KeyError: + if unit is None: if raise_error: - raise KeyError( - f"The dataset {dataset} has not been specified in {__name__}") - + raise KeyError(f"The dataset '{dataset}' has not been specified in {__name__}.") else: warnings.warn( - f"The dataset {dataset} has not been specified in {__name__}. " - f"Set 'raise_error' to True to raise an exception for this case." + f"The dataset '{dataset}' has not been specified in {__name__}. " + "Set 'raise_error' to True to raise an exception for this case.", + UserWarning ) - return "n.a." + if metric in STABILITY_METRICS: + unit += ' per decade' + + return unit + COMMON_METRICS = { 'R': 'Pearson\'s r', 'p_R': 'Pearson\'s r p-value', @@ -352,7 +379,7 @@ def get_metric_units(dataset, raise_error=False): 'p_rho': 'Spearman\'s ρ p-value', 'tau': 'Kendall rank correlation', 'p_tau': 'Kendall tau p-value', - 'status': 'Validation errors' + 'status': 'Validation errors', } QA4SM_EXCLUSIVE_METRICS = { @@ -363,7 +390,13 @@ def get_metric_units(dataset, raise_error=False): 'status': '# status', } -_metric_name = {**COMMON_METRICS, **READER_EXCLUSIVE_METRICS, **TC_METRICS} +STABILITY_METRICS = { + 'slopeR' : 'Theil-Sen slope of R', + 'slopeURMSD' : 'Theil-Sen slope of urmsd', + 'slopeBIAS' : 'Theil-Sen slope of BIAS' +} + +_metric_name = {**COMMON_METRICS, **READER_EXCLUSIVE_METRICS, **TC_METRICS, **STABILITY_METRICS} METRICS = {**COMMON_METRICS, **QA4SM_EXCLUSIVE_METRICS} @@ -426,6 +459,7 @@ def get_metric_units(dataset, raise_error=False): } } + # BACKUPS # ===================================================== # to fallback to in case the dataset attributes in the .nc file are @@ -741,6 +775,9 @@ def get_resolution_info(dataset, raise_error=False): 'tau', 'p_tau', 'status', + 'slopeR', + 'slopeURMSD', + 'slopeBIAS', ] METRIC_TEMPLATE = '_between_{ds1}_and_{ds2}' @@ -783,9 +820,26 @@ def get_resolution_info(dataset, raise_error=False): "Oct": [[10, 1], [10, 31]], "Nov": [[11, 1], [11, 30]], "Dec": [[12, 1], [12, 31]], + }, + #Fix as TemporalSubWindowsCreator checks pre-defined tsw on init + "stability":{ } } +def add_annual_subwindows(years): + # Loops through each year and create a subwindow for the entire year + annual_subwindows = {} + for year in years: + annual_subwindows[str(year)] = [ + [year, 1, 1], + [year, 12, 31] + ] + TEMPORAL_SUB_WINDOWS["custom"] = annual_subwindows + + +#years = [2008, 2009, 2010] +#add_annual_subwindows(years) + CLUSTERED_BOX_PLOT_STYLE = { 'fig_params': { 'title_fontsize': 20, @@ -806,8 +860,6 @@ def get_resolution_info(dataset, raise_error=False): CLUSTERED_BOX_PLOT_OUTDIR = 'comparison_boxplots' - - # netCDF transcription related settings # ===================================================== OLD_NCFILE_SUFFIX = '.old' diff --git a/src/qa4sm_reader/handlers.py b/src/qa4sm_reader/handlers.py index 314b3d7..3592d97 100644 --- a/src/qa4sm_reader/handlers.py +++ b/src/qa4sm_reader/handlers.py @@ -1,16 +1,14 @@ # -*- coding: utf-8 -*- from dataclasses import dataclass -import warnings - from qa4sm_reader import globals from parse import * -import warnings as warn +import warnings import re -from typing import List, Optional, Tuple, Dict, Any, Union - import matplotlib import matplotlib.axes from matplotlib.figure import Figure +from typing import List, Optional, Tuple, Dict, Any, Union + class MixinVarmeta: @@ -26,14 +24,14 @@ def pretty_name(self): template = "" template = template + globals._variable_pretty_name[self.g] - if self.g == 0: + if self.g == 'common': name = template.format(self.metric) - elif self.g == 2: + elif self.g == 'pairwise' or self.g == 'pairwise_stability': name = template.format(self.Metric.pretty_name, self.metric_ds[1]['pretty_title'], self.ref_ds[1]['pretty_title']) - elif self.g == 3: + elif self.g == 'triple': name = template.format(self.Metric.pretty_name, self.metric_ds[1]['pretty_title'], self.ref_ds[1]['pretty_title'], @@ -65,7 +63,7 @@ def get_varmeta(self) -> Tuple[Tuple, Tuple, Tuple, Tuple]: scale_ds: id, dict this is the scaling dataset """ - if self.g == 0: + if self.g == 'common': ref_ds = self.Datasets.dataset_metadata(self.Datasets._ref_id()) mds, dss, scale_ds = None, None, None @@ -82,17 +80,17 @@ def get_varmeta(self) -> Tuple[Tuple, Tuple, Tuple, Tuple]: warnings.warn( f"ID of scaling reference dataset could not be parsed, " f"units of spatial reference are used.") - + ref_ds = self.Datasets.dataset_metadata(self.parts['ref_id']) mds = self.Datasets.dataset_metadata(self.parts['sat_id0']) dss = None # if metric is status and globals.metric_groups is 3, add third dataset - if self.g == 3 and self.metric == 'status': + if self.g == 'triple' and self.metric == 'status': dss = self.Datasets.dataset_metadata(self.parts['sat_id1']) # if metric is TC, add third dataset - elif self.g == 3: + elif self.g == 'triple': mds = self.Datasets.dataset_metadata(self.parts['mds_id']) dss = self.Datasets.dataset_metadata(self.parts['sat_id1']) if dss == mds: @@ -134,6 +132,11 @@ def _ref_dc(self) -> int: ref_dc = 0 try: + # print(f'globals._ref_ds_attr: {globals._ref_ds_attr}') + # print(f'self.meta: {self.meta}') + # print( + # f'parse(globals._ds_short_name_attr, val_ref): {parse(globals._ds_short_name_attr, self.meta[globals._ref_ds_attr])}' + # ) # print(f'globals._ref_ds_attr: {globals._ref_ds_attr}') # print(f'self.meta: {self.meta}') # print( @@ -142,7 +145,7 @@ def _ref_dc(self) -> int: val_ref = self.meta[globals._ref_ds_attr] ref_dc = parse(globals._ds_short_name_attr, val_ref)[0] except KeyError as e: - warn("The netCDF file does not contain the attribute {}".format( + warnings.warn("The netCDF file does not contain the attribute {}".format( globals._ref_ds_attr)) raise e @@ -409,24 +412,24 @@ def ismetric(self) -> bool: def _parse_wrap(self, pattern, g): """Wrapper function that handles case of metric 'status' that occurs - in two globals.metric_groups (2,3). This is because a status array + in two globals.metric_groups (pairwise,triple). This is because a status array can be the result of a validation between two or three datasets (tc) """ - # ignore this case - (status is also in globals.metric_groups 2 but - # should be treated as group 3) + # ignore this case - (status is also in pairwise metric_groups but + # should be treated as triple metric_group) if self.varname.startswith('status') and (self.varname.count('_and_') - == 2) and g == 2: + == 2) and g == 'pairwise': return None # parse self.varname when three datasets elif self.varname.startswith('status') and (self.varname.count('_and_') - == 2) and g == 3: - template = globals.var_name_ds_sep[3] + == 2) and g == 'triple': + template = globals.var_name_ds_sep['triple'] return parse( - '{}{}'.format(globals.var_name_metric_sep[2], template), + '{}{}'.format(globals.var_name_metric_sep['pairwise'], template), self.varname) return parse(pattern, self.varname) - def _parse_varname(self) -> Tuple[str, int, dict]: + def _parse_varname(self) -> Tuple[str, str, dict]: """ Parse the name to get the metric, group and variable data @@ -434,7 +437,7 @@ def _parse_varname(self) -> Tuple[str, int, dict]: ------- metric : str metric name - g : int + g : str group parts : dict dictionary of MetricVariable data @@ -442,10 +445,10 @@ def _parse_varname(self) -> Tuple[str, int, dict]: metr_groups = list(globals.metric_groups.keys()) # check which group it belongs to for g in metr_groups: - template = globals.var_name_ds_sep[g] + template = globals.get_metric_format(g, globals.var_name_ds_sep) if template is None: template = '' - pattern = '{}{}'.format(globals.var_name_metric_sep[g], template) + pattern = '{}{}'.format(globals.get_metric_format(g, globals.var_name_metric_sep), template) # parse infromation from pattern and name parts = self._parse_wrap(pattern, g) @@ -455,7 +458,7 @@ def _parse_varname(self) -> Tuple[str, int, dict]: return parts['metric'], g, parts.named # perhaps it's a CI variable else: - pattern = '{}{}'.format(globals.var_name_CI[g], template) + pattern = '{}{}'.format(globals.get_metric_format(g, globals.var_name_CI), template) parts = parse(pattern, self.varname) if parts is not None and parts[ 'metric'] in globals.metric_groups[g]: @@ -535,7 +538,7 @@ def _get_attribute(self, attr: str): """ for n, Var in enumerate(self.variables): value = getattr(Var, attr) - # special case for "status" attribute (self.g can be 2 or 3) + # special case for "status" attribute (self.g can be 'pairwise' or 'triple') if n != 0 and not Var.varname.startswith('status'): assert value == previous, "The attribute {} is not equal in all variables".format( attr) @@ -573,3 +576,4 @@ class CWContainer: centers: List[float] widths: List[float] name: Optional[str] = 'Generic Dataset' + diff --git a/src/qa4sm_reader/img.py b/src/qa4sm_reader/img.py index dbe8e91..88fc7d9 100644 --- a/src/qa4sm_reader/img.py +++ b/src/qa4sm_reader/img.py @@ -351,11 +351,11 @@ def group_metrics(self, metrics: list = None) -> Union[None, Tuple[dict, dict, d # fill dictionaries for metric in metrics: Metric = self.metrics[metric] - if Metric.g == 0: + if Metric.g == 'common': common[metric] = Metric - elif Metric.g == 2: + elif Metric.g == 'pairwise' or Metric.g == 'pairwise_stability': double[metric] = Metric - elif Metric.g == 3: + elif Metric.g == 'triple': triple[metric] = Metric return common, double, triple @@ -490,17 +490,17 @@ def _metric_stats(self, metric, id=None) -> list: # find the statistics for the metric variable var_stats = [mean, values.median(), iqr] - if Var.g == 0: + if Var.g == 'common': var_stats.append('All datasets') var_stats.extend([globals._metric_name[metric], Var.g]) else: i, ds_name = Var.metric_ds - if Var.g == 2: + if Var.g == 'pairwise' or Var.g == 'pairwise_stability': var_stats.append('{}-{} ({})'.format( i, ds_name['short_name'], ds_name['pretty_version'])) - elif Var.g == 3: + elif Var.g == 'triple': o, other_ds = Var.other_ds var_stats.append( '{}-{} ({}); other ref: {}-{} ({})'.format( diff --git a/src/qa4sm_reader/intra_annual_temp_windows.py b/src/qa4sm_reader/intra_annual_temp_windows.py index b8d27ce..efaf349 100644 --- a/src/qa4sm_reader/intra_annual_temp_windows.py +++ b/src/qa4sm_reader/intra_annual_temp_windows.py @@ -148,6 +148,103 @@ def end_date_pretty(self) -> str: return self.end_date.strftime('%Y-%m-%d') +class TemporalSubWindowsFactory: + @staticmethod + def create(temporal_sub_window_type: str, overlap: int, period: Optional[List[datetime]], custom_subwindows: Optional[dict] = None): + """ + Factory method to instantiate the appropriate TemporalSubWindowsCreator + based on the type of metrics passed from the validation run. + """ + # Handle intra-annual or stability based on the type passed + if temporal_sub_window_type in ['seasons', 'months']: + return TemporalSubWindowsFactory._create_intra_annual(temporal_sub_window_type, overlap, period) + + elif temporal_sub_window_type == 'stability': + return TemporalSubWindowsFactory._create_stability(overlap, period, custom_subwindows) + + return None + + @staticmethod + def _create_intra_annual(temporal_sub_window_type: str, overlap: int, period: Optional[List[datetime]]): + """ + Create a TemporalSubWindowsCreator for intra-annual metrics. + """ + temp_sub_wdw_instance = TemporalSubWindowsCreator( + temporal_sub_window_type=temporal_sub_window_type, + overlap=overlap, + custom_file=None # Default to no custom file for intra-annual metrics + ) + + # Set default sub-windows + if not period: + period = [datetime(year=1978, month=1, day=1), datetime.now()] + + default_temp_sub_wndw = NewSubWindow(DEFAULT_TSW, period[0], period[1]) + temp_sub_wdw_instance.add_temp_sub_wndw( + new_temp_sub_wndw=default_temp_sub_wndw, + insert_as_first_wndw=True + ) + + return temp_sub_wdw_instance + + @staticmethod + def _create_stability(overlap: int, period: Optional[List[datetime]], custom_subwindows: Optional[dict] = None): + """ + Create a TemporalSubWindowsCreator for stability metrics. + """ + temp_sub_wdw_instance = TemporalSubWindowsCreator( + temporal_sub_window_type="stability", + overlap=overlap, + custom_file=None + ) + + # Remove existing sub-windows + temp_sub_wdw_instance.remove_temp_sub_wndws() + + # Set default sub-windows + if not period: + period = [datetime(year=1978, month=1, day=1), datetime.now()] + + default_temp_sub_wndw = NewSubWindow(DEFAULT_TSW, period[0], period[1]) + temp_sub_wdw_instance.add_temp_sub_wndw(default_temp_sub_wndw) + + # Add annual sub-windows based on the years in the period + years = list(range(*(map(lambda dt: dt.year, period)))) + add_annual_subwindows(temp_sub_wdw_instance, years) + + # Add custom sub-windows if provided + if custom_subwindows: + for key, value in custom_subwindows.items(): + new_subwindow = NewSubWindow( + name=key, + begin_date=datetime(value[0][0], value[0][1], value[0][2]), + end_date=datetime(value[1][0], value[1][1], value[1][2]), + ) + temp_sub_wdw_instance.add_temp_sub_wndw(new_subwindow) + + return temp_sub_wdw_instance + + +# Helper function to add annual sub-windows +def add_annual_subwindows(temp_sub_wdw_instance: 'TemporalSubWindowsCreator', years: List[int]): + """ + Add annual sub-windows based on the list of years. + + Parameters + ---------- + temp_sub_wdw_instance: TemporalSubWindowsCreator + The instance to which the sub-windows will be added. + years: List[int] + List of years to generate sub-windows for. + """ + for year in years: + temp_sub_wdw_instance.add_temp_sub_wndw( + NewSubWindow(f"{year}", datetime(year, 1, 1), datetime(year, 12, 31)) + ) + + + + class TemporalSubWindowsCreator(TemporalSubWindowsDefault): '''Class to create custom temporal sub-windows, based on the default definitions. @@ -401,6 +498,15 @@ def overwrite_temp_sub_wndw( except Exception as e: print(f'Error: {e}') return None + + def remove_temp_sub_wndws(self) -> None: + '''Removes all existing custom temporal sub-windows. + + This method clears both the `custom_temporal_sub_windows` and the `additional_temp_sub_wndws_container`. + ''' + self.custom_temporal_sub_windows.clear() + self.additional_temp_sub_wndws_container.clear() + print("All custom temporal sub-windows have been removed.") @property def names(self) -> List[str]: diff --git a/src/qa4sm_reader/netcdf_transcription.py b/src/qa4sm_reader/netcdf_transcription.py index e4f501d..05744d7 100644 --- a/src/qa4sm_reader/netcdf_transcription.py +++ b/src/qa4sm_reader/netcdf_transcription.py @@ -10,7 +10,7 @@ from pathlib import Path from qa4sm_reader.intra_annual_temp_windows import TemporalSubWindowsCreator, InvalidTemporalSubWindowError -from qa4sm_reader.globals import METRICS, TC_METRICS, NON_METRICS, METADATA_TEMPLATE, \ +from qa4sm_reader.globals import METRICS, TC_METRICS, STABILITY_METRICS, NON_METRICS, METADATA_TEMPLATE, \ IMPLEMENTED_COMPRESSIONS, ALLOWED_COMPRESSION_LEVELS, \ INTRA_ANNUAL_METRIC_TEMPLATE, INTRA_ANNUAL_TCOL_METRIC_TEMPLATE, \ TEMPORAL_SUB_WINDOW_SEPARATOR, DEFAULT_TSW, TEMPORAL_SUB_WINDOW_NC_COORD_NAME, \ @@ -209,6 +209,24 @@ def is_valid_tcol_metric_name(self, tcol_metric_name): ] return any( tcol_metric_name.startswith(prefix) for prefix in valid_prefixes) + + def is_valid_stability_metric_name(self, metric_name): + """ + Checks if a given stability metric name is valid, based on the defined `globals.INTRA_ANNUAL_METRIC_TEMPLATE`. + + Parameters: + metric_name (str): The stability metric name to be checked. + + Returns: + bool: True if the stability metric name is valid, False otherwise. + """ + valid_prefixes = [ + "".join( + template.format(tsw=tsw, metric=metric) + for template in INTRA_ANNUAL_METRIC_TEMPLATE) + for tsw in self.provided_tsws for metric in STABILITY_METRICS + ] + return any(metric_name.startswith(prefix) for prefix in valid_prefixes) @property def metrics_list(self) -> List[str]: @@ -229,6 +247,7 @@ def metrics_list(self) -> List[str]: metric for metric in self.pytesmo_results if self.is_valid_metric_name(metric) or self.is_valid_tcol_metric_name(metric) + or self.is_valid_stability_metric_name(metric) ] if len(_metrics) != 0: # intra-annual case @@ -270,6 +289,18 @@ def drop_obs_dim(self) -> None: self.transcribed_dataset = self.transcribed_dataset.drop_dims( 'obs') + def mask_redundant_tsw_values(self) -> None: + """ + For all variables starting with 'slope', replace all tsw values ('2010', '2011', etc.) with NaN + except for the default tsw. + """ + slope_vars = [var for var in self.transcribed_dataset if var.startswith("slope")] + + for var in slope_vars: + if TEMPORAL_SUB_WINDOW_NC_COORD_NAME in self.transcribed_dataset[var].dims: + mask = self.transcribed_dataset[var][TEMPORAL_SUB_WINDOW_NC_COORD_NAME] == DEFAULT_TSW + self.transcribed_dataset[var] = self.transcribed_dataset[var].where(mask, other=np.nan) + @staticmethod def update_dataset_var(ds: xr.Dataset, var: str, coord_key: str, coord_val: str, data_vals: List) -> xr.Dataset: @@ -303,14 +334,13 @@ def update_dataset_var(ds: xr.Dataset, var: str, coord_key: str, def get_transcribed_dataset(self) -> xr.Dataset: """ - Get the transcribed dataset, containing all metric and non-metric data provided by the pytesmo results. Metadata - is not yet included. + Get the transcribed dataset, containing all metric and non-metric data provided by the pytesmo results. Returns ------- xr.Dataset - The transcribed, metadata-less dataset. + The transcribed dataset. """ self.only_default_case, self.provided_tsws = self.temporal_sub_windows_checker( ) @@ -327,7 +357,7 @@ def get_transcribed_dataset(self) -> xr.Dataset: _tsw, new_name = new_name.split(TEMPORAL_SUB_WINDOW_SEPARATOR) if new_name not in self.transcribed_dataset: - # takes the data associated with the metric new_name and adds it as a new variabel + # takes the data associated with the metric new_name and adds it as a new variable # more precisely, it assigns copies of this data to each temporal sub-window, which is the new dimension self.transcribed_dataset[new_name] = self.pytesmo_results[ var_name].expand_dims( @@ -356,6 +386,7 @@ def get_transcribed_dataset(self) -> xr.Dataset: self.get_pytesmo_attrs() self.handle_n_obs() self.drop_obs_dim() + self.mask_redundant_tsw_values() self.transcribed_dataset[ TEMPORAL_SUB_WINDOW_NC_COORD_NAME].attrs = dict( @@ -715,9 +746,14 @@ def get_custom_tsws(tsw_list): if tsw not in month_order and tsw not in seasons_1_order and tsw not in seasons_2_order ] - return customs, list(set(tsw_list) - set(customs)) + return customs, list(set(tsw_list) - set(customs)) custom_tsws, tsw_list = get_custom_tsws(tsw_list) + + if all(tsw.isdigit() for tsw in custom_tsws): + custom_tsws = sorted(custom_tsws, key=int) + + lens = {len(tsw) for tsw in tsw_list} if lens == {2} and all( diff --git a/src/qa4sm_reader/plot_all.py b/src/qa4sm_reader/plot_all.py index a6be966..17397d3 100644 --- a/src/qa4sm_reader/plot_all.py +++ b/src/qa4sm_reader/plot_all.py @@ -5,9 +5,9 @@ from itertools import chain import pandas as pd +from qa4sm_reader.netcdf_transcription import Pytesmo2Qa4smResultsTranscriber from qa4sm_reader.plotter import QA4SMPlotter, QA4SMCompPlotter from qa4sm_reader.img import QA4SMImg -from qa4sm_reader.netcdf_transcription import Pytesmo2Qa4smResultsTranscriber import qa4sm_reader.globals as globals import numpy as np import matplotlib.pyplot as plt @@ -72,6 +72,7 @@ def plot_all(filepath: str, list of filenames for created comparison boxplots """ + if isinstance(save_metadata, bool): if not save_metadata: save_metadata = 'never' @@ -87,10 +88,16 @@ def plot_all(filepath: str, # initialise image and plotter fnames_bplot, fnames_mapplot, fnames_csv = [], [], [] + + comparison_periods = None if temporal_sub_windows is None: periods = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile(filepath) else: periods = np.array(temporal_sub_windows) + # Filter out all items that are purely digits + # Needs to be here because when qa4sm-validaion is run the temporal_sub_windows is not None + comparison_periods = periods + periods = [period for period in periods if not period.isdigit()] for period in periods: print(f'period: {period}') @@ -147,21 +154,25 @@ def plot_all(filepath: str, fnames_cbplot = [] if isinstance(out_type, str): out_type = [out_type] - metrics_not_to_plot = list(set(chain(globals._metadata_exclude, globals.metric_groups[3], ['n_obs']))) # metadata, tcol metrics, n_obs - if globals.DEFAULT_TSW in periods and len(periods) > 1: + metrics_not_to_plot = list(set(chain(globals._metadata_exclude, globals.metric_groups['triple'], ['n_obs']))) # metadata, tcol metrics, n_obs + if globals.DEFAULT_TSW in comparison_periods and len(comparison_periods) > 1: + #check if stability metrics where calculated + stability = all(item.isdigit() for item in comparison_periods if item != 'bulk') cbp = QA4SMCompPlotter(filepath) - if not os.path.isdir(os.path.join(out_dir, globals.CLUSTERED_BOX_PLOT_OUTDIR)): - os.makedirs(os.path.join(out_dir, globals.CLUSTERED_BOX_PLOT_OUTDIR)) + comparison_boxplot_dir = os.path.join(out_dir, globals.CLUSTERED_BOX_PLOT_OUTDIR) + os.makedirs(comparison_boxplot_dir, exist_ok=True) for available_metric in cbp.metric_kinds_available: if available_metric in metrics.keys( ) and available_metric not in metrics_not_to_plot: + spth = [Path(out_dir) / globals.CLUSTERED_BOX_PLOT_OUTDIR / f'{globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=available_metric, filetype=_out_type)}' for _out_type in out_type] _fig = cbp.plot_cbp( chosen_metric=available_metric, out_name=spth, + stability=stability ) plt.close(_fig) fnames_cbplot.extend(spth) diff --git a/src/qa4sm_reader/plotter.py b/src/qa4sm_reader/plotter.py index 6728e94..c704489 100644 --- a/src/qa4sm_reader/plotter.py +++ b/src/qa4sm_reader/plotter.py @@ -318,7 +318,7 @@ def _yield_values( for n, Var in enumerate(Vars): values = Var.values[Var.varname] # changes if it's a common-type Var - if Var.g == 0: + if Var.g == 'common': box_cap_ds = 'All datasets' else: box_cap_ds = self._box_caption(Var, tc=tc) @@ -409,7 +409,7 @@ def _boxplot_definition(self, ax.set_title(title, pad=globals.title_pad) if self.img.has_CIs: offset = 0.08 # offset smaller as CI variables have a larger caption - if Var.g == 0: + if Var.g == 'common': offset = 0.03 # offset larger as common metrics have a shorter caption # fig.tight_layout() @@ -674,7 +674,7 @@ def barplot( if values.empty: return None - if len(self.img.triple) and Var.g == 2: + if len(self.img.triple) and Var.g == 'pairwise': continue ref_meta, mds_meta, other_meta, _ = Var.get_varmeta() @@ -776,7 +776,7 @@ def mapplot_var( save_name = self.create_filename(Var=Var, type="mapplot_status", period=period) - elif Var.g == 0: + elif Var.g == 'common': title = "{} between all datasets".format( globals._metric_name[metric]) if period: @@ -784,7 +784,7 @@ def mapplot_var( save_name = self.create_filename(Var, type='mapplot_common', period=period) - elif Var.g == 2: + elif Var.g == 'pairwise' or 'pairwise_stability': title = self.create_title(Var=Var, type='mapplot_basic', period=period) @@ -845,7 +845,7 @@ def mapplot_metric(self, fnames = [] for Var in self.img._iter_vars(type="metric", filter_parms={"metric": metric}): - if len(self.img.triple) and Var.g == 2 and metric == 'status': + if len(self.img.triple) and Var.g == 'pairwise' and metric == 'status': continue if not (np.isnan(Var.values.to_numpy()).all() or Var.is_CI): fns = self.mapplot_var(Var, @@ -858,7 +858,7 @@ def mapplot_metric(self, continue if save_files: fnames.extend(fns) - plt.close('all') + plt.close('all') if fnames: return fnames @@ -876,27 +876,29 @@ def plot_metric(self, ---------- metric: str name of the metric - out_types: str or list of str, Optional - extensions which the files should be saved in. Default is 'png' + out_types: str or list + extensions which the files should be saved in save_all: bool, optional. Default is True. all plotted images are saved to the output directory plotting_kwargs: arguments for mapplot function. """ + fnames_bplot = None Metric = self.img.metrics[metric] - + + fnames_mapplot = None if Metric.name == 'status': fnames_bplot = self.barplot(metric='status', period=period, out_types=out_types, save_files=save_all) - elif Metric.g == 0 or Metric.g == 2: + elif Metric.g == 'common' or Metric.g == 'pairwise' or Metric.g == 'pairwise_stability': fnames_bplot = self.boxplot_basic(metric=metric, period=period, out_types=out_types, save_files=save_all, **plotting_kwargs) - elif Metric.g == 3: + elif Metric.g == 'triple': fnames_bplot = self.boxplot_tc(metric=metric, period=period, out_types=out_types, @@ -907,9 +909,7 @@ def plot_metric(self, period=period, out_types=out_types, save_files=save_all, - **plotting_kwargs) - else: - fnames_mapplot = None + **plotting_kwargs) return fnames_bplot, fnames_mapplot @@ -1690,6 +1690,7 @@ def _iter_vars(self, def plot_cbp(self, chosen_metric: str, + stability: bool, out_name: Optional[Union[List, List[str]]] = None) -> matplotlib.figure.Figure: """ Plot a Clustered Boxplot for a chosen metric @@ -1707,7 +1708,7 @@ def plot_cbp(self, the boxplot """ - + anchor_list = None def get_metric_vars( generic_metric: str) -> Dict[str, hdl.MetricVariable]: _dict = {} @@ -1797,9 +1798,22 @@ def sanitize_dataframe(df: pd.DataFrame, legend_entries = get_legend_entries(cbp_obj=self.cbp, generic_metric=chosen_metric) + + anchor_list = None + if stability: + # get the first dataset to deduce the number of anchors - important for the boxplot setup + unique_groups = metric_df.columns.get_level_values(0).unique() + first_df = metric_df.loc[:, metric_df.columns.get_level_values(0) == unique_groups[0]] + first_df = sanitize_dataframe(first_df, keep_empty_cols=False) + anchor_number = len(first_df.columns) + anchor_list = np.arange(anchor_number).astype(float) + + if anchor_list is None: + anchor_list = self.cbp.anchor_list + centers_and_widths = self.cbp.centers_and_widths( - anchor_list=self.cbp.anchor_list, + anchor_list=anchor_list, no_of_ds=self.cbp.no_of_ds, space_per_box_cluster=0.9, rel_indiv_box_width=0.8) @@ -1817,29 +1831,17 @@ def sanitize_dataframe(df: pd.DataFrame, legend_entries = get_legend_entries(cbp_obj=self.cbp, generic_metric=chosen_metric) - - centers_and_widths = self.cbp.centers_and_widths( - anchor_list=self.cbp.anchor_list, - no_of_ds=self.cbp.no_of_ds, - space_per_box_cluster=0.9, - rel_indiv_box_width=0.8) - - figwidth = globals.boxplot_width * (len(metric_df.columns) + 1 - ) # otherwise it's too narrow - figsize = [figwidth, globals.boxplot_height] - fig_kwargs = { - 'figsize': figsize, - 'dpi': 'figure', - 'bbox_inches': 'tight' - } - + cbp_fig = self.cbp.figure_template(incl_median_iqr_n_axs=False, fig_kwargs=fig_kwargs) - + legend_handles = [] for dc_num, (dc_val_name, Var) in enumerate(Vars.items()): _df = Var.values # get the dataframe for the specific metric, potentially with NaNs - _df = sanitize_dataframe(_df, keep_empty_cols=True) # sanitize the dataframe + if not stability: + _df = sanitize_dataframe(_df, keep_empty_cols=True) # sanitize the dataframe + else: + _df = sanitize_dataframe(_df, keep_empty_cols=False) # remove redundant columns bp = cbp_fig.ax_box.boxplot( [_df[col] for col in _df.columns], positions=centers_and_widths[dc_num].centers, @@ -1883,7 +1885,7 @@ def sanitize_dataframe(df: pd.DataFrame, ncols=_ncols) xtick_pos = self.cbp.centers_and_widths( - anchor_list=self.cbp.anchor_list, + anchor_list=anchor_list, no_of_ds=1, space_per_box_cluster=0.7, rel_indiv_box_width=0.8) @@ -1897,8 +1899,14 @@ def get_xtick_labels(df: pd.DataFrame) -> List: f"{tsw[1]}\nEmpty" if count == 0 else f"{tsw[1]}" for tsw, count in _count_dict.items() ] + xtick_labels = get_xtick_labels(_df) - cbp_fig.ax_box.set_xticklabels(get_xtick_labels(_df), ) + if len(xtick_labels) > 19 and stability: + xtick_labels = [label.replace("\n", " ") for label in xtick_labels] + cbp_fig.ax_box.set_xticklabels(xtick_labels) + cbp_fig.ax_box.tick_params(axis='x', rotation=315) + else: + cbp_fig.ax_box.set_xticklabels(xtick_labels) cbp_fig.ax_box.tick_params( axis='both', labelsize=globals.CLUSTERED_BOX_PLOT_STYLE['fig_params'] diff --git a/src/qa4sm_reader/plotting_methods.py b/src/qa4sm_reader/plotting_methods.py index 1596747..38cee02 100644 --- a/src/qa4sm_reader/plotting_methods.py +++ b/src/qa4sm_reader/plotting_methods.py @@ -458,7 +458,7 @@ def style_map( map_resolution, edgecolor='black', facecolor='none') - ax.add_feature(borders, linewidth=0.2, zorder=3) + ax.add_feature(borders, linewidth=0.5, zorder=3) if add_us_states: ax.add_feature(cfeature.STATES, linewidth=0.1, zorder=3) @@ -670,12 +670,12 @@ def _make_cbar(fig, if label is None: label = globals._metric_name[metric] + \ globals._metric_description[metric].format( - globals.get_metric_units(ref_short) + globals.get_metric_units(ref_short, metric) ) if scl_short: label = globals._metric_name[metric] + \ globals._metric_description[metric].format( - globals.get_metric_units(scl_short) + globals.get_metric_units(scl_short, metric) ) extend = get_extend_cbar(metric) diff --git a/tests/test_data/intra_annual/stability/0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_pytesmo.nc b/tests/test_data/intra_annual/stability/0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_pytesmo.nc new file mode 100644 index 0000000..c81cbec Binary files /dev/null and b/tests/test_data/intra_annual/stability/0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_pytesmo.nc differ diff --git a/tests/test_data/intra_annual/stability/0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_qa4sm.nc b/tests/test_data/intra_annual/stability/0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_qa4sm.nc new file mode 100644 index 0000000..010a320 Binary files /dev/null and b/tests/test_data/intra_annual/stability/0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_qa4sm.nc differ diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 42ed2e4..dff9b5c 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -214,7 +214,7 @@ def test_parse_varname(tc_metrics): for var in [tc_metrics["beta"], tc_metrics["r"], tc_metrics["n_obs"]]: info = var._parse_varname() assert type(info[0]) == str - assert type(info[1]) == int + assert type(info[1]) == str assert type(info[2]) == dict diff --git a/tests/test_image.py b/tests/test_image.py index bb1ee36..aae8ad7 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -75,7 +75,7 @@ def test_load_vars(img): def test_iter_vars(img): for Var in img._iter_vars(type="metric"): - assert Var.g in [0, 2, 3] + assert Var.g in ['common', 'pairwise', 'triple'] for Var in img._iter_vars(type="metric", filter_parms={'metric': 'R'}): assert Var.varname in [ 'R_between_0-ERA5_LAND_and_2-SMOS_IC', @@ -84,8 +84,8 @@ def test_iter_vars(img): def test_iter_metrics(img): - for Metr in img._iter_metrics(**{'g': 2}): - assert Metr.name in globals.metric_groups[2] + for Metr in img._iter_metrics(**{'g': 'pairwise'}): + assert Metr.name in globals.metric_groups['pairwise'] def test_group_vars(img): @@ -109,8 +109,8 @@ def test_group_metrics(img): def test_load_metrics(img): - assert len(img.metrics.keys()) == len(globals.metric_groups[0]) + len( - globals.metric_groups[2]) - 2 + assert len(img.metrics.keys()) == len(globals.metric_groups['common']) + len( + globals.metric_groups['pairwise']) - 2 def test_ds2df(img): @@ -135,9 +135,9 @@ def test_metric_df(img): def test_metrics_in_file(img): """Test that all metrics are initialized correctly""" - assert list(img.common.keys()) == globals.metric_groups[0] + assert list(img.common.keys()) == globals.metric_groups['common'] for m in img.double.keys(): # tau is not in the results - assert m in globals.metric_groups[2] + assert m in globals.metric_groups['pairwise'] assert list(img.triple.keys()) == [] # this is not the TC test case # with merged return value @@ -153,7 +153,7 @@ def test_vars_in_file(img): vars.append(Var.varname) vars_should = ['n_obs'] # since the valination is non-TC - for metric in globals.metric_groups[2]: + for metric in globals.metric_groups['pairwise']: vars_should.append( '{}_between_0-ERA5_LAND_and_1-C3S_combined'.format(metric)) vars_should.append( @@ -173,13 +173,13 @@ def test_find_groups(img): """Test that all metrics for a specific group can be collected""" common_group = [] for name, Metric in img.common.items(): - assert Metric.name in globals.metric_groups[0] + assert Metric.name in globals.metric_groups['common'] assert len(Metric.variables) == 1 common_group.append(name) double_group = [] for name, Metric in img.double.items(): - assert Metric.name in globals.metric_groups[2] + assert Metric.name in globals.metric_groups['pairwise'] if name in [ 'p_R', 'p_rho', 'RMSD', 'mse', 'mse_corr', 'mse_bias', 'mse_var', 'RSS', 'status' @@ -250,16 +250,16 @@ def test_stats_df(img): for name, Metric in img.metrics.items(): stats = img._metric_stats(name) if not stats: # find metrics without values - if Metric.g == 1: + if Metric.g == 1: # don't know what was meant here, g was defined as 0,2 or 3 empty_metrics += 1 - elif Metric.g == 2: # stats table has an entry for metric, for sat dataset (in common and triple metrics) + elif Metric.g == 'pairwise': # stats table has an entry for metric, for sat dataset (in common and triple metrics) empty_metrics += 2 tot_stats = len( img.common.keys()) + 2 * len(img.double.keys()) - empty_metrics assert tot_stats == 27 - glob_stats = len(globals.metric_groups[0]) + 2 * len( - globals.metric_groups[2]) - empty_metrics + glob_stats = len(globals.metric_groups['common']) + 2 * len( + globals.metric_groups['pairwise']) - empty_metrics assert glob_stats == 31 # We drop the corr. significance statistics diff --git a/tests/test_intra_annual_temp_windows.py b/tests/test_intra_annual_temp_windows.py index c8c0a1b..fa06fe7 100644 --- a/tests/test_intra_annual_temp_windows.py +++ b/tests/test_intra_annual_temp_windows.py @@ -151,7 +151,7 @@ def test_default_monthly_sub_windows_attributes( assert default_monthly_sub_windows_no_overlap.custom_file == default_seasonal_sub_windows_no_overlap.custom_file == None assert default_monthly_sub_windows_no_overlap.available_temp_sub_wndws == default_seasonal_sub_windows_no_overlap.available_temp_sub_wndws == [ - 'seasons', 'months' + 'seasons', 'months', 'stability' ] assert default_monthly_sub_windows_no_overlap.names == [ @@ -231,7 +231,7 @@ def test_default_seasonal_sub_windows_attributes( assert default_seasonal_sub_windows_no_overlap.custom_file == None assert default_seasonal_sub_windows_no_overlap.available_temp_sub_wndws == [ - 'seasons', 'months' + 'seasons', 'months', 'stability' ] assert default_seasonal_sub_windows_no_overlap.names == [ diff --git a/tests/test_netcdf_transcription.py b/tests/test_netcdf_transcription.py index e33eda8..4319ce3 100644 --- a/tests/test_netcdf_transcription.py +++ b/tests/test_netcdf_transcription.py @@ -94,6 +94,13 @@ def monthly_qa4sm_file(TEST_DATA_DIR) -> Path: return Path(TEST_DATA_DIR / 'intra_annual' / 'monthly' / '0-ISMN.soil_moisture_with_1-C3S.sm_tsw_months_qa4sm.nc') +@pytest.fixture +def stability_pytesmo_file(TEST_DATA_DIR) -> Path: + return Path(TEST_DATA_DIR / 'intra_annual' / 'stability' / '0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_pytesmo.nc') + +@pytest.fixture +def stability_qa4sm_file(TEST_DATA_DIR) -> Path: + return Path(TEST_DATA_DIR / 'intra_annual' / 'stability' / '0-ESA_CCI_SM_passive.sm_with_1-ERA5_LAND.swvl1_tsw_stability_qa4sm.nc') #------------------Helper functions------------------------ @@ -255,6 +262,8 @@ def test_qr_globals_attributes(): "Oct": [[10, 1], [10, 31]], "Nov": [[11, 1], [11, 30]], "Dec": [[12, 1], [12, 31]], + }, + "stability":{ } } @@ -430,8 +439,7 @@ def test_bulk_case_transcription(TEST_DATA_DIR, tmp_paths): @log_function_call -def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, - monthly_pytesmo_file, monthly_qa4sm_file): +def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, monthly_pytesmo_file, monthly_qa4sm_file, stability_pytesmo_file, stability_qa4sm_file): ''' Test the transcription of the test files with the correct temporal sub-windows and the correct output nc files''' @@ -440,6 +448,8 @@ def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, assert seasonal_qa4sm_file.exists assert monthly_pytesmo_file.exists assert monthly_qa4sm_file.exists + assert stability_pytesmo_file.exists + assert stability_qa4sm_file.exists # instantiate proper TemporalSubWindowsCreator instances for the corresponding test files bulk_tsw = NewSubWindow( @@ -452,11 +462,30 @@ def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, monthly_tsws = TemporalSubWindowsCreator('months') monthly_tsws.add_temp_sub_wndw(bulk_tsw, insert_as_first_wndw=True) + stability_tsws = TemporalSubWindowsCreator(temporal_sub_window_type="stability") + stability_tsws.add_temp_sub_wndw(bulk_tsw, insert_as_first_wndw=True) + + + # Add annual sub-windows based on the years in the period + period = [datetime(year=2009, month=1, day=1), datetime(year=2022, month=12, day=31)] + years = list(range(period[0].year, period[1].year + 1)) + + globals.add_annual_subwindows(years) + + for key, value in globals.TEMPORAL_SUB_WINDOWS['custom'].items(): + new_subwindow = NewSubWindow( + name=key, + begin_date=datetime(value[0][0], value[0][1], value[0][2]), + end_date=datetime(value[1][0], value[1][1], value[1][2]), + ) + stability_tsws.add_temp_sub_wndw(new_subwindow) + + # make sure the above defined temporal sub-windows are indeed the ones on the expected output nc files - assert seasons_tsws.names == Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( - seasonal_qa4sm_file) - assert monthly_tsws.names == Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( - monthly_qa4sm_file) + assert seasons_tsws.names == Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile(seasonal_qa4sm_file) + assert monthly_tsws.names == Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile(monthly_qa4sm_file) + assert stability_tsws.names == Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile(stability_qa4sm_file) + # instantiate transcribers for the test files seasonal_transcriber = Pytesmo2Qa4smResultsTranscriber( @@ -469,17 +498,25 @@ def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, pytesmo_results=monthly_pytesmo_file, intra_annual_slices=monthly_tsws, keep_pytesmo_ncfile=False) + + stability_transcriber = Pytesmo2Qa4smResultsTranscriber( + pytesmo_results=stability_pytesmo_file, + intra_annual_slices=stability_tsws, + keep_pytesmo_ncfile=False) assert seasonal_transcriber.exists assert monthly_transcriber.exists + assert stability_transcriber.exists # get the transcribed datasets seasonal_transcribed_ds = seasonal_transcriber.get_transcribed_dataset() monthly_transcribed_ds = monthly_transcriber.get_transcribed_dataset() + stability_transcribed_ds = stability_transcriber.get_transcribed_dataset() # check that the transcribed datasets are indeed xarray.Dataset instances assert isinstance(seasonal_transcribed_ds, xr.Dataset) assert isinstance(monthly_transcribed_ds, xr.Dataset) + assert isinstance(stability_transcribed_ds, xr.Dataset) # check that the transcribed datasets are equal to the expected output files # xr.testing.assert_equal(ds1, ds2) runs a more detailed comparison of the two datasets as compared to ds1.equals(ds2) @@ -487,6 +524,8 @@ def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, expected_seasonal_ds = f with xr.open_dataset(monthly_qa4sm_file) as f: expected_monthly_ds = f + with xr.open_dataset(stability_qa4sm_file) as f: + expected_stability_ds = f #!NOTE: pytesmo/QA4SM offer the possibility to calculate Kendall's tau, but currently this metric is deactivated. #! Therefore, in a real validation run no tau related metrics will be transcribed to the QA4SM file, even though they might be present in the pytesmo file. @@ -510,13 +549,14 @@ def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, assert None == xr.testing.assert_equal( seasonal_transcribed_ds, expected_seasonal_ds) # returns None if the datasets are equal + assert None == xr.testing.assert_equal(stability_transcribed_ds, expected_stability_ds) # returns None if the datasets are equal # the method above does not check attrs of the datasets, so we do it here # Creation date and qa4sm_reader might differ, so we exclude them from the comparison datasets = [ monthly_transcribed_ds, expected_monthly_ds, seasonal_transcribed_ds, expected_seasonal_ds - ] + , stability_transcribed_ds, expected_stability_ds] attrs_to_be_excluded = ['date_created', 'qa4sm_version'] for ds in datasets: for attr in attrs_to_be_excluded: @@ -525,6 +565,7 @@ def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, assert seasonal_transcribed_ds.attrs == expected_seasonal_ds.attrs assert monthly_transcribed_ds.attrs == expected_monthly_ds.attrs + assert stability_transcribed_ds.attrs == expected_stability_ds.attrs # Compare the coordinate attributes for coord in seasonal_transcribed_ds.coords: @@ -555,11 +596,12 @@ def test_correct_file_transcription(seasonal_pytesmo_file, seasonal_qa4sm_file, seasonal_transcribed_ds.close() monthly_transcribed_ds.close() + stability_transcribed_ds.close() #TODO: refactoring @log_function_call -def test_plotting(seasonal_qa4sm_file, monthly_qa4sm_file, tmp_paths): +def test_plotting(seasonal_qa4sm_file, monthly_qa4sm_file, stability_qa4sm_file, tmp_paths): ''' Test the plotting of the test files with temporal sub-windows beyond the bulk case (this scenario covered in other tests) ''' @@ -572,6 +614,10 @@ def test_plotting(seasonal_qa4sm_file, monthly_qa4sm_file, tmp_paths): tmp_paths) tmp_monthly_dir = tmp_monthly_file.parent + tmp_stability_file, _ = get_tmp_single_test_file(stability_qa4sm_file, tmp_paths) + + tmp_stability_dir = tmp_stability_file.parent + # check the output directories pa.plot_all( @@ -583,10 +629,7 @@ def test_plotting(seasonal_qa4sm_file, monthly_qa4sm_file, tmp_paths): out_type=['png', 'svg'], ) - metrics_not_plotted = [ - *globals.metric_groups[0], *globals.metric_groups[3], - *globals._metadata_exclude - ] + metrics_not_plotted = [*globals.metric_groups['common'], *globals.metric_groups['triple'], *globals._metadata_exclude] tsw_dirs_expected = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( tmp_seasonal_file) @@ -694,6 +737,68 @@ def test_plotting(seasonal_qa4sm_file, monthly_qa4sm_file, tmp_paths): ), f"{tmp_monthly_dir / tsw / f'{tsw}_statistics_table.csv'} does not exist" + # now check the file with stability temporal sub-windows and without tcol metrics and the count of the plots + + pa.plot_all( + filepath=tmp_stability_file, + temporal_sub_windows=Pytesmo2Qa4smResultsTranscriber. + get_tsws_from_ncfile(tmp_stability_file), + out_dir=tmp_stability_dir, + save_all=True, + save_metadata=True, + out_type=['png', 'svg'], + ) + + tsw_dirs_expected = Pytesmo2Qa4smResultsTranscriber.get_tsws_from_ncfile( + tmp_stability_file) + + # Subfolders for tsw should not exist in the stability - case + for t, tsw in enumerate(tsw_dirs_expected): + if tsw == globals.DEFAULT_TSW: + assert Path(tmp_stability_dir / tsw).is_dir(), f"{tmp_stability_dir / tsw} is not a directory" + else: + assert not Path(tmp_stability_dir / tsw).exists(), f"{tmp_stability_dir / tsw} should not exist" + continue + + # no tcol metrics present here + for metric in [*list(globals.METRICS.keys())]: + if metric in metrics_not_plotted: + continue + # tsw specific plots + assert Path( + tmp_stability_dir / tsw / f"{tsw}_boxplot_{metric}.png" + ).exists( + ), f"{tmp_stability_dir / tsw / f'{tsw}_boxplot_{metric}.png'} does not exist" + assert Path( + tmp_stability_dir / tsw / f"{tsw}_boxplot_{metric}.svg" + ).exists( + ), f"{tmp_stability_dir / tsw / f'{tsw}_boxplot_{metric}.svg'} does not exist" + + if t == 0: + #comparison boxplots + assert Path(tmp_stability_dir / 'comparison_boxplots').is_dir() + assert Path( + tmp_stability_dir / 'comparison_boxplots' / + globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, + filetype='png') + ).exists( + ), f"{tmp_stability_dir / 'comparison_boxplots' / globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, filetype='png')} does not exist" + assert Path( + tmp_stability_dir / 'comparison_boxplots' / + globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, + filetype='svg') + ).exists( + ), f"{tmp_stability_dir / 'comparison_boxplots' / globals.CLUSTERED_BOX_PLOT_SAVENAME.format(metric=metric, filetype='svg')} does not exist" + assert Path( + tmp_stability_dir / tsw / f'{tsw}_statistics_table.csv' + ).is_file( + ), f"{tmp_stability_dir / tsw / f'{tsw}_statistics_table.csv'} does not exist" + + plot_dir = Path(tmp_stability_dir / globals.DEFAULT_TSW) + assert len(list(plot_dir.iterdir())) == 69 + assert all(file.suffix in [".png", ".svg", ".csv"] for file in plot_dir.iterdir()), "Not all files have been saved as .png or .csv" + + @log_function_call def test_write_to_netcdf_default(TEST_DATA_DIR, tmp_paths): temp_netcdf_file: Path = get_tmp_single_test_file( @@ -853,5 +958,4 @@ def test_is_valid_tcol_metric_name(seasonal_pytesmo_file, # keep_pytesmo_ncfile=True) # transcriber.pytesmo_results.close() # ds.close() - - test_bulk_case_transcription() + test_bulk_case_transcription() \ No newline at end of file