Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SCDL Row Feature Index for Performance Improvement #443

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq


__all__: Sequence[str] = ("RowFeatureIndex",)
Expand All @@ -45,9 +47,11 @@ class RowFeatureIndex:
def __init__(self) -> None:
"""Instantiates the index."""
self._cumulative_sum_index: np.array = np.array([-1])
self._feature_arr: List[pd.DataFrame] = []
self._feature_arr: List[dict] = []
self._num_genes_per_row: List[int] = []
self._version = importlib.metadata.version("bionemo.scdl")
self._labels: List[str] = []
self._feature_arr_lookup: List[dict] = []

def version(self) -> str:
"""Returns a version number.
Expand All @@ -60,7 +64,9 @@ def __len__(self) -> int:
"""The length is the number of rows or RowFeatureIndex length."""
return len(self._feature_arr)

def append_features(self, n_obs: int, features: pd.DataFrame, label: Optional[str] = None) -> None:
def append_features(
self, n_obs: int, features: dict, num_genes: Optional[int], label: Optional[str] = None
) -> None:
"""Updates the index with the given features.

The dataframe is inserted into the feature array by adding a
Expand All @@ -70,14 +76,18 @@ def append_features(self, n_obs: int, features: pd.DataFrame, label: Optional[st
n_obs (int): The number of times that these feature occur in the
class.
features (pd.DataFrame): Corresponding features.
num_genes (int): the length of the features for each feature key in features (i.e., number of genes)
label (str): Label for the features.
"""
if isinstance(features, pd.DataFrame):
raise TypeError("Expected a dictionary, but received a Pandas DataFrame.")
csum = max(self._cumulative_sum_index[-1], 0)
self._cumulative_sum_index = np.append(self._cumulative_sum_index, csum + n_obs)
self._feature_arr.append(features)
self._num_genes_per_row.append(num_genes)
self._labels.append(label)

def lookup(self, row: int, select_features: Optional[List[str]] = None) -> Tuple[pd.DataFrame, str]:
def lookup(self, row: int, select_features: Optional[List[str]] = None) -> Tuple[List[np.ndarray], str]:
"""Find the features at a given row.

It is assumed that the row is
Expand Down Expand Up @@ -113,17 +123,23 @@ def lookup(self, row: int, select_features: Optional[List[str]] = None) -> Tuple
d_id = sum(mask) - 1

# Retrieve the features for the identified value.
features = self._feature_arr[d_id]
features_dict = self._feature_arr[d_id]

# If specific features are to be selected, filter the features.
if select_features is not None:
features = features[select_features]
features = []
for feature in select_features:
if feature not in features_dict:
raise ValueError(f"Provided feature column {feature} in select_features not present in dataset.")
features.append(features_dict[feature])
else:
features = [features_dict[f] for f in features_dict]

# Return the features for the identified range.
return features, self._labels[d_id]

def number_vars_at_row(self, row: int) -> int:
"""Return number of variables (legnth of the dataframe) in a given row.
"""Return number of variables (length of the dataframe) in a given row.

Args:
row (int): The row in the feature index.
Expand All @@ -132,7 +148,7 @@ def number_vars_at_row(self, row: int) -> int:
The length of the features at the row
"""
feats, _ = self.lookup(row=row)
return len(feats)
return len(feats[0])

def column_dims(self) -> List[int]:
"""Return the number of columns in all rows.
Expand All @@ -144,7 +160,7 @@ def column_dims(self) -> List[int]:
A list containing the lengths of the features in every row
"""
# Just take the total dim of the DataFrame(s)
return [len(feats) for feats in self._feature_arr]
return self._num_genes_per_row

def number_of_values(self) -> List[int]:
"""Get the total number of values in the array.
Expand All @@ -160,8 +176,10 @@ def number_of_values(self) -> List[int]:
self._cumulative_sum_index[i] - max(self._cumulative_sum_index[i - 1], 0)
for i in range(1, len(self._cumulative_sum_index))
]

vals = [n_rows * len(self._feature_arr[i]) for i, n_rows in enumerate(rows)]
vals = []
for i, n_rows in enumerate(rows):
num_genes = self._num_genes_per_row[i]
vals.append(n_rows * num_genes)
return vals

def number_of_rows(self) -> int:
Expand Down Expand Up @@ -201,7 +219,8 @@ def concat(self, other_row_index: RowFeatureIndex, fail_on_empty_index: bool = T
for i, feats in enumerate(list(other_row_index._feature_arr)):
c_span = other_row_index._cumulative_sum_index[i + 1]
label = other_row_index._labels[i]
self.append_features(c_span, feats, label)
num_genes = other_row_index._num_genes_per_row[i]
self.append_features(c_span, feats, num_genes, label)

return self

Expand All @@ -213,10 +232,11 @@ def save(self, datapath: str) -> None:
"""
Path(datapath).mkdir(parents=True, exist_ok=True)
num_digits = len(str(len(self._feature_arr)))
for index, feature_dict in enumerate(self._feature_arr):
table = pa.table({column: pa.array(values) for column, values in feature_dict.items()})
dataframe_str_index = f"{index:0{num_digits}d}"
pq.write_table(table, f"{datapath}/dataframe_{dataframe_str_index}.parquet")

for dataframe_index, dataframe in enumerate(self._feature_arr):
dataframe_str_index = f"{dataframe_index:0{num_digits}d}"
dataframe.to_parquet(f"{datapath}/dataframe_{dataframe_str_index}.parquet", index=False)
np.save(Path(datapath) / "cumulative_sum_index.npy", self._cumulative_sum_index)
np.save(Path(datapath) / "labels.npy", self._labels)
np.save(Path(datapath) / "version.npy", np.array(self._version))
Expand All @@ -232,7 +252,15 @@ def load(datapath: str) -> RowFeatureIndex:
"""
new_row_feat_index = RowFeatureIndex()
parquet_data_paths = sorted(Path(datapath).rglob("*.parquet"))
new_row_feat_index._feature_arr = [pd.read_parquet(csv_path) for csv_path in parquet_data_paths]
new_row_feat_index._feature_arr = [pq.read_table(csv_path) for csv_path in parquet_data_paths]
new_row_feat_index._feature_arr = [
{column: table[column].to_numpy() for column in table.column_names}
for table in new_row_feat_index._feature_arr
]
new_row_feat_index._num_genes_per_row = [
len(feats[next(iter(feats.keys()))]) for feats in new_row_feat_index._feature_arr
]

new_row_feat_index._cumulative_sum_index = np.load(Path(datapath) / "cumulative_sum_index.npy")
new_row_feat_index._labels = np.load(Path(datapath) / "labels.npy", allow_pickle=True)
new_row_feat_index._version = np.load(Path(datapath) / "version.npy").item()
Expand Down
Loading
Loading