Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstrings #14

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/behapy/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ def load_events(root: Path,
run (str): run ID
"""
events_path = get_events_path(root, subject, session, task, run)

"""Locate the events from the BIDS root directory.

Args:
root (Path): path to the root of the BIDS dataset
subject (str): subject ID
session (str): session ID
task (str): task ID
run (str): run ID
"""
if not events_path.exists():
raise ValueError(f'Events file {events_path} does not exist')
events = pd.read_csv(events_path, index_col=0)
Expand Down Expand Up @@ -109,7 +119,7 @@ def _find_nearest(origin, fit):
second['nearest'] = False
second.loc[first['origin'], 'nearest'] = True
return second['nearest'].to_numpy()


def _build_single_regressor(data: pd.DataFrame,
events: pd.Series,
Expand All @@ -132,6 +142,14 @@ def _build_single_regressor(data: pd.DataFrame,
def build_design_matrix(data: pd.DataFrame,
events: pd.DataFrame,
window: Tuple[float, float]) -> pd.DataFrame:

"""Builds the design matrix.
Args:
data: preprocessed signal
events: behavioural events (eg. lp, pel)
window: time window either side of the event (eg. -20 to 10s)
"""

regressor_dfs = []
for event in events.event_id.unique():
matrix, offsets = _build_single_regressor(
Expand All @@ -149,6 +167,13 @@ def build_design_matrix(data: pd.DataFrame,
def regress(design_matrix: pd.DataFrame,
data: pd.DataFrame,
min_events=50) -> pd.Series:
"""OLS regression based on data in the design matrix.
Args:
design_matrix: the matrix of all events
data: preprocessed signal
min_events: lower cut-off to remove infreqent events
"""

dm = design_matrix.loc[:, design_matrix.sum() > min_events]
if dm.empty:
return pd.Series(dtype=float, index=dm.columns)
Expand Down
73 changes: 67 additions & 6 deletions src/behapy/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@ def load_channel(root, subject, session, task, run, label, channel):


def load_signal(root, subject, session, task, run, label, iso_channel='iso'):

"""Load a raw signal, including the isosbestic channel if present.

Args:
root (Path): path to the root of the BIDS dataset
subject (str): subject ID
session (str): session ID
task (str): task ID
run (str): run ID
label (str): label ID
iso_channel: set value eg 'iso'
"""
root = Path(root).absolute()
recordings = pd.DataFrame(
Expand Down Expand Up @@ -128,6 +138,12 @@ def load_signal(root, subject, session, task, run, label, iso_channel='iso'):


def downsample(signal, factor=None):
"""Downsampling of the data.

Args:
signal: concat data
factor: set to None for use in function below
"""
if factor is None:
# Downsample to something reasonable
factor = 1
Expand All @@ -144,6 +160,17 @@ def downsample(signal, factor=None):


def save_rejections(tree, root, subject, session, task, run, label):
"""Save the provided IntervalTree as a csv.

Args:
tree: from IntervalTree import?
root (Path): path to the root of the BIDS dataset
subject (str): subject ID
session (str): session ID
task (str): task ID
run (str): run ID
label (str): label ID
"""
# Save the provided IntervalTree as a CSV
fn = get_rejected_intervals_path(root, subject, session, task, run,
label)
Expand All @@ -155,7 +182,16 @@ def save_rejections(tree, root, subject, session, task, run, label):


def load_rejections(root, subject, session, task, run, label):
# Load rejected intervals if present
"""Load rejected intervals if present.

Args:
root (Path): path to the root of the BIDS dataset
subject (str): subject ID
session (str): session ID
task (str): task ID
run (str): run ID
label (str): label ID
"""
fn = get_rejected_intervals_path(root, subject, session, task, run,
label)
intervals = []
Expand All @@ -168,13 +204,18 @@ def load_rejections(root, subject, session, task, run, label):


def find_discontinuities(signal, mean_window=3, std_window=30, nstd_thresh=2):
# This currently relies on the isobestic channel being valid.
"""Find discontinuities in the signal.

# How many samples to consider for the sliding mean
Args:
signal: concat data
mean_window: how many samples to consider for the sliding mean
std_window: ?
nstd_thresh: ?
Uses the median of a sliding window STD as characteristic STD.
Currently relies on the isobestic channel being valid.
Assume that the std of the iso channel is constant.
"""
n = int(signal.attrs['fs'] * mean_window)
# Assume that the STD of the iso channel is constant. We can
# then use the median of a sliding window STD as our
# characteristic STD.
std_n = int(signal.attrs['fs'] * std_window)
# iso_rstds = np.std(sliding_window_view(site.iso(), std_n), axis=-1)
data = signal[signal.attrs['channel']].to_numpy()
Expand Down Expand Up @@ -225,6 +266,15 @@ def find_discontinuities(signal, mean_window=3, std_window=30, nstd_thresh=2):

def find_disconnects(signal, zero_nstd_thresh=5, mean_window=3, std_window=30,
nstd_thresh=2):
"""Find disconnections in the signal.

Args:
signal: concat data
zero_nstd_thresh: ?
mean_window: how many samples to consider for the sliding mean
std_window:?
nstd_thresh: ?
"""
bounds = find_discontinuities(signal, mean_window=mean_window,
std_window=std_window, nstd_thresh=nstd_thresh)
# data = signal[signal.attrs['iso_channel']].to_numpy()
Expand Down Expand Up @@ -416,6 +466,17 @@ def normalise(signal, control, mask, fs, method='fit', detrend=True):


def preprocess(root, subject, session, task, run, label):
"""Load preprocessed data.

Args:
root (Path): path to the root of the BIDS dataset
subject (str): subject ID
session (str): session ID
task (str): task ID
run (str): run ID
label (str): label ID
"""

config = load_preprocess_config(root)
intervals = load_rejections(root, subject, session, task, run, label)
# Check if the recording has rejections saved
Expand Down