From 6e97df50162693c12d97b16dbb330e55b5dd1692 Mon Sep 17 00:00:00 2001 From: Pablo Olivares <65406121+pab1s@users.noreply.github.com> Date: Mon, 15 Jul 2024 22:20:24 +0200 Subject: [PATCH] Homology Inference and plotting advances #29 --- experiments/homology.py | 260 ++++++++++++++++++++++++++++++++++ experiments/homology_times.py | 86 ----------- scripts/homology.sh | 21 ++- utils/ph_plotting.py | 190 +++++++++++++++++++++++++ 4 files changed, 465 insertions(+), 92 deletions(-) create mode 100644 experiments/homology.py delete mode 100644 experiments/homology_times.py create mode 100644 utils/ph_plotting.py diff --git a/experiments/homology.py b/experiments/homology.py new file mode 100644 index 0000000..bf0ebfe --- /dev/null +++ b/experiments/homology.py @@ -0,0 +1,260 @@ +import os +import argparse +import torch +import yaml +import numpy as np +import scipy.spatial +import datetime +import logging +import time +from torch.utils.data import DataLoader, random_split, Subset +from datasets.dataset import get_dataset +from datasets.transformations import get_transforms +from factories.model_factory import ModelFactory +from gtda.homology import VietorisRipsPersistence + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def load_pretrained_model(model_path: str, config: dict, device: torch.device) -> torch.nn.Module: + """ + Load a pretrained model from a specified path using configurations. + + Args: + model_path (str): Path to the model file. + config (dict): Configuration dictionary specifying model details. + device (torch.device): The device to load the model onto. + + Returns: + torch.nn.Module: The loaded model. + """ + + model_factory = ModelFactory() + model = model_factory.create(config['model']['type'], num_classes=config['model']['parameters']['num_classes'], pretrained=False).to(device) + model.load_state_dict(torch.load(model_path, map_location=device)) + return model + +def load_config(config_path: str) -> dict: + """ + Load a YAML configuration file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: Configuration dictionary. + + Raises: + FileNotFoundError: If the config file does not exist. + yaml.YAMLError: If there is an error parsing the YAML file. + """ + + try: + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + return config + except FileNotFoundError: + logging.error(f"Config file not found: {config_path}") + +def register_all_hooks(model: torch.nn.Module, activations: dict, layer_progress: dict) -> None: + """ + Register forward hooks to capture output activations of specific layers during model forwarding. + + Args: + model (torch.nn.Module): The model from which to capture activations. + activations (dict): A dictionary to store the activations. + layer_progress (dict): A dictionary to track the progress of output capturing. + """ + + relevant_layers = [name for name, layer in model.named_modules() if isinstance(layer, (torch.nn.ReLU, torch.nn.SiLU, torch.nn.Linear))] + total_layers = len(relevant_layers) + + def get_activation(name): + def hook(model, input, output): + activations[name] = output.detach().cpu().numpy().reshape(output.size(0), -1) + current_layer_index = relevant_layers.index(name) + progress = (current_layer_index + 1) / total_layers * 100 + layer_progress[name] = progress + return hook + + for name in relevant_layers: + layer = dict(model.named_modules())[name] + layer.register_forward_hook(get_activation(name)) + +def compute_persistence_diagrams_using_giotto(distance_matrix: np.ndarray, dimensions: list = [0, 1]) -> np.ndarray: + """ + Compute persistence diagrams using Vietoris-Rips complex from a precomputed distance matrix. + + Args: + distance_matrix (np.ndarray): A square matrix of pairwise distances. + dimensions (list): List of homology dimensions to compute. + + Returns: + np.ndarray: Array of persistence diagrams. + """ + + vr_computator = VietorisRipsPersistence(homology_dimensions=dimensions, metric="precomputed") + diagrams = vr_computator.fit_transform([distance_matrix])[0] + return np.sort(diagrams[:, :2]) + +def save_persistence_diagram(persistence_diagram: np.ndarray, layer_name: str, dataset_type: str, model_name: str, progress: float, persistence_dir: str) -> None: + """ + Save the computed persistence diagram to a text file. + + Args: + persistence_diagram (np.ndarray): Array of persistence intervals. + layer_name (str): Name of the layer for which the diagram was computed. + dataset_type (str): Type of the dataset (e.g., train, test). + model_name (str): Name of the model. + progress (float): Percentage of the progress in the model processing. + persistence_dir (str): Directory to save the persistence diagrams. + """ + + timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + filename = f"{model_name}_{layer_name}_{dataset_type}_{progress:.2f}percent_{timestamp}.txt" + dataset_persistence_dir = os.path.join(persistence_dir, dataset_type) + os.makedirs(dataset_persistence_dir, exist_ok=True) + filepath = os.path.join(dataset_persistence_dir, filename) + + with open(filepath, 'w') as f: + for birth, death in persistence_diagram: + f.write(f'{birth} {death}\n') + + logging.info(f'Saved persistence diagram to {filepath}') + +def incremental_processing(loader: DataLoader, model: torch.nn.Module, device: torch.device, activations: dict, dataset_type: str, model_name: str, layer_progress: dict, persistence_dir: str) -> None: + """ + Process data incrementally, computing persistence diagrams for model activations layer by layer. + + Args: + loader (DataLoader): DataLoader for the dataset. + model (torch.nn.Module): Pretrained model. + device (torch.device): Device on which computation is performed. + activations (dict): Dictionary holding activations. + dataset_type (str): Type of the dataset (e.g., 'train', 'valid'). + model_name (str): Name of the model. + layer_progress (dict): Progress tracking for each layer. + persistence_dir (str): Directory to save persistence diagrams. + """ + + with torch.no_grad(): + for inputs, labels in loader: + inputs = inputs.to(device) + output = model(inputs) + for name, feature_array in activations.items(): + if feature_array is not None: + progress = layer_progress.get(name, 0) + process_feature_layer(name, feature_array, dataset_type, model_name, progress, persistence_dir) + activations[name] = None + +def process_feature_layer(layer_name: str, feature_array: np.ndarray, dataset_type: str, model_name: str, progress: float, persistence_dir: str) -> None: + """ + Process a single layer's features to compute and save its persistence diagram. + + Args: + layer_name (str): Name of the layer. + feature_array (np.ndarray): Activations of the layer. + dataset_type (str): Type of the dataset being processed. + model_name (str): Model identifier. + progress (float): Progress of processing through the model. + persistence_dir (str): Directory where diagrams should be saved. + """ + + if feature_array.size > 0: + feature_array = feature_array.reshape(-1, feature_array.shape[-1]) + logging.info(f"Computing persistence diagrams for layer {layer_name} at {progress:.2f}% through the model") + distance_matrix = scipy.spatial.distance.pdist(feature_array) + square_distance_matrix = scipy.spatial.distance.squareform(distance_matrix) + persistence_diagram = compute_persistence_diagrams_using_giotto(square_distance_matrix) + save_persistence_diagram(persistence_diagram, layer_name, dataset_type, model_name, progress, persistence_dir) + else: + logging.info(f"Layer {layer_name}: No features to process at {progress:.2f}% through the model") + +def process_dataset(loader, dataset_type, model_name, persistence_dir) -> None: + """ + Process a dataset to compute persistence diagrams. + + Args: + loader (DataLoader): DataLoader for the dataset. + dataset_type (str): Type of the dataset (e.g., 'train', 'valid'). + model_name (str): Name of the model. + persistence_dir (str): Directory to save persistence diagrams. + """ + + for i, (inputs, labels) in enumerate(loader): + feature_array = inputs.view(inputs.size(0), -1).numpy() + logging.info(f"Computing persistence diagrams for {dataset_type} set, batch {i}") + distance_matrix = scipy.spatial.distance.pdist(feature_array) + square_distance_matrix = scipy.spatial.distance.squareform(distance_matrix) + persistence_diagram = compute_persistence_diagrams_using_giotto(square_distance_matrix) + save_persistence_diagram(persistence_diagram, f'batch_{i}', dataset_type, model_name, 100.0 * (i + 1) / len(loader), persistence_dir) + + +def main(config_name: str, model_name: str) -> None: + """ + Main function for computing persistence diagrams from a pretrained model. + + Args: + config_name (str): Name of the configuration file. + model_name (str): Name of the pretrained model. + """ + + model_path = f"outputs/models/DENSENET_REGULARIZADOR/{model_name}.pth" + config = load_config(f"config/{config_name}.yaml") + + if not torch.cuda.is_available(): + logging.warning("CUDA is not available. Running on CPU.") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = load_pretrained_model(model_path, config, device) + + transforms = get_transforms(config['data']['eval_transforms']) + data = get_dataset(config['data']['name'], config['data']['dataset_path'], train=True, transform=transforms) + + # Split data + total_size = len(data) + test_size = int(total_size * config['data']['test_size']) + val_size = int(total_size * config['data']['val_size']) + train_size = total_size - test_size - val_size + assert train_size > 0 and val_size > 0 and test_size > 0, "One of the splits has zero or negative size." + data_train, data_test = random_split(data, [train_size + val_size, test_size], generator=torch.Generator().manual_seed(config['random_seed'])) + data_train, data_val = random_split(data_train, [train_size, val_size], generator=torch.Generator().manual_seed(config['random_seed'])) + + # Truncate each set to 128 images + num_images = 128 + data_train = Subset(data_train, range(min(num_images, len(data_train)))) + data_val = Subset(data_val, range(min(num_images, len(data_val)))) + data_test = Subset(data_test, range(min(num_images, len(data_test)))) + + train_loader = DataLoader(data_train, batch_size=num_images, shuffle=True) + valid_loader = DataLoader(data_val, batch_size=num_images, shuffle=False) + test_loader = DataLoader(data_test, batch_size=num_images, shuffle=False) + + model_dir = os.path.join("output_files", model_name) + os.makedirs(model_dir, exist_ok=True) + persistence_dir = os.path.join(model_dir, 'persistence_diagrams') + lle_dir = os.path.join(model_dir, 'lle_plots') + os.makedirs(persistence_dir, exist_ok=True) + os.makedirs(lle_dir, exist_ok=True) + + activations = {} + layer_progress = {} + register_all_hooks(model, activations, layer_progress) + + loaders = [train_loader, valid_loader, test_loader] + dataset_types = ['train', 'valid', 'test'] + + for loader, dataset_type in zip(loaders, dataset_types): + time_start = time.time() + # process_dataset(loader, dataset_type, model_name, persistence_dir) + incremental_processing(loader, model, device, activations, dataset_type, model_name, layer_progress, persistence_dir, lle_dir) + time_end = time.time() + logging.info(f"{dataset_type.capitalize()} processing time: {time_end - time_start:.2f} seconds") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Compute persistence diagrams from a pretrained model.") + parser.add_argument("config_name", type=str, help="Name of the configuration file.") + parser.add_argument("model_name", type=str, help="Name of the pretrained model.") + args = parser.parse_args() + + main(args.config_name, args.model_name) + diff --git a/experiments/homology_times.py b/experiments/homology_times.py deleted file mode 100644 index 1374401..0000000 --- a/experiments/homology_times.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import yaml -import numpy as np -import scipy.spatial -import time -import logging -from scipy.cluster.hierarchy import linkage -from torch.utils.data import DataLoader -from datasets.dataset import get_dataset -from datasets.transformations import get_transforms -from models import get_model -from gtda.homology import VietorisRipsPersistence - -# Setup basic logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - -def prepare_model(config, device): - model = get_model( - config['model']['name'], - config['model']['num_classes'], - pretrained=config['model']['pretrained'] - ).to(device) - model.classifier = torch.nn.Identity() # Remove the classifier head - return model - -def load_config(config_path: str): - try: - with open(config_path, 'r') as file: - return yaml.safe_load(file) - except FileNotFoundError: - logging.error(f"The configuration file at {config_path} was not found.") - raise - except yaml.YAMLError as exc: - logging.error(f"Error parsing YAML file: {exc}") - raise - -def compute_features_and_labels(test_loader, model, device): - all_features, all_labels = [], [] - with torch.no_grad(): - for inputs, labels in test_loader: - inputs = inputs.to(device) - features = model(inputs) - all_features.append(features.cpu().numpy()) - all_labels.append(labels.cpu().numpy()) - return np.concatenate(all_features), np.concatenate(all_labels) - -def compute_persistence_diagram_using_single_linkage(distance_matrix): - condensed_matrix = scipy.spatial.distance.squareform(distance_matrix) - deaths = linkage(condensed_matrix, method='single')[:, 2] - return np.array([[0, d] for d in deaths]) - -def compute_persistence_diagrams_using_giotto(distance_matrix, dimensions=[0,1]): - vr_computator = VietorisRipsPersistence(homology_dimensions=dimensions, metric="precomputed") - diagrams = vr_computator.fit_transform([distance_matrix])[0] - return diagrams[diagrams[:, 2] == 0][:, :2] # Filter zero-dimensional features - -def main(config_path: str): - config = load_config(config_path) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - transforms = get_transforms(config) - data_test = get_dataset(config['data']['name'], config['data']['dataset_path'], train=False, transform=transforms) - test_loader = DataLoader(data_test, batch_size=config['training']['batch_size'], shuffle=False) - - model = prepare_model(config, device) - - features, labels = compute_features_and_labels(test_loader, model, device) - logging.info(f"Features shape: {features.shape}") - - start_time = time.time() - distance_matrix = scipy.spatial.distance.squareform(scipy.spatial.distance.pdist(features)) - logging.info(f"Time taken to compute distance matrix: {time.time() - start_time:.2f}s") - - start_time = time.time() - persistence_diagram_sl = compute_persistence_diagram_using_single_linkage(distance_matrix) - logging.info(f"Time taken for single linkage: {time.time() - start_time:.2f}s") - logging.info(f"Persistence Diagram (SL) Shape: {persistence_diagram_sl.shape}") - - dims = [0, 1] - start_time = time.time() - persistence_diagram_giotto = compute_persistence_diagrams_using_giotto(distance_matrix, dims) - logging.info(f"Time taken for Giotto: {time.time() - start_time:.2f}s") - logging.info(f"Persistence Diagram (Giotto) Shape: {persistence_diagram_giotto.shape}") - -if __name__ == "__main__": - main("config/config.yaml") diff --git a/scripts/homology.sh b/scripts/homology.sh index 0ff7193..93ea436 100644 --- a/scripts/homology.sh +++ b/scripts/homology.sh @@ -2,21 +2,30 @@ #SBATCH --job-name=homology # Process name #SBATCH --partition=dios # Queue for execution +#SBATCH -w dionisio # Node to execute the job #SBATCH --gres=gpu:1 # Number of GPUs to use #SBATCH --mail-type=END,FAIL # Notifications for job done & fail -#SBATCH --mail-user=pablolivares@correo.ugr.es # Where to send notification +#SBATCH --mail-user=user@mail.com # Where to send notification # Load necessary paths export PATH="/opt/anaconda/anaconda3/bin:$PATH" export PATH="/opt/anaconda/bin:$PATH" +export PYTHONPATH=$(dirname $(dirname "$0")) # Setup Conda environment eval "$(conda shell.bash hook)" -conda activate /mnt/homeGPU/polivares/tda-nn/tda-nn-separability +conda activate tda-nn-analysis export TFHUB_CACHE_DIR=. -# Call the Python script with the provided arguments -python homology_times.py +# Ensure that the correct number of arguments is provided +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Read the command-line arguments +CONFIG_NAME=$1 +MODEL_NAME=$2 -# Notify by email when the process is completed, not needed if SLURM mail is set -# mail -s "Proceso finalizado" pablolivares@correo.ugr.es <<< "El proceso ha finalizado" +# Call the Python script with the provided arguments +python homology_times.py $CONFIG_NAME $MODEL_NAME diff --git a/utils/ph_plotting.py b/utils/ph_plotting.py new file mode 100644 index 0000000..e405a91 --- /dev/null +++ b/utils/ph_plotting.py @@ -0,0 +1,190 @@ +import os +import re +import gudhi as gd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt +from typing import List, Tuple, Optional + + +def read_persistence_intervals(filename) -> List[Tuple[float, float]]: + """ + Reads persistence intervals from a file. + + Parameters: + filename (str): The path to the file containing persistence intervals. + + Returns: + List[Tuple[float, float]]: A list of tuples containing the birth and death values. + """ + with open(filename, 'r') as file: + lines = file.readlines() + intervals = [tuple(map(float, line.strip().split())) for line in lines] + return intervals + +def extract_percentage(filename: str) -> Optional[float]: + """ + Extract the percentage from the filename using regex. + + Args: + filename (str): The filename containing the percentage. + + Returns: + Optional[float]: The extracted percentage or None if not found. + """ + + match = re.search(r'_([\d\.]+)percent_', filename) + + return float(match.group(1)) if match else None + +def extract_info(directory_name: str, choice: str) -> Optional[int]: + """ + Extract information based on a choice from the directory name using regex. + + Args: + directory_name (str): Name of the directory. + choice (str): Choice from ['batch_size', 'optimizer', 'architecture']. + + Returns: + Optional[int]: Indexed value based on the directory pattern. + """ + + patterns = { + 'batch_size': r'_(\d+)_\d{4}-\d{2}-\d{2}', + 'optimizer': r'_(Adam|SGD)_', + 'architecture': r'^(densenet121|efficientnet_b0|resnet18)_' + } + indexes = { + 'batch_size': {'8': 0, '16': 1, '32': 2, '64': 3}, + 'optimizer': {'Adam': 0, 'SGD': 1}, + 'architecture': {'densenet121': 0, 'efficientnet_b0': 1, 'resnet18': 2} + } + + if choice in patterns: + match = re.search(patterns[choice], directory_name) + if match: + return indexes[choice][match.group(1)] + return None + +def compute_total_persistence(persistence_intervals: List[Tuple[float, float]]) -> float: + """ + Calculate the total persistence from a list of persistence intervals. + + Args: + persistence_intervals (List[Tuple[float, float]]): List of intervals. + + Returns: + float: The total persistence calculated from the intervals. + """ + + max_lifetime = max(death - birth for birth, death in persistence_intervals if death != np.inf) + return sum((death if death != np.inf else max_lifetime) - birth for birth, death in persistence_intervals) # / max_lifetime + +def plot_barcode_sets(base_path: str, model_dir: str, datasets: List[str]) -> None: + """ + Plot barcode values for different datasets within a model directory. + + Args: + base_path (str): Base directory where barcodes by dataset are located. + model_dir (str): Specific model directory to plot. + datasets (List[str]): List of datasets to process. + + Returns: + None: This function plots a graph. + """ + + colors = sns.color_palette("hls", len(datasets)) + plt.figure(figsize=(12, 8)) + sns.set_theme(style="whitegrid") + + for idx, dataset in enumerate(datasets): + dataset_path = os.path.join(base_path, model_dir, 'persistence_diagrams', dataset) + barcode_values = [] + if os.path.isdir(dataset_path): + for filename in sorted(os.listdir(dataset_path)): + if filename.endswith(".txt"): + percentage = extract_percentage(filename) + if percentage is not None: + filepath = os.path.join(dataset_path, filename) + intervals = read_persistence_intervals(filepath) + barcode = compute_total_persistence(intervals) + barcode_values.append((percentage, barcode)) + + if barcode_values: + barcode_values.sort() + percentages, entropies = zip(*barcode_values) + sns.regplot(x=np.array(percentages), y=np.array(entropies), order=2, + scatter_kws={'s': 100, 'color': colors[idx], 'alpha': 0.5}, + line_kws={'color': colors[idx], 'lw': 6}, + label=f'{dataset.capitalize()} Set') + + plt.ylim(0, 250000) + plt.xlabel('Percentage', fontsize=20) + plt.ylabel('Total Persistence', fontsize=20) + plt.legend(prop={'size': 24}) + plt.tight_layout() + plt.show() + +def plot_barcodes_groups(base_directory_path: str, choice: str) -> None: + """ + Plot barcode values grouped by specified attributes like batch size, optimizer, or architecture. + + Args: + base_directory_path (str): Base directory containing the model directories. + choice (str): Attribute to group the barcodes by. + """ + + colors = sns.color_palette("hls", 4) + plt.figure(figsize=(12, 8)) + sns.set_theme(style="whitegrid") + + for model_dir in os.listdir(base_directory_path): + model_path = os.path.join(base_directory_path, model_dir) + if os.path.isdir(model_path): + index = extract_info(model_dir, choice) + if index is not None: + persistence_diagram_path = os.path.join(model_path, 'persistence_diagrams/test') + if os.path.exists(persistence_diagram_path): + files_with_percentages = [(f, extract_percentage(f)) for f in os.listdir(persistence_diagram_path) if f.endswith(".txt") and extract_percentage(f) is not None] + sorted_files = sorted(files_with_percentages, key=lambda x: x[1]) + percentages, barcodes = zip(*[(p, compute_total_persistence(read_persistence_intervals(os.path.join(persistence_diagram_path, f)))) for f, p in sorted_files]) + sns.regplot(x=np.array(percentages), y=np.array(barcodes), order=2, + scatter_kws={'s': 100, 'color': colors[index], 'alpha': 0.5}, + line_kws={'color': colors[index], 'lw': 6}, + label=f'{model_dir}') + + plt.ylim(0, 250000) + plt.xlabel('Percentage', fontsize=20) + plt.ylabel('Total Persistence', fontsize=20) + plt.legend(prop={'size': 24}) + plt.tight_layout() + plt.show() + +def barcode_plot(filename) -> None: + """ + Plots the persistence barcode for a given file containing persistence intervals. + + Parameters: + filename (str): The path to the file containing persistence intervals. + + Returns: + None + """ + persistence_intervals = read_persistence_intervals(filename) + + # Separate persistence intervals by type + H0_intervals = [(0, interval) for interval in persistence_intervals if interval[0] == 0.0] + H1_intervals = [(1, interval) for interval in persistence_intervals if interval[0] != 0.0] + + diag = H0_intervals + H1_intervals + + # Visualize persistence with GUDHI + plt.figure(figsize=(12, 8)) + gd.plot_persistence_barcode(diag, legend=True) + plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) + plt.title('EfficientNet-B0 Regularizado 100%') + plt.show() + +if __name__ == '__main__': + filename = 'output_files/FINAL/m_efficientnet_base/efficientnet_b0_Adam_8_2024-06-01_14-51-19/efficientnet_b0_Adam_8_2024-06-01_14-51-19_100percent_train.txt' + barcode_plot(filename)