diff --git a/src/patch_denoise/bindings/cli.py b/src/patch_denoise/bindings/cli.py index ea5890c..25bc64d 100644 --- a/src/patch_denoise/bindings/cli.py +++ b/src/patch_denoise/bindings/cli.py @@ -216,7 +216,17 @@ def _get_parser(): type=IsFile, help=( "Phase of the input data. This MUST be in radians. " - "No conversion would be applied." + "No rescaling will be applied." + ), + ) + data_group.add_argument( + "--noise-map-phase", + metavar="FILE", + default=None, + type=IsFile, + help=( + "Phase component of the noise map estimation file. " + "This MUST be in radians. No rescaling will be applied." ), ) @@ -283,17 +293,19 @@ def main(): if args.input_phase is not None: input_data, affine = load_complex_nifti(args.input_file, args.input_phase) - input_data, affine = load_as_array(args.input_file) + else: + input_data, affine = load_as_array(args.input_file) kwargs = args.extra if args.nan_to_num is not None: input_data = np.nan_to_num(input_data, nan=args.nan_to_num) + n_nans = np.isnan(input_data).sum() if n_nans > 0: logging.warning( - f"{n_nans}/{np.prod(input_data.shape)} voxels are NaN." - " You might want to use --nan-to-num=", + f"{n_nans}/{input_data.size} voxels are NaN. " + "You might want to use --nan-to-num=", stacklevel=0, ) @@ -302,14 +314,30 @@ def main(): affine_mask = None else: mask, affine_mask = load_as_array(args.mask) - noise_map, affine_noise_map = load_as_array(args.noise_map) + + if args.noise_map is not None and args.noise_map_phase is not None: + noise_map, affine_noise_map = load_complex_nifti( + args.noise_map, + args.noise_map_phase, + ) + elif args.noise_map is not None: + noise_map, affine_noise_map = load_as_array(args.noise_map) + elif args.noise_map_phase is not None: + raise ValueError( + "The phase component of the noise map has been provided, " + "but not the magnitude." + ) + else: + noise_map = None + affine_noise_map = None if affine is not None: - if affine_mask is not None and np.allclose(affine, affine_mask): + if (affine_mask is not None) and not np.allclose(affine, affine_mask): logging.warning( "Affine matrix of input and mask does not match", stacklevel=2 ) - if affine_noise_map is not None and np.allclose(affine, affine_noise_map): + + if (affine_noise_map is not None) and not np.allclose(affine, affine_noise_map): logging.warning( "Affine matrix of input and noise map does not match", stacklevel=2 ) @@ -344,7 +372,7 @@ def main(): if noise_map is None: raise RuntimeError("A noise map must be specified for this method.") - denoised_data, patchs_weight, noise_std_map, rank_map = denoise_func( + denoised_data, _, noise_std_map, _ = denoise_func( input_data, patch_shape=args.patch_shape, patch_overlap=args.patch_overlap, diff --git a/tests/test_spacetime_utils.py b/tests/test_spacetime_utils.py index 4272fef..8131795 100644 --- a/tests/test_spacetime_utils.py +++ b/tests/test_spacetime_utils.py @@ -70,13 +70,30 @@ def f(x): @pytest.mark.parametrize("block_dim", range(5, 10)) -def test_noise_estimation(medium_random_matrix, block_dim): - """Test noise estimation.""" - noise_map = estimate_noise(medium_random_matrix, block_dim) - - real_std = np.nanstd(medium_random_matrix) - err = np.nanmean(noise_map - real_std) - assert err <= 0.1 * real_std +def test_noise_estimation(block_dim): + """Test noise estimation. + + The mean patch-wise standard deviation should be close to the overall + standard deviation. + """ + for seed in range(15): + print(f"Seed: {seed}") + rng = np.random.RandomState(seed) + medium_random_matrix = rng.randn(200, 200, 100) + print(f"Mean of raw: {np.nanmean(medium_random_matrix)}") + print(f"Max of raw: {np.nanmax(medium_random_matrix)}") + print(f"Min of raw: {np.nanmin(medium_random_matrix)}") + real_std = np.nanstd(medium_random_matrix) + print(f"SD of raw: {real_std}") + + noise_map = estimate_noise(medium_random_matrix, block_dim) + print(f"Mean of noise map: {np.nanmean(noise_map)}") + print(f"Max of noise map: {np.nanmax(noise_map)}") + print(f"Min of noise map: {np.nanmin(noise_map)}") + print(f"SD of noise map: {np.nanstd(noise_map)}") + err = np.nanmean(noise_map - real_std) + print(f"Err: {err}") + assert err <= 0.1 * real_std @parametrize_random_matrix