From 8a6a0f07797dcef8eb190a89cfd2f78bee1f71c8 Mon Sep 17 00:00:00 2001 From: 1andrin <115493865+1andrin@users.noreply.github.com> Date: Mon, 16 Dec 2024 01:20:32 +0100 Subject: [PATCH 1/3] updates to experiment_runner.py to improve the plot design and to add auto-detection of the current case Signed-off-by: 1andrin <115493865+1andrin@users.noreply.github.com> --- notebooks/RunExperiments/experiment_runner.py | 304 +++++++++++++----- 1 file changed, 220 insertions(+), 84 deletions(-) diff --git a/notebooks/RunExperiments/experiment_runner.py b/notebooks/RunExperiments/experiment_runner.py index 033b86a..da5fcf3 100644 --- a/notebooks/RunExperiments/experiment_runner.py +++ b/notebooks/RunExperiments/experiment_runner.py @@ -4,7 +4,7 @@ import glob import copy import argparse -from typing import List +from typing import List, Union import numpy as np import matplotlib @@ -47,6 +47,10 @@ def parse_arguments(): help="Datasets to use (format: Size Name, e.g., Small Linear_RCT)", ) parser.add_argument("--n_runs", type=int, default=1, help="Number of runs") + parser.add_argument( + "--num_samples", type=int, default=-1, help="Maximum number of iterations" + ) + parser.add_argument( "--outcome_model", type=str, default="nested", help="Outcome model type" ) @@ -130,7 +134,16 @@ def run_experiment(args): args.time_budget = None # Use only components_time_budget for dataset_name, cd in data_sets.items(): - case = dataset_name.split("_")[-1] + # Extract case while preserving original string checking logic + if "KCKP" in dataset_name: + case = "KCKP" + elif "KC" in dataset_name: + case = "KC" + elif "IV" in dataset_name: + case = "IV" + else: + case = "RCT" + os.makedirs(f"{out_dir}/{case}", exist_ok=True) for i_run in range(1, args.n_runs + 1): @@ -142,7 +155,8 @@ def run_experiment(args): for metric in args.metrics: if metric == "ate": # this is not something to optimize continue - print(f"Optimzing {metric} for {dataset_name} (run {i_run})") + + print(f"Optimizing {metric} for {dataset_name} (run {i_run})") try: fn = make_filename(metric, dataset_name, i_run) out_fn = os.path.join(out_dir, case, fn) @@ -150,13 +164,17 @@ def run_experiment(args): print(f"File {out_fn} exists, skipping...") continue - if "KC" in dataset_name and "KCKP" not in dataset_name: - propensity_model = "auto" - elif "KCKP" in dataset_name: + # Set propensity model using string checking like original version + if "KCKP" in dataset_name: + print(f"Using passthrough propensity model for {dataset_name}") propensity_model = passthrough_model( - cd.propensity_modifiers, include_control=False + cd_i.propensity_modifiers, include_control=False ) + elif "KC" in dataset_name: + print(f"Using auto propensity model for {dataset_name}") + propensity_model = "auto" else: + print(f"Using dummy propensity model for {dataset_name}") propensity_model = "dummy" ct = CausalTune( @@ -236,6 +254,9 @@ def compute_scores(ct, metric, test_df): reverse=metric not in metrics_to_minimize(), ) + # Debugging: Log final result structure + print(f"Returning scores for metric {metric}: Best estimator: {ct.best_estimator}") + return { "best_estimator": ct.best_estimator, "best_config": ct.best_config, @@ -282,25 +303,36 @@ def get_all_test_scores(out_dir, dataset_name): def generate_plots( out_dir: str, - log_scale: List[str] | None = None, - upper_bounds: dict | None = None, - lower_bounds: dict | None = None, + log_scale: Union[List[str], None] = None, + upper_bounds: Union[dict, None] = None, + lower_bounds: Union[dict, None] = None, + font_size=0, ): if log_scale is None: log_scale = ["energy_distance", "psw_energy_distance", "frobenius_norm"] + if upper_bounds is None: + upper_bounds = {} # Use an empty dictionary if None + if lower_bounds is None: + lower_bounds = {} # Use an empty dictionary if None + metrics, datasets = extract_metrics_datasets(out_dir) - # Define names for metrics and experiments + # Remove 'ate' from metrics + metrics = [m for m in metrics if m.lower() != "ate"] + metric_names = { - "psw_frobenius_norm": "Propensity Weighted Frobenius Norm", - "frobenius_norm": "Frobenius Norm", + "psw_frobenius_norm": "PSW\nFrobenius\nNorm", + "frobenius_norm": "Frobenius\nNorm", "erupt": "ERUPT", "codec": "CODEC", - "policy_risk": "Policy Risk", - "energy_distance": "Energy Distance", - "psw_energy_distance": "Propensity Weighted Energy Distance", + "auc": "AUC", + "qini": "Qini", + "bite": "BITE", + "policy_risk": "Policy\nRisk", + "energy_distance": "Energy\nDistance", + "psw_energy_distance": "PSW\nEnergy\nDistance", + "norm_erupt": "Normalized\nERUPT", } - # Coloring and marker styles colors = ( [matplotlib.colors.CSS4_COLORS["black"]] + list(matplotlib.colors.TABLEAU_COLORS) @@ -312,72 +344,157 @@ def generate_plots( ) markers = ["o", "s", "D", "^", "v", "<", ">", "P", "*", "h", "X", "|", "_", "8"] + # Determine the problem type from the dataset name + problem = "iv" if any("IV" in dataset for dataset in datasets) else "backdoor" + def plot_grid(title): + # Use determined problem type instead of hardcoding "backdoor" + all_metrics = [ + m + for m in supported_metrics(problem, False, False) + if m.lower() != "ate" and m.lower() != "norm_erupt" + ] + fig, axs = plt.subplots( - len(metrics), len(datasets), figsize=(20, 5 * len(metrics)), dpi=300 + len(all_metrics), len(datasets), figsize=(20, 5 * len(all_metrics)), dpi=300 ) - if len(metrics) == 1 and len(datasets) == 1: + + if len(all_metrics) == 1 and len(datasets) == 1: axs = np.array([[axs]]) - elif len(metrics) == 1 or len(datasets) == 1: + elif len(all_metrics) == 1 or len(datasets) == 1: axs = axs.reshape(-1, 1) if len(datasets) == 1 else axs.reshape(1, -1) - for i, metric in enumerate(metrics): - for j, dataset in enumerate(datasets): - ax = axs[i, j] - + # For multiple metrics in args.metrics, use the first one that has a results file + results_files = {} + for dataset in datasets: + for metric in args.metrics: filename = make_filename(metric, dataset, 1) filepath = os.path.join(out_dir, filename) - if os.path.exists(filepath): - with open(filepath, "rb") as f: - results = pickle.load(f) - - best_estimator = results["best_estimator"] - CATE_gt = results["scores_per_estimator"][best_estimator][0][ - "test" - ]["CATE_groundtruth"] - CATE_est = results["scores_per_estimator"][best_estimator][0][ - "test" - ]["CATE_estimate"] - - CATE_gt = np.array(CATE_gt).flatten() - CATE_est = np.array(CATE_est).flatten() - - ax.scatter(CATE_gt, CATE_est, s=20, alpha=0.1) - ax.plot( - [min(CATE_gt), max(CATE_gt)], - [min(CATE_gt), max(CATE_gt)], - "k-", - linewidth=0.5, + results_files[dataset] = filepath + break + if dataset not in results_files: + print(f"No results file found for dataset {dataset}") + + for j, dataset in enumerate(datasets): + if dataset not in results_files: + continue + + with open(results_files[dataset], "rb") as f: + results = pickle.load(f) + + print(f"Loading results for Dataset: {dataset}") + + for i, metric in enumerate(all_metrics): + ax = axs[i, j] + + try: + # Find best estimator for this metric + best_estimator = None + best_score = ( + float("inf") + if metric in metrics_to_minimize() + else float("-inf") ) + estimator_name = None + + for score in results["all_scores"]: + if "test" in score and metric in score["test"]["scores"]: + current_score = score["test"]["scores"][metric] + if metric in metrics_to_minimize(): + if current_score < best_score: + best_score = current_score + best_estimator = score + estimator_name = score["test"]["scores"][ + "estimator_name" + ] + else: + if current_score > best_score: + best_score = current_score + best_estimator = score + estimator_name = score["test"]["scores"][ + "estimator_name" + ] + + if best_estimator: + CATE_gt = np.array( + best_estimator["test"]["CATE_groundtruth"] + ).flatten() + CATE_est = np.array( + best_estimator["test"]["CATE_estimate"] + ).flatten() + + # Plotting + ax.scatter(CATE_gt, CATE_est, s=40, alpha=0.5) + ax.plot( + [min(CATE_gt), max(CATE_gt)], + [min(CATE_gt), max(CATE_gt)], + "k-", + linewidth=1.0, + ) - try: + # Calculate correlation coefficient corr = np.corrcoef(CATE_gt, CATE_est)[0, 1] + + # Add correlation ax.text( 0.05, 0.95, f"Corr: {corr:.2f}", transform=ax.transAxes, verticalalignment="top", - fontsize=8, + fontsize=font_size + 12, + fontweight="bold", ) - except ValueError: - print(f"Could not compute correlation for {dataset}_{metric}") - ax.set_title(f"{best_estimator.split('.')[-1]}", fontsize=8) - else: - ax.text(0.5, 0.5, "No data", ha="center", va="center") + # Add estimator name at bottom center + if estimator_name: + estimator_base = estimator_name.split(".")[-1] + ax.text( + 0.5, + 0.02, + estimator_base, + transform=ax.transAxes, + horizontalalignment="center", + color="blue", + fontsize=font_size + 10, + ) - ax.set_xticks([]) - ax.set_yticks([]) + except Exception as e: + print( + f"Error processing metric {metric} for dataset {dataset}: {e}" + ) + ax.text( + 0.5, + 0.5, + "Error processing data", + ha="center", + va="center", + fontsize=font_size + 12, + ) if j == 0: - ax.set_ylabel(metric_names.get(metric, metric), fontsize=10) + # Create tight layout for ylabel + ax.set_ylabel( + metric_names.get(metric, metric), + fontsize=font_size + 12, + fontweight="bold", + labelpad=5, # Reduce padding between label and plot + ) if i == 0: - ax.set_title(dataset, fontsize=10) + ax.set_title( + dataset, fontsize=font_size + 14, fontweight="bold", pad=15 + ) + ax.set_xticks([]) + ax.set_yticks([]) - plt.suptitle(f"Estimated CATEs vs. True CATEs: {title}", fontsize=16) - plt.tight_layout(rect=[0, 0, 1, 0.96]) + plt.suptitle( + f"Estimated CATEs vs. True CATEs: {title}", + fontsize=font_size + 18, + fontweight="bold", + ) + # Adjust spacing between subplots + plt.tight_layout(rect=[0.1, 0, 1, 0.96], h_pad=1.0, w_pad=0.5) plt.savefig( os.path.join(out_dir, "CATE_grid.pdf"), format="pdf", bbox_inches="tight" ) @@ -389,14 +506,19 @@ def plot_grid(title): def plot_mse_grid(title): df = get_all_test_scores(out_dir, datasets[0]) est_names = sorted(df["estimator_name"].unique()) - problem = "iv" if "IV" in datasets[0] else "backdoor" + + # Problem type already determined at top level all_metrics = [ - c for c in df.columns if c in supported_metrics(problem, False, False) + c + for c in df.columns + if c in supported_metrics(problem, False, False) and c.lower() != "ate" ] fig, axs = plt.subplots( len(all_metrics), len(datasets), figsize=(20, 5 * len(all_metrics)), dpi=300 ) + + # Handle single plot cases if len(all_metrics) == 1 and len(datasets) == 1: axs = np.array([[axs]]) elif len(all_metrics) == 1 or len(datasets) == 1: @@ -405,6 +527,7 @@ def plot_mse_grid(title): legend_elements = [] for j, dataset in enumerate(datasets): df = get_all_test_scores(out_dir, dataset) + # Apply bounds filtering for m, value in upper_bounds.items(): if m in df.columns: df = df[df[m] < value].copy() @@ -416,13 +539,12 @@ def plot_mse_grid(title): ax = axs[i, j] this_df = df[["estimator_name", metric, "MSE"]].dropna() this_df = this_df[~np.isinf(this_df[metric].values)] + if len(this_df): for idx, est_name in enumerate(est_names): df_slice = this_df[this_df["estimator_name"] == est_name] if "Dummy" not in est_name and len(df_slice): - marker = markers[idx % len(markers)] - # TODO: throw x-axis outliers away ax.scatter( df_slice["MSE"], df_slice[metric], @@ -452,40 +574,45 @@ def plot_mse_grid(title): if metric in log_scale: ax.set_yscale("log") ax.grid(True) - # ax.set_title( - # f"{results['best_estimator'].split('.')[-1]}", fontsize=8 - # ) else: - ax.text(0.5, 0.5, "No data", ha="center", va="center") + ax.text( + 0.5, + 0.5, + "No data", + ha="center", + va="center", + fontsize=font_size + 12, + ) if j == 0: - ax.set_ylabel(metric_names.get(metric, metric), fontsize=10) + # Match ylabel style with plot_grid + ax.set_ylabel( + metric_names.get(metric, metric), + fontsize=font_size + 12, + fontweight="bold", + labelpad=5, + ) if i == 0: - ax.set_title(dataset, fontsize=10) + ax.set_title( + dataset, fontsize=font_size + 14, fontweight="bold", pad=15 + ) - plt.suptitle(f"MSE vs. Scores: {title}", fontsize=16) - plt.tight_layout(rect=[0, 0, 1, 0.96]) - plt.savefig( - os.path.join(out_dir, "MSE_grid.pdf"), format="pdf", bbox_inches="tight" + plt.suptitle( + f"MSE vs. Scores: {title}", + fontsize=font_size + 18, + fontweight="bold", ) - plt.savefig( - os.path.join(out_dir, "MSE_grid.png"), format="png", bbox_inches="tight" - ) - plt.close() - # Create separate legend - fig_legend, ax_legend = plt.subplots(figsize=(6, 6)) - ax_legend.legend(handles=legend_elements, loc="center", fontsize=10) - ax_legend.axis("off") + # Match spacing style with plot_grid + plt.tight_layout(rect=[0.1, 0, 1, 0.96], h_pad=1.0, w_pad=0.5) plt.savefig( - os.path.join(out_dir, "MSE_legend.pdf"), format="pdf", bbox_inches="tight" + os.path.join(out_dir, "MSE_grid.pdf"), format="pdf", bbox_inches="tight" ) plt.savefig( - os.path.join(out_dir, "MSE_legend.png"), format="png", bbox_inches="tight" + os.path.join(out_dir, "MSE_grid.png"), format="png", bbox_inches="tight" ) plt.close() - # Generate plots plot_grid("Experiment Results") plot_mse_grid("Experiment Results") @@ -500,8 +627,17 @@ def plot_mse_grid(title): # args.timestamp_in_dirname = False # args.outcome_model = "auto" # or use "nested" for the old-style nested model out_dir = run_experiment(args) + # Determine case from datasets + if any("IV" in dataset for dataset in args.datasets): + case = "IV" + elif any("KC" in dataset for dataset in args.datasets): + case = "KC" + elif any("KCKP" in dataset for dataset in args.datasets): + case = "KCKP" + else: + case = "RCT" # upper_bounds = {"MSE": 1e2, "policy_risk": 0.2} # lower_bounds = {"erupt": 0.06, "bite": 0.75} generate_plots( - os.path.join(out_dir, "RCT") + os.path.join(out_dir, case), font_size=8 ) # , upper_bounds=upper_bounds, lower_bounds=lower_bounds) From f9b718fb00c10f87700f227f011fe44b745d73cf Mon Sep 17 00:00:00 2001 From: 1andrin <115493865+1andrin@users.noreply.github.com> Date: Mon, 16 Dec 2024 01:38:52 +0100 Subject: [PATCH 2/3] linter Signed-off-by: 1andrin <115493865+1andrin@users.noreply.github.com> --- notebooks/RunExperiments/experiment_runner.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/notebooks/RunExperiments/experiment_runner.py b/notebooks/RunExperiments/experiment_runner.py index da5fcf3..ba0eceb 100644 --- a/notebooks/RunExperiments/experiment_runner.py +++ b/notebooks/RunExperiments/experiment_runner.py @@ -1,19 +1,18 @@ +import argparse +import copy +import glob import os -import sys import pickle -import glob -import copy -import argparse +import sys +import warnings +from datetime import datetime from typing import List, Union -import numpy as np import matplotlib import matplotlib.pyplot as plt +import numpy as np import pandas as pd - from sklearn.model_selection import train_test_split -from datetime import datetime -import warnings warnings.filterwarnings("ignore") From 9dff9c0e11db868df8f763e18d255df0ba8bfcbc Mon Sep 17 00:00:00 2001 From: 1andrin <115493865+1andrin@users.noreply.github.com> Date: Mon, 16 Dec 2024 01:56:59 +0100 Subject: [PATCH 3/3] trying to fix linter E402 Signed-off-by: 1andrin <115493865+1andrin@users.noreply.github.com> --- notebooks/RunExperiments/experiment_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/notebooks/RunExperiments/experiment_runner.py b/notebooks/RunExperiments/experiment_runner.py index ba0eceb..bb935ef 100644 --- a/notebooks/RunExperiments/experiment_runner.py +++ b/notebooks/RunExperiments/experiment_runner.py @@ -14,10 +14,8 @@ import pandas as pd from sklearn.model_selection import train_test_split -warnings.filterwarnings("ignore") - # Ensure CausalTune is in the Python path -root_path = os.path.realpath("../../../..") +root_path = os.path.realpath("../../../..") # noqa: E402 sys.path.append(os.path.join(root_path, "causaltune")) # noqa: E402 # Import CausalTune and other custom modules after setting up the path @@ -28,7 +26,10 @@ from causaltune.score.scoring import ( metrics_to_minimize, # noqa: E402 supported_metrics, # noqa: E402 -) # noqa: E402 +) + +# Configure warnings +warnings.filterwarnings("ignore") def parse_arguments():