Skip to content

Commit

Permalink
Merge pull request #104 from arnor-sigurdsson/improve/prediction-hand…
Browse files Browse the repository at this point in the history
…ling

Improve prediction handling
  • Loading branch information
arnor-sigurdsson authored Dec 11, 2024
2 parents 255bb6e + 0d675d8 commit 3023c45
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 94 deletions.
4 changes: 4 additions & 0 deletions docs/tutorials/h_survival_analysis/01_survival_flchain.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ Let's configure and train a model on the FLChain data. Here are the key configur
:language: yaml
:caption: input.yaml

.. literalinclude:: ../tutorial_files/h_survival_analysis/01_flchain/fusion.yaml
:language: yaml
:caption: fusion.yaml

.. literalinclude:: ../tutorial_files/h_survival_analysis/01_flchain/output.yaml
:language: yaml
:caption: output.yaml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ which is also supported for survival models
:language: yaml
:caption: input.yaml

.. literalinclude:: ../tutorial_files/h_survival_analysis/02_flchain_cox/fusion.yaml
:language: yaml
:caption: fusion.yaml

.. literalinclude:: ../tutorial_files/h_survival_analysis/02_flchain_cox/output.yaml
:language: yaml
:caption: output.yaml
Expand Down
2 changes: 1 addition & 1 deletion eir/data_load/data_preparation_modules/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def impute_missing_output_modalities(
case ComputedArrayOutputInfo():
assert output_type == "array"
shape = output_object.data_dimensions.full_shape()
approach = "random"
approach = "constant"

case ComputedTabularOutputInfo() | ComputedSurvivalOutputInfo():
assert output_type in ("tabular", "survival")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def predict_tabular_wrapper_with_labels(
ids=cur_ids,
predictions=predictions_np,
labels=target_labels_np,
tabular_output_type=target_head_name,
prediction_classes=classes,
cat_loss_name=cat_loss_name,
)
Expand Down Expand Up @@ -119,6 +120,7 @@ def _merge_ids_predictions_and_labels(
ids: Sequence[str],
predictions: np.ndarray,
labels: Optional[np.ndarray],
tabular_output_type: str,
cat_loss_name: str,
prediction_classes: Union[Sequence[str], None] = None,
label_column_name: str = "True Label",
Expand All @@ -134,7 +136,7 @@ def _merge_ids_predictions_and_labels(
if prediction_classes is None:
prediction_classes = [f"Score Class {i}" for i in range(predictions.shape[1])]

if cat_loss_name == "BCEWithLogitsLoss":
if tabular_output_type == "cat" and cat_loss_name == "BCEWithLogitsLoss":
assert predictions.shape[1] == 1, predictions.shape
assert len(prediction_classes) == 2, len(prediction_classes)
prediction_classes = prediction_classes[1]
Expand Down Expand Up @@ -184,6 +186,7 @@ def predict_tabular_wrapper_no_labels(
df_predictions = _merge_ids_and_predictions(
ids=cur_ids,
predictions=predictions,
tabular_output_type=target_head_name,
prediction_classes=classes,
cat_loss_name=cat_loss_name,
)
Expand All @@ -206,6 +209,7 @@ def _merge_ids_and_predictions(
ids: Sequence[str],
predictions: np.ndarray,
cat_loss_name: str,
tabular_output_type: str,
prediction_classes: Optional[Sequence[str]] = None,
) -> pd.DataFrame:
df = pd.DataFrame()
Expand All @@ -216,7 +220,7 @@ def _merge_ids_and_predictions(
if prediction_classes is None:
prediction_classes = [f"Score Class {i}" for i in range(predictions.shape[1])]

if cat_loss_name == "BCEWithLogitsLoss":
if tabular_output_type == "cat" and cat_loss_name == "BCEWithLogitsLoss":
assert predictions.shape[1] == 1, predictions.shape
assert len(prediction_classes) == 2, len(prediction_classes)
prediction_classes = prediction_classes[1]
Expand Down
9 changes: 8 additions & 1 deletion eir/train_utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,14 @@ def save_classification_predictions(
output_folder: Path,
) -> None:

validation_predictions_total = val_outputs.argmax(axis=1)
def sigmoid(x):
return 1 / (1 + np.exp(-x))

# BCEWithLogitsLoss case
if val_outputs.shape[1] == 1:
validation_predictions_total = sigmoid(val_outputs).round().astype(int)
else:
validation_predictions_total = val_outputs.argmax(axis=1)

df_predictions = _parse_valid_classification_predictions(
val_true=val_labels,
Expand Down
140 changes: 58 additions & 82 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ contourpy==1.3.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
coverage==7.6.9 ; python_full_version >= "3.12.0" and python_version < "3.13"
coverage[toml]==7.6.9 ; python_full_version >= "3.12.0" and python_version < "3.13"
cycler==0.12.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
debugpy==1.8.9 ; python_full_version >= "3.12.0" and python_version < "3.13"
debugpy==1.8.10 ; python_full_version >= "3.12.0" and python_version < "3.13"
decorator==5.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
deeplake==4.0.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
defusedxml==0.7.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand All @@ -55,7 +55,7 @@ fastapi==0.115.6 ; python_full_version >= "3.12.0" and python_version < "3.13"
fastjsonschema==2.21.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
filelock==3.16.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
flake8==7.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
fonttools==4.55.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
fonttools==4.55.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
formulaic==1.0.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
fqdn==1.5.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
frozenlist==1.5.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand Down Expand Up @@ -99,7 +99,7 @@ jupyter==1.1.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyterlab-pygments==0.3.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyterlab-server==2.27.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyterlab-widgets==3.0.13 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyterlab==4.3.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
jupyterlab==4.3.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
kiwisolver==1.4.7 ; python_full_version >= "3.12.0" and python_version < "3.13"
lifelines==0.30.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
markupsafe==3.0.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ einops==0.8.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
executing==2.1.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
fastapi==0.115.6 ; python_full_version >= "3.12.0" and python_version < "3.13"
filelock==3.16.1 ; python_full_version >= "3.12.0" and python_version < "3.13"
fonttools==4.55.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
fonttools==4.55.3 ; python_full_version >= "3.12.0" and python_version < "3.13"
formulaic==1.0.2 ; python_full_version >= "3.12.0" and python_version < "3.13"
frozenlist==1.5.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
fsspec==2024.10.0 ; python_full_version >= "3.12.0" and python_version < "3.13"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -869,11 +869,8 @@ def _should_compile():
"lr": 1e-03 * 4,
"gradient_accumulation_steps": 4,
},
"training_control": {
"mixing_alpha": 0.2,
},
"attribution_analysis": {
"max_attributions_per_class": 100,
"max_attributions_per_class": 200,
},
},
"input_configs": [
Expand Down

0 comments on commit 3023c45

Please sign in to comment.