Skip to content

Commit

Permalink
Preserve image histogram range, when loading, processing and saving.
Browse files Browse the repository at this point in the history
Update cli interface to simplify and support PixInsight script integration
Add PyxInsight cli script
  • Loading branch information
p7ayfu77 committed Sep 28, 2024
1 parent 6335818 commit 8793c49
Show file tree
Hide file tree
Showing 9 changed files with 857 additions and 141 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ models/**
*.tif*
*.zip
*.tfevents.*
pjsr/
safe/
updates/
26 changes: 18 additions & 8 deletions astrodenoise/applayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def save(self, path):
self.dismiss_popup()

filepath = Path(path)
#AstroDeNoiseApp()

get_app(App.get_running_app()).lastpath = str(filepath.parent)

if self.fits_headers is None:
Expand All @@ -317,14 +317,19 @@ def save(self, path):
else:
result_tosave = np.array(self.currentimage.get_data('pre')[0])

# Clip to range [0,1] before save
result_tosave = np.clip(result_tosave, 0, 1)

try:
extension = filepath.suffix.lower()

if extension in supported_save_formats_fits:
result_forsave = np.moveaxis(np.transpose(result_tosave),1,2)
write_fits(filepath,result_forsave,headers=self.fits_headers)
write_fits(
filepath,
np.moveaxis(np.transpose(result_tosave),1,2),
headers=self.fits_headers)
elif extension in supported_save_formats_tiff:
result_tosave = (result_tosave - np.min(result_tosave)) / (np.max(result_tosave) - np.min(result_tosave))
# Scale to unit16 for tiff
result_tosave = (result_tosave * np.iinfo(np.uint16).max).astype(np.uint16)
imsave(filepath,data=result_tosave)
else:
Expand Down Expand Up @@ -459,7 +464,11 @@ def denoise(self, data, C=-2.8,B=0.25):

self.update_progress(0)
expand_low_actual = 0.5 - (self.expand_low/2)
normalizer = STFNormalizer(C=C,B=B,expand_low=expand_low_actual,do_after=False) if self.normalize_enabled else NoNormalizer(expand_low=expand_low_actual)
# Strength shifts pixel values to right of histogram by up to 0.5
#Strength=1 => Image + 0
#Strength=0 => Image + 0.5
normalizer = STFNormalizer(C=C,B=B,expand_low=expand_low_actual,do_after=False) if self.normalize_enabled else NoNormalizer(expand_low=expand_low_actual,do_after=False)


if self.denoise_enabled:
with tf.device(f"/{self.selected_device}:0"):
Expand All @@ -478,7 +487,7 @@ def denoise(self, data, C=-2.8,B=0.25):
result = normalizer.before(np.moveaxis(np.transpose(data),0,1),'YX')
self.update_progress(1)

return result
return result - expand_low_actual

@mainthread
def update_progress(self,progress):
Expand Down Expand Up @@ -526,8 +535,9 @@ def get_label_data(self):
}

def get_texture(self, result):
image = (result - np.min(result)) / (np.max(result) - np.min(result))
image = (image * 255).astype('uint8')

result = np.clip(result, 0, 1)
image = (result * 255).astype('uint8')

colorfmt='rgb'
if image.shape[2] == 1:
Expand Down
50 changes: 33 additions & 17 deletions astrodenoise/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys
import argparse
import tensorflow as tf
from tifffile import imread
from xisf import XISF
import numpy as npp
from pathlib import Path
from os.path import join as path_join
Expand All @@ -11,20 +13,27 @@
from astrodeep.utils.fits import read_fits, write_fits
from astrodenoise.version import modelversion

def get_exepath():
if getattr(sys, "frozen", False):
datadir = os.path.dirname(sys.executable)
else:
datadir = Path(os.path.dirname(__file__)).parent.as_posix()
return datadir

def cli():

parser = argparse.ArgumentParser()

parser.add_argument('input', type=str, nargs=1, help='Input image path, either tif or debayered fits file with data stored as 32bit float.')
parser.add_argument('--model','-m', type=str, default=modelversion, help='Alternative model name to use for de-noising.')
parser.add_argument('--models_folder', type=str, default='models', help='Alternative models folder root path.')
parser.add_argument('--tiles','-t', type=int, default=0, help='Use number of tiling slices when de-noising, useful for large images and limited memory.')
parser.add_argument('--tiles','-t', type=int, default=3, help='Use number of tiling slices when de-noising, useful for large images and limited memory.')
parser.add_argument('--overwrite','-o', action='store_true', help='Allow overwrite of existing output file. Default: False when not specified.')
parser.add_argument('--device','-d', choices=['GPU','CPU'], default='CPU', help='Optional select processing to target CPU or GCP. Default: CPU')
parser.add_argument('--normalize','-n', action='store_true', help='Enable STFNormalization before de-noising. Default: False when not specified.')
parser.add_argument('--norm-C', type=float, default=-2.8, help='C parameter for STF Normalization. Default: -2.8')
parser.add_argument('--norm-B', type=float, default=0.25, help='B parameter for STF Normalization, Higher B results in stronger stretch providing the ability target de-noising more effectively. . Default: 0.25, Range: 0 < B < 1')
parser.add_argument('--norm-restore', action='store_true', help='Restores output image to original data range after processing. Default: False when not specified.')
parser.add_argument('--norm-B', type=float, default=0.25, help='B parameter for STF Normalization, Higher B results in stronger stretch providing the ability target de-noising more effectively. . Default: 0.25, Range: 0 < B < 1')
parser.add_argument('--strength', type=float, default=0.5, help='The denoise strength applied. Default: 0.5')

args = parser.parse_args()

Expand All @@ -35,7 +44,9 @@ def predict(path,model):
if path.suffix in ['.fit','.fits']:
data, headers = read_fits(path)
elif path.suffix in ['.tif','.tiff']:
data, headers = npp.moveaxis(imread(path),-1,0), None
data, headers = npp.moveaxis(imread(path),-1,0), None
elif path.suffix in ['.xisf']:
data, headers = npp.moveaxis(XISF(path).read_image(0),-1,0), None
else:
print("Skipping unsupported format. Allowed formats: .tiff/.tif/.fits/.fit")
return
Expand All @@ -46,33 +57,38 @@ def predict(path,model):
if data.ndim == 2:
data = data[npp.newaxis,...]

print("Processing file:",path)
print("Image Dimensions:",data.shape)
print(f"Processing file:{path}\n")
print(f"Image Dimensions: {data.shape}\n")

n_tiles = None if args.tiles == 0 else (args.tiles,args.tiles)
n_tiles = None if args.tiles == 0 else (args.tiles, args.tiles)
if n_tiles is not None:
print("Processing with tilling:",n_tiles)

output_denoised = []

axes = 'YX'
normalizer = STFNormalizer(C=args.norm_C,B=args.norm_B,do_after=args.norm_restore) if args.normalize is True else NoNormalizer()
print("Using Normalization:",normalizer.params)
expand_low_actual = 0.5 - (args.strength/2)
normalizer = STFNormalizer(C=args.norm_C,B=args.norm_B,expand_low=expand_low_actual,do_after=True) if args.normalize is True else NoNormalizer(expand_low=expand_low_actual,do_after=True)
print(f"Using Normalization: {normalizer.params}\n")
print(f"Using Strength: {args.strength}\n")

output_denoised = []
for c in data:
output_denoised.append(
model.predict(c, axes, normalizer=normalizer,resizer=PadAndCropResizer(), n_tiles=n_tiles)
model.predict(c, axes, normalizer=normalizer, resizer=PadAndCropResizer(), n_tiles=n_tiles)
)

output_denoised_arr = npp.asarray(output_denoised)
# Clip to [0,1] range
output = output_denoised_arr.clip(0,1)

output_file_name = path.stem + f"_denoised.fits"
output_path = path_join(path.parent, 'denoised')
Path(output_path).mkdir(exist_ok=True)
output_file_path = path_join(output_path, output_file_name)
write_fits(output_file_path, output_denoised, headers, args.overwrite)
print("Output file saved:", output_file_path)
output_file_path = path_join(path.parent, output_file_name)
write_fits(output_file_path, output, headers, args.overwrite)

print("Output file saved:", output_file_path)

print("Loading model:", args.model)
model = CARE(config=None, name=args.model, basedir=args.models_folder)
model = CARE(config=None, name=args.model, basedir=Path(get_exepath()).joinpath(args.models_folder).as_posix())
file_or_path = args.input[0]

if os.path.isfile(file_or_path):
Expand Down
Loading

0 comments on commit 8793c49

Please sign in to comment.