Skip to content

Commit

Permalink
Merge branch 'master' into sf_xdl
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix authored Jan 17, 2024
2 parents 762692d + 0064b83 commit 5b0dbec
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 86 deletions.
2 changes: 1 addition & 1 deletion docker/vm_boot_images/config/tensorflow-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ google-cloud-storage
umap-learn[plot]
neurite
voxelmorph
pystrum
pystrum
6 changes: 4 additions & 2 deletions ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def parse_args():
)
parser.add_argument('--balance_csvs', default=[], nargs='*', help='Balances batches with representation from sample IDs in this list of CSVs')
parser.add_argument('--optimizer', default='radam', type=str, help='Optimizer for model training')
parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2'], help='Adjusts learning rate during training.')
parser.add_argument('--learning_rate_schedule', default=None, type=str, choices=['triangular', 'triangular2', 'cosine_decay'], help='Adjusts learning rate during training.')
parser.add_argument('--anneal_rate', default=0., type=float, help='Annealing rate in epochs of loss terms during training')
parser.add_argument('--anneal_shift', default=0., type=float, help='Annealing offset in epochs of loss terms during training')
parser.add_argument('--anneal_max', default=2.0, type=float, help='Annealing maximum value')
Expand Down Expand Up @@ -385,10 +385,12 @@ def parse_args():
)

# Arguments for explorations/infer_stats_from_segmented_regions
parser.add_argument('--analyze_ground_truth', default=True, help='Whether or not to filter by images with ground truth segmentations, for comparison')
parser.add_argument('--analyze_ground_truth', default=False, action='store_true', help='Whether or not to filter by images with ground truth segmentations, for comparison')
parser.add_argument('--structures_to_analyze', nargs='*', default=[], help='Structure names to include in the .tsv files and scatter plots')
parser.add_argument('--erosion_radius', default=1, type=int, help='Radius of the unit disk structuring element for erosion preprocessing')
parser.add_argument('--intensity_thresh', type=float, help='Threshold value for preprocessing')
parser.add_argument('--intensity_thresh_percentile', type=float, help='Threshold percentile for preprocessing, between 0 and 100 inclusive')
parser.add_argument('--intensity_thresh_k_means', nargs='*', default=[], type=int, help='Preprocessing using k-means specified as two numbers, the first is the number of clusters and the second is the cluster index to keep')
parser.add_argument('--intensity_thresh_in_structures', nargs='*', default=[], help='Structure names whose pixels should be replaced if the images has intensity above the threshold')
parser.add_argument('--intensity_thresh_out_structure', help='Replacement structure name')

Expand Down
5 changes: 4 additions & 1 deletion ml4h/defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def __str__(self):
'aortic_root': 7, 'ascending_aorta': 8, 'pulmonary_artery': 9, 'ascending_aortic_wall': 10, 'LVOT': 11,
}
MRI_LIVER_SEGMENTED_CHANNEL_MAP = {'background': 0, 'liver': 1, 'inferior_vena_cava': 2, 'abdominal_aorta': 3, 'body': 4}

MRI_PANCREAS_SEGMENTED_CHANNEL_MAP = {
'background': 0, 'body': 1, 'pancreas': 2, 'liver': 3, 'stomach': 4, 'spleen': 5,
'kidney': 6, 'bowel': 7, 'spine': 8, 'aorta':9, 'ivc': 10,
}

# TODO: These values should ultimately come from the coding table
CODING_VALUES_LESS_THAN_ONE = [-10, -1001]
Expand Down
45 changes: 30 additions & 15 deletions ml4h/explorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas as pd
import multiprocessing as mp
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from tensorflow.keras.models import Model


Expand All @@ -37,7 +38,7 @@
from ml4h.plots import evaluate_predictions, subplot_rocs, subplot_scatters, plot_categorical_tmap_over_time
from ml4h.defines import JOIN_CHAR, MRI_SEGMENTED_CHANNEL_MAP, CODING_VALUES_MISSING, CODING_VALUES_LESS_THAN_ONE
from ml4h.defines import TENSOR_EXT, IMAGE_EXT, ECG_CHAR_2_IDX, ECG_IDX_2_CHAR, PARTNERS_CHAR_2_IDX, PARTNERS_IDX_2_CHAR, PARTNERS_READ_TEXT
from ml4h.tensorize.tensor_writer_ukbb import _unit_disk
from ml4h.tensorize.tensor_writer_ukbb import unit_disk

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -719,13 +720,26 @@ def _get_csv_row(sample_id, means, medians, stds, date):
csv_row = [sample_id] + res[0].astype('str').tolist() + [date]
return csv_row

def _thresh_labels_above(y, img, intensity_thresh, in_labels, out_label, nb_orig_channels):
def _thresh_labels_above(y, img, intensity_thresh, intensity_thresh_percentile, in_labels, out_label, nb_orig_channels):
y = np.argmax(y, axis=-1)[..., np.newaxis]
y[np.logical_and(img >= intensity_thresh, np.isin(y, in_labels))] = out_label
if intensity_thresh:
img_intensity_thresh = intensity_thresh
elif intensity_thresh_percentile:
img_intensity_thresh = np.percentile(img, intensity_thresh_percentile)
y[np.logical_and(img >= img_intensity_thresh, np.isin(y, in_labels))] = out_label
y = y[..., 0]
y = _to_categorical(y, nb_orig_channels)
return y

def _intensity_thresh_k_means(y, img, intensity_thresh_k_means):
X = img[y==1][...,np.newaxis]
if X.size > 1:
kmeans = KMeans(n_clusters=intensity_thresh_k_means[0], random_state=0, n_init="auto").fit(X)
labels = kmeans.predict(img.flatten()[...,np.newaxis])
labels = np.reshape(labels, img.shape)
y[np.logical_and(labels==intensity_thresh_k_means[1], y==1)] = 0
return y

def _scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, structures_to_analyze,
output_folder, id, input_name, output_name,
Expand Down Expand Up @@ -759,13 +773,9 @@ def _scatter_plots_from_segmented_region_stats(
title = col.replace('_', ' ')
ax.set_xlabel(f'{title} T1 Time (ms) - Manual Segmentation')
ax.set_ylabel(f'{title} T1 Time (ms) - Model Segmentation')
if i == 'all':
min_value = -50
max_value = 1300
elif i == 'filter_outliers':
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
min_value, max_value = plot_data.min(), plot_data.max()
min_value = min([min_value['true'], min_value['pred']]) - 100
max_value = min([max_value['true'], max_value['pred']]) + 100
ax.set_xlim([min_value, max_value])
ax.set_ylim([min_value, max_value])
res = stats.pearsonr(plot_data['true'], plot_data['pred'])
Expand Down Expand Up @@ -798,7 +808,6 @@ def infer_stats_from_segmented_regions(args):
assert(tm_in.shape[-1] == 1, 'no support here for stats on multiple input channels')

# don't filter datasets for ground truth segmentations if we want to run inference on everything
# TODO HELP - this isn't giving me all 56K anymore
if not args.analyze_ground_truth:
args.output_tensors = []
args.tensor_maps_out = []
Expand All @@ -815,11 +824,13 @@ def infer_stats_from_segmented_regions(args):

# Structuring element used for the erosion
if args.erosion_radius > 0:
structure = _unit_disk(args.erosion_radius)[np.newaxis, ..., np.newaxis]
structure = unit_disk(args.erosion_radius)[np.newaxis, ..., np.newaxis]

# Setup for intensity thresholding
do_intensity_thresh = args.intensity_thresh_in_structures and args.intensity_thresh_out_structure
if do_intensity_thresh:
assert (not (args.intensity_thresh and args.intensity_thresh_percentile))
assert (not (args.intensity_thresh_k_means and len(args.intensity_thresh_in_structures) > 1))
intensity_thresh_in_channels = [tm_out.channel_map[k] for k in args.intensity_thresh_in_structures]
intensity_thresh_out_channel = tm_out.channel_map[args.intensity_thresh_out_structure]

Expand Down Expand Up @@ -870,19 +881,23 @@ def infer_stats_from_segmented_regions(args):

if args.analyze_ground_truth:
if do_intensity_thresh:
y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_true = _thresh_labels_above(y_true, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_true = np.delete(y_true, bad_channels, axis=-1)
if args.erosion_radius > 0:
y_true = binary_erosion(y_true, structure).astype(y_true.dtype)
if args.intensity_thresh_k_means:
y_true = _intensity_thresh_k_means(y_true, img, args.intensity_thresh_k_means)
means_true, medians_true, stds_true = _compute_masked_stats(rescaled_img, y_true, nb_good_channels)
csv_row_true = _get_csv_row(sample_id, means_true, medians_true, stds_true, date)
inference_writer_true.writerow(csv_row_true)

if do_intensity_thresh:
y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_pred = _thresh_labels_above(y_pred, img, args.intensity_thresh, args.intensity_thresh_percentile, intensity_thresh_in_channels, intensity_thresh_out_channel, nb_orig_channels)
y_pred = np.delete(y_pred, bad_channels, axis=-1)
if args.erosion_radius > 0:
y_pred = binary_erosion(y_pred, structure).astype(y_pred.dtype)
if args.intensity_thresh_k_means:
y_pred = _intensity_thresh_k_means(y_pred, img, args.intensity_thresh_k_means)
means_pred, medians_pred, stds_pred = _compute_masked_stats(rescaled_img, y_pred, nb_good_channels)
csv_row_pred = _get_csv_row(sample_id, means_pred, medians_pred, stds_pred, date)
inference_writer_pred.writerow(csv_row_pred)
Expand All @@ -896,7 +911,7 @@ def infer_stats_from_segmented_regions(args):
if args.analyze_ground_truth:
_scatter_plots_from_segmented_region_stats(
inference_tsv_true, inference_tsv_pred, args.structures_to_analyze,
args.output_folder, args.id, tm_in.input_name(), args.output_name,
args.output_folder, args.id, tm_in.input_name(), tm_out.output_name(),
)

def _softmax(x):
Expand Down
8 changes: 7 additions & 1 deletion ml4h/models/layer_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, Average, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D
from tensorflow.keras.regularizers import L1, L2


Tensor = tf.Tensor
Expand Down Expand Up @@ -52,9 +53,14 @@
# class name -> (dimension -> class)
'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D},
'dropout': defaultdict(lambda _: Dropout),
'l1': L1,
'l2': L2,
}
DENSE_REGULARIZATION_CLASSES = {
'dropout': Dropout, # TODO: add l1, l2
'dropout': Dropout,
'dropout': Dropout,
'l1': L1,
'l2': L2,
}


Expand Down
7 changes: 6 additions & 1 deletion ml4h/models/legacy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D
from tensorflow.keras.layers.experimental.preprocessing import RandomRotation, RandomZoom, RandomContrast
from tensorflow.keras.regularizers import L1, L2
import tensorflow_probability as tfp

from ml4h.metrics import get_metric_dict
Expand Down Expand Up @@ -79,9 +80,13 @@ class BottleneckType(Enum):
# class name -> (dimension -> class)
'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D},
'dropout': defaultdict(lambda _: Dropout),
'l1': L1,
'l2': L2,
}
DENSE_REGULARIZATION_CLASSES = {
'dropout': Dropout, # TODO: add l1, l2
'dropout': Dropout,
'l1': L1,
'l2': L2,
}


Expand Down
3 changes: 3 additions & 0 deletions ml4h/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow_addons.optimizers import RectifiedAdam, TriangularCyclicalLearningRate, Triangular2CyclicalLearningRate
from tensorflow.keras.optimizers.schedules import CosineDecay

from ml4h.plots import plot_find_learning_rate
from ml4h.tensor_generators import TensorGenerator
Expand Down Expand Up @@ -40,6 +41,8 @@ def _get_learning_rate_schedule(learning_rate: float, learning_rate_schedule: st
initial_learning_rate=learning_rate / 5, maximal_learning_rate=learning_rate,
step_size=steps_per_epoch * 5,
)
if learning_rate_schedule == 'cosine_decay':
return CosineDecay(initial_learning_rate=learning_rate, decay_steps=steps_per_epoch)
else:
raise ValueError(f'Learning rate schedule "{learning_rate_schedule}" unknown.')

Expand Down
77 changes: 38 additions & 39 deletions ml4h/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,17 @@
from ml4h.models.model_factory import make_multimodal_multitask_model
from ml4h.ml4ht_integration.tensor_generator import TensorMapDataLoader2
from ml4h.tensor_generators import BATCH_INPUT_INDEX, BATCH_OUTPUT_INDEX, BATCH_PATHS_INDEX

from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, infer_stats_from_segmented_regions
from ml4h.explorations import test_labels_to_label_map, infer_with_pixels, latent_space_dataframe, infer_stats_from_segmented_regions
from ml4h.explorations import mri_dates, ecg_dates, predictions_to_pngs, sample_from_language_model
from ml4h.plots import plot_roc, plot_precision_recall_per_class, plot_scatter
from ml4h.explorations import plot_while_learning, plot_histograms_of_tensors_in_pdf, explore, pca_on_tsv
from ml4h.models.legacy_models import get_model_inputs_outputs, make_shallow_model, make_hidden_layer_model
from ml4h.tensor_generators import TensorGenerator, test_train_valid_tensor_generators, big_batch_from_minibatch_generator
from ml4h.data_descriptions import dataframe_data_description_from_tensor_map, ECGDataDescription, DataFrameDataDescription
from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, \
log_pearson_coefficients, concordance_index_censored

from ml4h.plots import plot_dice
from ml4h.plots import plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp

from ml4h.metrics import get_roc_aucs, get_precision_recall_aucs, get_pearson_coefficients, log_aucs, log_pearson_coefficients, concordance_index_censored
from ml4h.plots import plot_dice, plot_reconstruction, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival, plot_dice
from ml4h.plots import plot_reconstruction, plot_hit_to_miss_transforms, plot_saliency_maps, plot_partners_ecgs, plot_ecg_rest_mp
from ml4h.plots import subplot_rocs, subplot_comparison_rocs, subplot_scatters, subplot_comparison_scatters, plot_prediction_calibrations
from ml4h.models.legacy_models import make_character_model_plus, embed_model_predict, make_siamese_model, legacy_multimodal_multitask_model
from ml4h.plots import evaluate_predictions, plot_scatters, plot_rocs, plot_precision_recalls, subplot_roc_per_class, plot_tsne, plot_survival
Expand Down Expand Up @@ -143,7 +140,7 @@ def run(args):

except Exception as e:
logging.exception(e)

if args.gcs_cloud_bucket is not None:
save_to_google_cloud(args)

Expand Down Expand Up @@ -225,37 +222,39 @@ def train_multimodal_multitask(args):
if merger:
merger.save(f'{args.output_folder}{args.id}/merger.h5')

test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
performance_metrics = _predict_and_evaluate(
model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected,
args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths,
args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height,
)
performance_metrics = {}
if args.test_steps > 0:
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
performance_metrics = _predict_and_evaluate(
model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, args.tensor_maps_protected,
args.batch_size, args.hidden_layer, os.path.join(args.output_folder, args.id + '/'), test_paths,
args.embed_visualization, args.alpha, args.dpi, args.plot_width, args.plot_height,
)

predictions_list = model.predict(test_data)
samples = min(args.test_steps * args.batch_size, 12)
out_path = os.path.join(args.output_folder, args.id, 'reconstructions/')
if len(args.tensor_maps_out) == 1:
predictions_list = [predictions_list]
predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)}
logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}')

for i, etm in enumerate(encoders):
embed = encoders[etm].predict(test_data[etm.input_name()])
if etm.output_name() in predictions_dict:
plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples)
for dtm in decoders:
reconstruction = decoders[dtm].predict(embed)
logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}')
my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/')
os.makedirs(os.path.dirname(my_out_path), exist_ok=True)
if dtm.axes() > 1:
plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples)
else:
evaluate_predictions(
dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path,
test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height,
)
predictions_list = model.predict(test_data)
samples = min(args.test_steps * args.batch_size, 12)
out_path = os.path.join(args.output_folder, args.id, 'reconstructions/')
if len(args.tensor_maps_out) == 1:
predictions_list = [predictions_list]
predictions_dict = {name: pred for name, pred in zip(model.output_names, predictions_list)}
logging.info(f'Predictions and shapes are: {[(p, predictions_dict[p].shape) for p in predictions_dict]}')

for i, etm in enumerate(encoders):
embed = encoders[etm].predict(test_data[etm.input_name()])
if etm.output_name() in predictions_dict:
plot_reconstruction(etm, test_data[etm.input_name()], predictions_dict[etm.output_name()], out_path, test_paths, samples)
for dtm in decoders:
reconstruction = decoders[dtm].predict(embed)
logging.info(f'{dtm.name} has prediction shape: {reconstruction.shape} from embed shape: {embed.shape}')
my_out_path = os.path.join(out_path, f'decoding_{dtm.name}_from_{etm.name}/')
os.makedirs(os.path.dirname(my_out_path), exist_ok=True)
if dtm.axes() > 1:
plot_reconstruction(dtm, test_labels[dtm.output_name()], reconstruction, my_out_path, test_paths, samples)
else:
evaluate_predictions(
dtm, reconstruction, test_labels[dtm.output_name()], {}, dtm.name, my_out_path,
test_paths, dpi=args.dpi, width=args.plot_width, height=args.plot_height,
)
return performance_metrics


Expand Down
2 changes: 1 addition & 1 deletion ml4h/tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
:param paths: If weights is provided, paths should be a list of path lists the same length as weights
"""
self.augment = augment
self.paths = sum(paths) if isinstance(paths[0], list) else paths
self.paths = sum(paths) if (len(paths) > 0 and isinstance(paths[0], list)) else paths
self.run_on_main_thread = num_workers == 0
self.q = None
self.stats_q = None
Expand Down
Loading

0 comments on commit 5b0dbec

Please sign in to comment.