Skip to content

Commit

Permalink
STYLE: Small edits requested in PR
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepace committed Nov 14, 2023
1 parent 8d8f6b0 commit 3967219
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 73 deletions.
14 changes: 6 additions & 8 deletions ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def parse_args():
parser.add_argument('--dicoms', default='./dicoms/', help='Path to folder of dicoms.')
parser.add_argument('--sample_csv', default=None, help='Path to CSV with Sample IDs to restrict tensor paths')
parser.add_argument('--tsv_style', default='standard', choices=['standard', 'genetics'], help='Format choice for the TSV file produced in output by infer and explore modes.')
parser.add_argument('--app_csv', help='Path to file used to link sample IDs between UKBB applications 17488 and 7089')
parser.add_argument('--app_csv', help='Path to file used by the recipe')
parser.add_argument('--tensors', help='Path to folder containing tensors, or where tensors will be written.')
parser.add_argument('--output_folder', default='./recipes_output/', help='Path to output folder for recipes.py runs.')
parser.add_argument('--model_file', help='Path to a saved model architecture and weights (hd5).')
Expand Down Expand Up @@ -273,9 +273,9 @@ def parse_args():
)

# 2D image data augmentation parameters
parser.add_argument('--rotation_factor', default=0., type=float, help='a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees]')
parser.add_argument('--zoom_factor', default=0., type=float, help='a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%]')
parser.add_argument('--translation_factor', default=0., type=float, help='a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions')
parser.add_argument('--rotation_factor', default=0., type=float, help='for data augmentation, a float represented as fraction of 2 Pi, e.g., rotation_factor = 0.014 results in an output rotated by a random amount in the range [-5 degrees, 5 degrees]')
parser.add_argument('--zoom_factor', default=0., type=float, help='for dat augmentation, a float represented as fraction of value, e.g., zoom_factor = 0.05 results in an output zoomed in a random amount in the range [-5%, 5%]')
parser.add_argument('--translation_factor', default=0., type=float, help='for data augmentation, a float represented as a fraction of value, e.g., translation_factor = 0.05 results in an output shifted by a random amount in the range [-5%, 5%] in the x- and y- directions')

# Run specific and debugging arguments
parser.add_argument('--id', default='no_id', help='Identifier for this run, user-defined string to keep experiments organized.')
Expand Down Expand Up @@ -384,15 +384,13 @@ def parse_args():
default='3M',
)

# Arguments for explorations/infer_medians
# Arguments for explorations/infer_stats_from_segmented_regions
parser.add_argument('--analyze_ground_truth', default=True, help='Whether or not to filter by images with ground truth segmentations, for comparison')
parser.add_argument('--dates_file', help='File containing dates for each sample_id')
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv outputs')
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots')
parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing')
parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing')
parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold')
parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name')
parser.add_argument('--results_to_plot', nargs='*', default=[], help='Structure names to make scatter plots')

# TensorMap prefix for convenience
parser.add_argument('--tensormap_prefix', default="ml4h.tensormap", type=str, help="Module prefix path for TensorMaps. Defaults to \"ml4h.tensormap\"")
Expand Down
123 changes: 66 additions & 57 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,14 +726,70 @@ def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig
y = _to_categorical(y, nb_orig_channels)
return y

def infer_medians(args):
assert(args.batch_size == 1) # no support here for iterating over larger batches
assert(len(args.tensor_maps_in) == 1) # no support here for multiple input maps
assert(len(args.tensor_maps_out) == 1) # no support here for multiple output channels
def _scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, structures_to_analyze,
output_folder, id, input_name, output_name,
):
df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)
df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)

results_to_plot = [f'{s}_median' for s in structures_to_analyze]
for col in results_to_plot:
for i in ['all', 'filter_outliers']: # Two types of plots
plot_data = pd.concat(
[df_true['sample_id'], df_true[col], df_pred[col]],
axis=1, keys=['sample_id', 'true', 'pred'],
)

if i == 'all':
true_outliers = plot_data[plot_data.true == 0]
pred_outliers = plot_data[plot_data.pred == 0]
logging.info(f'sample_ids where {col} is zero in the manual segmentation:')
logging.info(true_outliers['sample_id'].to_list())
logging.info(f'sample_ids where {col} is zero in the model segmentation:')
logging.info(pred_outliers['sample_id'].to_list())
elif i == 'filter_outliers':
plot_data = plot_data[plot_data.true != 0]
plot_data = plot_data[plot_data.pred != 0]
plot_data = plot_data.drop('sample_id', axis=1)

plt.figure()
g = lmplot(x='true', y='pred', data=plot_data)
ax = plt.gca()
title = col.replace('_', ' ')
ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation')
ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation')
if i == 'all':
min_value = -50
max_value = 1300
elif i == 'filter_outliers':
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
ax.set_xlim([min_value, max_value])
ax.set_ylim([min_value, max_value])
res = stats.pearsonr(plot_data['true'], plot_data['pred'])
conf = res.confidence_interval(confidence_level=0.95)
text = f'Pearson Correlation Coefficient r={res.statistic:.2f},\n95% CI {conf.low:.2f} - {conf.high:.2f}'
ax.text(0.25, 0.1, text, transform=ax.transAxes)
if i == 'all':
postfix = ''
elif i == 'filter_outliers':
postfix = '_no_outliers'
logging.info(f'{col} pearson{postfix} {res.statistic}')
figure_path = os.path.join(
output_folder, id, f'{col}_{id}_{input_name}_{output_name}{postfix}.png',
)
plt.savefig(figure_path)

def infer_stats_from_segmented_regions(args):
assert(args.batch_size == 1, 'no support here for iterating over larger batches')
assert(len(args.tensor_maps_in) == 1, 'no support here for multiple input maps')
assert(len(args.tensor_maps_out) == 1, 'no support here for multiple output channels')

tm_in = args.tensor_maps_in[0]
tm_out = args.tensor_maps_out[0]
assert(tm_in.shape[-1] == 1) # no support here for stats on multiple input channels
assert(tm_in.shape[-1] == 1, 'no support here for stats on multiple input channels')

# don't filter datasets for ground truth segmentations if we want to run inference on everything
# TODO HELP - this isn't giving me all 56K anymore
Expand Down Expand Up @@ -762,7 +818,7 @@ def infer_medians(args):
intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure]

# Get the dates
with open(args.dates_file, mode='r') as dates_file:
with open(args.app_csv, mode='r') as dates_file:
dates_reader = csv.reader(dates_file)
dates_dict = {rows[0]:rows[1] for rows in dates_reader}

Expand Down Expand Up @@ -832,57 +888,10 @@ def infer_medians(args):

# Scatter plots
if args.analyze_ground_truth:
df_true = pd.read_csv(inference_tsv_true, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)
df_pred = pd.read_csv(inference_tsv_pred, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)

for col in args.results_to_plot:
for i in ['all', 'filter_outliers']: # Two types of plots
plot_data = pd.concat(
[df_true['sample_id'], df_true[col], df_pred[col]],
axis=1, keys=['sample_id', 'true', 'pred'],
)

if i == 'all':
true_outliers = plot_data[plot_data.true == 0]
pred_outliers = plot_data[plot_data.pred == 0]
logging.info(f'sample_ids where {col} is zero in the manual segmentation:')
logging.info(true_outliers['sample_id'].to_list())
logging.info(f'sample_ids where {col} is zero in the model segmentation:')
logging.info(pred_outliers['sample_id'].to_list())
elif i == 'filter_outliers':
plot_data = plot_data[plot_data.true != 0]
plot_data = plot_data[plot_data.pred != 0]
plot_data = plot_data.drop('sample_id', axis=1)

plt.figure()
g = lmplot(x='true', y='pred', data=plot_data)
ax = plt.gca()
title = col.replace('_', ' ')
ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation')
ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation')
if i == 'all':
min_value = -50
max_value = 1300
elif i == 'filter_outliers':
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
ax.set_xlim([min_value, max_value])
ax.set_ylim([min_value, max_value])
res = stats.pearsonr(plot_data['true'], plot_data['pred'])
conf = res.confidence_interval(confidence_level=0.95)
text = f'Pearson Correlation Coefficient r={res.statistic:.2f},\n95% CI {conf.low:.2f} - {conf.high:.2f}'
ax.text(0.25, 0.1, text, transform=ax.transAxes)
if i == 'all':
postfix = ''
elif i == 'filter_outliers':
postfix = '_no_outliers'
logging.info(f'{col} pearson{postfix} {res.statistic}')
figure_path = os.path.join(
args.output_folder, args.id,
f'{col}_{args.id}_{tm_in.input_name()}_{output_name}{postfix}.png',
)
plt.savefig(figure_path)
_scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, args.structures_to_analyze,
args.output_folder, args.id, tm_in.input_name(), args.output_name,
)

def _softmax(x):
"""Compute softmax values for each sets of scores in x."""
Expand Down
6 changes: 3 additions & 3 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ml4h.tensorize.tensor_writer_mgb import write_tensors_mgb
from ml4h.models.model_factory import make_multimodal_multitask_model
from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX
from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_medians
from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_stats_from_segmented_regions
from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model
from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv
from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator
Expand Down Expand Up @@ -66,8 +66,8 @@ def run(args):
infer_hidden_layer_multimodal_multitask(args)
elif 'infer_pixels' == args.mode:
infer_with_pixels(args)
elif 'infer_medians' == args.mode:
infer_medians(args)
elif 'infer_stats_from_segmented_regions' == args.mode:
infer_stats_from_segmented_regions(args)
elif 'infer_encoders' == args.mode:
infer_encoders_block_multimodal_multitask(args)
elif 'test_scalar' == args.mode:
Expand Down
17 changes: 12 additions & 5 deletions ml4h/tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,13 @@ def get_train_valid_test_paths(

logging.info(f'Found {len(train_paths)} train, {len(valid_paths)} validation, and {len(test_paths)} testing tensors at: {tensors}')
logging.debug(f'Discarded {len(discard_paths)} tensors due to given ratios')
if len(train_paths) == 0 and len(valid_paths) == 0 and len(test_paths) == 0:
raise ValueError(
f'Not enough tensors at {tensors}\n'
f'Found {len(train_paths)} training, {len(valid_paths)} validation, and {len(test_paths)} testing tensors\n'
f'Discarded {len(discard_paths)} tensors',
)

return train_paths, valid_paths, test_paths


Expand Down Expand Up @@ -728,8 +735,8 @@ def get_train_valid_test_paths_split_by_csvs(
# https://stackoverflow.com/questions/65475057/keras-data-augmentation-pipeline-for-image-segmentation-dataset-image-and-mask
def augment_using_layers(images, mask, in_shapes, out_shapes, rotation_factor, zoom_factor, translation_factor):

assert(len(in_shapes) == 1) # no support for multiple inputs
assert(len(out_shapes) == 1) # no support for mulitple outputs
assert(len(in_shapes) == 1, 'no support for multiple inputs')
assert(len(out_shapes) == 1, 'no support for mulitple outputs')

def aug():
rota = tf.keras.layers.RandomRotation(factor=rotation_factor, fill_mode='constant')
Expand Down Expand Up @@ -877,11 +884,11 @@ def test_train_valid_tensor_generators(
)

do_augmentation = bool(rotation_factor or zoom_factor or translation_factor)
logging.info(f'doing_augmentation {do_augmentation}')
logging.info(f'doing_augmentation {do_augmentation} with rotation {rotation_factor}, zoom {zoom_factor}, translation {translation_factor}')

if do_augmentation:
assert(len(tensor_maps_in) == 1) # no support for multiple input tensors
assert(len(tensor_maps_out) == 1) # no support for multiple output tensors
assert(len(tensor_maps_in) == 1, 'no support for multiple input tensors')
assert(len(tensor_maps_out) == 1, 'no support for multiple output tensors')

if wrap_with_tf_dataset or do_augmentation:
in_shapes = {tm.input_name(): (batch_size,) + tm.static_shape() for tm in tensor_maps_in}
Expand Down

0 comments on commit 3967219

Please sign in to comment.