Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add options to consolidate with ecg2x notebooks #564

Merged
merged 1 commit into from
May 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions ml4h/data_descriptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from ml4ht.data.defines import SampleID

import os
import glob
from typing import Callable, List, Union, Optional, Tuple, Dict, Any
Expand All @@ -12,7 +10,7 @@

from ml4ht.data.data_description import DataDescription
from ml4ht.data.util.date_selector import DATE_OPTION_KEY
from ml4ht.data.defines import LoadingOption, Tensor
from ml4ht.data.defines import LoadingOption, SampleID, Tensor

from ml4h.TensorMap import TensorMap
from ml4h.defines import PARTNERS_DATETIME_FORMAT, ECG_REST_AMP_LEADS
Expand Down Expand Up @@ -211,7 +209,7 @@ def _loading_options(
sites = [
decompress_data(
data_compressed=hd5[f'{self.hd5_path_to_ecg}/{date}/sitename'][()],
dtype='str'
dtype='str',
)
for date in dates
]
Expand Down Expand Up @@ -329,17 +327,23 @@ def __init__(
col: str,
process_col: Callable[[Any], Tensor] = None,
name: str = None,
loading_options_col: str = 'start_fu_datetime',
restrict_to_loading_option: bool = False,
):
"""
Gets data from a column of the provided DataFrame.
:param col: The column name to get data from
:param process_col: Function to turn the column value into Tensor
:param name: Optional overwrite of the df column name
:param loading_options_col: Which column to use when getting loading options
:param restrict_to_loading_option: Whether to use the loading_option when getting raw data
"""
self.process_col = process_col or self._default_process_call
self.df = df
self.col = col
self._name = name or col
self.loading_options_col = loading_options_col
self.restrict_to_loading_option = restrict_to_loading_option

@staticmethod
def _default_process_call(x: Any) -> Tensor:
Expand All @@ -352,34 +356,47 @@ def name(self) -> str:
def get_loading_options(self, sample_id):
row = self.df.loc[sample_id]
return [{
'start_fu_datetime': pd.to_datetime(row['start_fu_datetime']),
self.loading_options_col: pd.to_datetime(row[self.loading_options_col]),
}]

def get_raw_data(
self,
sample_id: SampleID,
loading_option: LoadingOption,
) -> Tensor:
col_val = self.df.loc[sample_id][self.col]
if not self.restrict_to_loading_option:
col_val = self.df.loc[sample_id][self.col]
else:
col_val = self.df.loc[self.df[self.loading_options_col] == loading_option[self.loading_options_col]].loc[sample_id][self.col]
if self.col == 'age_in_days' and 'day_delta' in loading_option:
col_val -= loading_option['day_delta']
return self.process_col(col_val)


def one_hot_sex(x):
return np.array([1, 0], dtype=np.float32) if x in [0, "Female"] else np.array([0, 1], dtype=np.float32)

def one_hot_n(n):
def one_hot(x):
if isinstance(x, str) and x in ["Female", "Male"]:
return one_hot_sex(x)
else:
a = np.zeros((n), dtype=np.float32)
a[int(x)] = 1
return a
return one_hot

def make_zscore(mu, std):
def zscore(x):
return (x-mu) / (1e-8+std)
return zscore


def dataframe_data_description_from_tensor_map(
tensor_map: TensorMap,
dataframe: pd.DataFrame,
is_input: bool = False,
loading_options_col = 'start_fu_datetime',
restrict_to_loading_option = False,
do_zscore: bool = True,
) -> DataDescription:
if tensor_map.is_survival_curve():
if tensor_map.name == 'survival_curve_af':
Expand All @@ -396,14 +413,18 @@ def dataframe_data_description_from_tensor_map(
event_column=event_column,
)
if tensor_map.is_categorical():
process_col = one_hot_sex
else:
process_col = one_hot_n(len(tensor_map.channel_map))
elif do_zscore:
process_col = make_zscore(dataframe[tensor_map.name].mean(), dataframe[tensor_map.name].std())
else:
process_col = lambda x:x
return DataFrameDataDescription(
dataframe,
col=tensor_map.name,
process_col=process_col,
name=tensor_map.input_name() if is_input else tensor_map.output_name(),
loading_options_col=loading_options_col,
restrict_to_loading_option=restrict_to_loading_option,
)


Expand Down Expand Up @@ -440,7 +461,8 @@ def get_loading_options(self, sample_id):
row = self.wide_df.loc[sample_id]
ecg_date = pd.to_datetime(row[DATE_OPTION_KEY])
start_date = ecg_date + (
pd.to_timedelta(row[self.start_age_column]) - pd.to_timedelta(row[self.ecg_age_column]))
pd.to_timedelta(row[self.start_age_column]) - pd.to_timedelta(row[self.ecg_age_column])
)
return [{
DATE_OPTION_KEY: ecg_date,
'start_date': start_date,
Expand Down Expand Up @@ -478,7 +500,8 @@ def get_raw_data(self, sample_id, loading_option):
cur_date = ecg_date + datetime.timedelta(days=day_delta)
survival_then_censor[i] = float(cur_date < censor_date)
survival_then_censor[self.intervals + i] = has_disease * float(
censor_date <= cur_date < censor_date + datetime.timedelta(days=days_per_interval))
censor_date <= cur_date < censor_date + datetime.timedelta(days=days_per_interval),
)
# Handle prevalent diseases
if has_disease and pd.to_timedelta(row[self.event_age]) <= pd.to_timedelta(ecg_age):
survival_then_censor[self.intervals] = has_disease
Expand Down
Loading