Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Enable features with dtype = 'str' #2226

Merged
merged 25 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions docs/curate-df.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,35 @@
"outputs": [],
"source": [
"# validate again\n",
"validated = curate.validate()\n",
"validated"
"curate.validate()"
]
},
{
"cell_type": "markdown",
"id": "ab7cfff0",
"metadata": {},
"source": [
"Save a curated artifact."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e7e6372",
"metadata": {},
"outputs": [],
"source": [
"artifact = curate.save_artifact(description=\"My curated dataframe\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74f91cc8",
"metadata": {},
"outputs": [],
"source": [
"artifact.describe(print_types=True)"
]
},
{
Expand Down Expand Up @@ -360,9 +387,8 @@
" \"ENSG00000153563\": [13, 14, 15],\n",
" \"ENSGcorrupted\": [16, 17, 18]\n",
" }, \n",
" index=df.index\n",
" index=df.index # because we already curated the dataframe above, it will validate \n",
")\n",
"\n",
"adata = ad.AnnData(X=X, obs=df)\n",
"adata"
]
Expand All @@ -383,20 +409,7 @@
" var_index=bt.Gene.ensembl_gene_id, # validate var.index against Gene.ensembl_gene_id\n",
" categoricals=categoricals, \n",
" organism=\"human\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8d8cee5",
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
")\n",
"curate.validate()"
]
},
Expand Down Expand Up @@ -465,18 +478,9 @@
" categoricals=categoricals, \n",
" organism=\"human\",\n",
")\n",
"\n",
"curate.validate()"
]
},
{
"cell_type": "markdown",
"id": "38a30170",
"metadata": {},
"source": [
"## Save a curated artifact"
]
},
{
"cell_type": "markdown",
"id": "a814ef37",
Expand Down
12 changes: 7 additions & 5 deletions lamindb/_curate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def standardize(self, key: str) -> None:
Inplace modification of the dataset.

Args:
key: `str` The name of the column to standardize.
key: The name of the column to standardize.

Returns:
None
Expand Down Expand Up @@ -336,7 +336,7 @@ def standardize(self, key: str):
"""Replace synonyms with standardized values.

Args:
key: `str` The key referencing the slot in the DataFrame from which to draw terms.
key: The key referencing the slot in the DataFrame from which to draw terms.

Modifies the input dataset inplace.
"""
Expand Down Expand Up @@ -675,9 +675,10 @@ def standardize(self, key: str):
"""Replace synonyms with standardized values.

Args:
key: `str` The key referencing the slot in `adata.obs` from which to draw terms. Same as the key in `categoricals`.
- If "var_index", standardize the var.index.
- If "all", standardize all obs columns and var.index.
key: The key referencing the slot in `adata.obs` from which to draw terms. Same as the key in `categoricals`.

- If "var_index", standardize the var.index.
- If "all", standardize all obs columns and var.index.

Inplace modification of the dataset.
"""
Expand Down Expand Up @@ -1529,6 +1530,7 @@ def _add_labels(
feature=feature,
feature_ref_is_name=feature_ref_is_name,
label_ref_is_name=label_ref_is_name,
from_curator=True,
)

if artifact._accessor == "MuData":
Expand Down
115 changes: 60 additions & 55 deletions lamindb/_feature.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,59 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Literal, get_args

import lamindb_setup as ln_setup
import pandas as pd
from lamin_utils import logger
from lamindb_setup.core._docs import doc_args
from lnschema_core.models import Artifact, Feature
from lnschema_core.models import Artifact, Feature, Record
from lnschema_core.types import FeatureDtype
from pandas.api.types import CategoricalDtype, is_string_dtype

from lamindb.core.exceptions import ValidationError

from ._query_set import RecordsList
from ._utils import attach_func_to_class_method
from .core._settings import settings
from .core.schema import dict_schema_name_to_model_name

if TYPE_CHECKING:
from lnschema_core.types import FieldAttr

FEATURE_TYPES = {
"number": "number",
"int": "int",
"float": "float",
"bool": "bool",
"date": "date",
"datetime": "datetime",
"str": "cat",
"object": "cat",
}


def convert_numpy_dtype_to_lamin_feature_type(dtype, str_as_cat: bool = True) -> str:
orig_type = dtype.name
# strip precision qualifiers
type = "".join(i for i in orig_type if not i.isdigit())
if type == "object" or type == "str":
type = "cat" if str_as_cat else "str"
return type
from pandas.core.dtypes.base import ExtensionDtype


FEATURE_DTYPES = set(get_args(FeatureDtype))


def get_dtype_str_from_dtype(dtype: Any) -> str:
if not isinstance(dtype, list) and dtype.__name__ in FEATURE_DTYPES:
dtype_str = dtype.__name__
else:
error_message = "dtype has to be of type Record or list[Record]"
if isinstance(dtype, Record):
dtype = [dtype]

Check warning on line 34 in lamindb/_feature.py

View check run for this annotation

Codecov / codecov/patch

lamindb/_feature.py#L34

Added line #L34 was not covered by tests
elif not isinstance(dtype, list):
raise ValueError(error_message)

Check warning on line 36 in lamindb/_feature.py

View check run for this annotation

Codecov / codecov/patch

lamindb/_feature.py#L36

Added line #L36 was not covered by tests
registries_str = ""
for registry in dtype:
if not hasattr(registry, "__get_name_with_schema__"):
raise ValueError(error_message)

Check warning on line 40 in lamindb/_feature.py

View check run for this annotation

Codecov / codecov/patch

lamindb/_feature.py#L40

Added line #L40 was not covered by tests
registries_str += registry.__get_name_with_schema__() + "|"
dtype_str = f'cat[{registries_str.rstrip("|")}]'
return dtype_str


def convert_pandas_dtype_to_lamin_dtype(pandas_dtype: ExtensionDtype) -> str:
if is_string_dtype(pandas_dtype):
if not isinstance(pandas_dtype, CategoricalDtype):
dtype = "str"
else:
dtype = "cat"

Check warning on line 51 in lamindb/_feature.py

View check run for this annotation

Codecov / codecov/patch

lamindb/_feature.py#L51

Added line #L51 was not covered by tests
else:
# strip precision qualifiers
dtype = "".join(dt for dt in pandas_dtype.name if not dt.isdigit())
assert dtype in FEATURE_DTYPES # noqa: S101
return dtype


def __init__(self, *args, **kwargs):
Expand All @@ -47,28 +66,16 @@
dtype: type | str = kwargs.pop("dtype") if "dtype" in kwargs else None
# cast type
if dtype is None:
raise ValueError("Please pass dtype!")
raise ValueError(f"Please pass dtype, one of {FEATURE_DTYPES}")
elif dtype is not None:
if not isinstance(dtype, str):
if not isinstance(dtype, list) and dtype.__name__ in FEATURE_TYPES:
dtype_str = FEATURE_TYPES[dtype.__name__]
else:
if not isinstance(dtype, list):
raise ValueError("dtype has to be a list of Record types")
registries_str = ""
for cls in dtype:
if not hasattr(cls, "__get_name_with_schema__"):
raise ValueError("each element of the list has to be a Record")
registries_str += cls.__get_name_with_schema__() + "|"
dtype_str = f'cat[{registries_str.rstrip("|")}]'
dtype_str = get_dtype_str_from_dtype(dtype)
else:
dtype_str = dtype
# add validation that a registry actually exists
if dtype_str not in FEATURE_TYPES.values() and not dtype_str.startswith(
"cat"
):
if dtype_str not in FEATURE_DTYPES and not dtype_str.startswith("cat"):
raise ValueError(
f"dtype is {dtype_str} but has to be one of 'number', 'int', 'float', 'cat', 'bool', 'cat[...]'!"
f"dtype is {dtype_str} but has to be one of {FEATURE_DTYPES}!"
)
if dtype_str != "cat" and dtype_str.startswith("cat"):
registries_str = dtype_str.replace("cat[", "").rstrip("]")
Expand All @@ -81,6 +88,13 @@
)
kwargs["dtype"] = dtype_str
super(Feature, self).__init__(*args, **kwargs)
if not self._state.adding:
if not (
self.dtype.startswith("cat") if dtype == "cat" else self.dtype == dtype
):
raise ValidationError(
f"Feature {self.name} already exists with dtype {self.dtype}, you passed {dtype}"
)


def categoricals_from_df(df: pd.DataFrame) -> dict:
Expand All @@ -94,7 +108,9 @@
for key in string_cols:
c = pd.Categorical(df[key])
if len(c.categories) < len(c):
categoricals[key] = c
logger.warning(
f"consider changing the dtype of string column `{key}` to categorical"
)
return categoricals


Expand All @@ -103,29 +119,18 @@
def from_df(cls, df: pd.DataFrame, field: FieldAttr | None = None) -> RecordsList:
"""{}""" # noqa: D415
field = Feature.name if field is None else field
registry = field.field.model
if registry != Feature:
raise ValueError("field must be a Feature FieldAttr!")

Check warning on line 124 in lamindb/_feature.py

View check run for this annotation

Codecov / codecov/patch

lamindb/_feature.py#L124

Added line #L124 was not covered by tests
categoricals = categoricals_from_df(df)

dtypes = {}
# categoricals_with_unmapped_categories = {} # type: ignore
for name, col in df.items():
if name in categoricals:
dtypes[name] = "cat"
else:
dtypes[name] = convert_numpy_dtype_to_lamin_feature_type(col.dtype)

# silence the warning "loaded record with exact same name "
verbosity = settings.verbosity
try:
settings.verbosity = "error"

registry = field.field.model
if registry != Feature:
raise ValueError("field must be a Feature FieldAttr!")
# create records for all features including non-validated
dtypes[name] = convert_pandas_dtype_to_lamin_dtype(col.dtype)
with logger.mute(): # silence the warning "loaded record with exact same name "
features = [Feature(name=name, dtype=dtype) for name, dtype in dtypes.items()]
finally:
settings.verbosity = verbosity

assert len(features) == len(df.columns) # noqa: S101
return RecordsList(features)

Expand Down
10 changes: 6 additions & 4 deletions lamindb/_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lnschema_core import Feature, FeatureSet, Record, ids
from lnschema_core.types import FieldAttr, ListLike

from ._feature import convert_numpy_dtype_to_lamin_feature_type
from ._feature import convert_pandas_dtype_to_lamin_dtype
from ._record import init_self_from_db
from ._utils import attach_func_to_class_method
from .core.exceptions import ValidationError
Expand All @@ -26,7 +26,7 @@

from ._query_set import QuerySet

NUMBER_TYPE = "number"
NUMBER_TYPE = "num"
DICT_KEYS_TYPE = type({}.keys()) # type: ignore


Expand Down Expand Up @@ -179,13 +179,15 @@ def from_df(
logger.warning("no validated features, skip creating feature set")
return None
if registry == Feature:
validated_features = Feature.from_df(df.loc[:, validated])
validated_features = Feature.from_values(
df.columns, field=field, organism=organism
)
feature_set = FeatureSet(validated_features, name=name, dtype=None)
else:
dtypes = [col.dtype for (_, col) in df.loc[:, validated].items()]
if len(set(dtypes)) != 1:
raise ValueError(f"data types are heterogeneous: {set(dtypes)}")
dtype = convert_numpy_dtype_to_lamin_feature_type(dtypes[0])
dtype = convert_pandas_dtype_to_lamin_dtype(dtypes[0])
validated_features = registry.from_values(
df.columns[validated],
field=field,
Expand Down
7 changes: 7 additions & 0 deletions lamindb/_query_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ def one(self) -> Record:
"""Exactly one result. Throws error if there are more or none."""
return one_helper(self)

def save(self) -> RecordsList:
"""Save all records to the database."""
from lamindb._save import save

save(self)
return self


class QuerySet(models.QuerySet):
"""Sets of records returned by queries.
Expand Down
Loading