Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚸 Update .get() to accept expressions #1815

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,13 @@
"metadata": {},
"outputs": [],
"source": [
"# get an entity by uid (here, the current notebook)\n",
"# get a single record by uid (here, the current notebook)\n",
"transform = ln.Transform.get(\"FPnfDtJz8qbE\")\n",
"\n",
"# filter by description\n",
"# get a single record by matching a field\n",
"transform = ln.Transform.get(name=\"Introduction\")\n",
"\n",
"# get a set of records by filtering on description\n",
"ln.Artifact.filter(description=\"my RNA-seq\").df()\n",
"\n",
"# query all artifacts ingested from the current notebook\n",
Expand Down
25 changes: 23 additions & 2 deletions docs/records.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,34 @@
"users_dict = ln.User.lookup().dict()"
]
},
{
"cell_type": "markdown",
"id": "d54676dd",
"metadata": {},
"source": [
"## Query exactly one record"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "463ff17c",
"metadata": {},
"outputs": [],
"source": [
"# by uid\n",
"ln.User.get(\"DzTjkKse\")\n",
"# by any expression involving fields\n",
"ln.User.get(handle=\"testuser1\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "45ac3b5c",
"metadata": {},
"source": [
"## Filter by metadata"
"## Query sets of records"
]
},
{
Expand Down Expand Up @@ -148,7 +169,7 @@
"\n",
"- `.df()`: A pandas `DataFrame` with each record in a row.\n",
"- `.all()`: A {class}`~lamindb.core.QuerySet`.\n",
"- `.one()`: Exactly one record. Will raise an error if there is none.\n",
"- `.one()`: Exactly one record. Will raise an error if there is none. Is equivalent to the `.get()` method shown above.\n",
"- `.one_or_none()`: Either one record or `None` if there is no query result."
]
},
Expand Down
8 changes: 3 additions & 5 deletions lamindb/_query_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
Transform,
)

from lamindb.core.exceptions import DoesNotExist

if TYPE_CHECKING:
from lnschema_core.types import ListLike, StrField


class NoResultFound(Exception):
pass


class MultipleResultsFound(Exception):
pass

Expand Down Expand Up @@ -59,7 +57,7 @@ def get_keys_from_df(data: list, registry: Record) -> list[str]:

def one_helper(self):
if len(self) == 0:
raise NoResultFound
raise DoesNotExist
elif len(self) > 1:
raise MultipleResultsFound(self)
else:
Expand Down
19 changes: 15 additions & 4 deletions lamindb/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,17 @@ def filter(cls, **expressions) -> QuerySet:

@classmethod # type:ignore
@doc_args(Record.get.__doc__)
def get(cls, idlike: int | str) -> Record:
def get(
cls,
idlike: int | str | None = None,
**expressions,
) -> Record:
"""{}""" # noqa: D415
from lamindb._filter import filter

if isinstance(idlike, int):
return filter(cls, id=idlike).one()
else:
elif isinstance(idlike, str):
qs = filter(cls, uid__startswith=idlike)
if issubclass(cls, IsVersioned):
if len(idlike) <= cls._len_stem_uid:
Expand All @@ -138,20 +142,27 @@ def get(cls, idlike: int | str) -> Record:
return qs.one()
else:
return qs.one()
else:
assert idlike is None # noqa: S101
# below behaves exactly like `.one()`
return cls.objects.get(**expressions)


@classmethod # type:ignore
@doc_args(Record.df.__doc__)
def df(
cls, include: str | list[str] | None = None, join: str = "inner"
cls,
include: str | list[str] | None = None,
join: str = "inner",
limit: int = 100,
) -> pd.DataFrame:
"""{}""" # noqa: D415
from lamindb._filter import filter

query_set = filter(cls)
if hasattr(cls, "updated_at"):
query_set = query_set.order_by("-updated_at")
return query_set.df(include=include, join=join)
return query_set[:limit].df(include=include, join=join)


# from_values doesn't apply for QuerySet or Manager
Expand Down
9 changes: 9 additions & 0 deletions lamindb/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. autosummary::
:toctree: .

DoesNotExist
ValidationError
NotebookNotSavedError
NoTitleError
Expand All @@ -21,6 +22,14 @@ class ValidationError(SystemExit):
pass


# inspired by Django's DoesNotExist
# equivalent to SQLAlchemy's NoResultFound
class DoesNotExist(Exception):
"""No record found."""

pass


# -------------------------------------------------------------------------------------
# ln.track() AKA run_context
# -------------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion sub/lnschema-core
4 changes: 2 additions & 2 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import bionty as bt
import lamindb as ln
import pytest
from lamindb._query_set import MultipleResultsFound, NoResultFound
from lamindb._query_set import DoesNotExist, MultipleResultsFound
from lnschema_core.users import current_user_id


Expand Down Expand Up @@ -88,7 +88,7 @@ def test_one_first():
assert qs.one_or_none().handle == "testuser1"

qs = ln.User.filter(handle="test")
with pytest.raises(NoResultFound):
with pytest.raises(DoesNotExist):
qs.one()
qs = bt.Source.filter().all()
with pytest.raises(MultipleResultsFound):
Expand Down
28 changes: 20 additions & 8 deletions tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import lamindb as ln
import pytest
from lamindb import _record as registry
from lamindb import _record


def test_signatures():
Expand All @@ -18,11 +18,11 @@ class Mock:
# class methods
class_methods = ["filter", "get", "df", "search", "lookup", "from_values", "using"]
for name in class_methods:
setattr(Mock, name, getattr(registry, name))
assert signature(getattr(Mock, name)) == registry.SIGS.pop(name)
setattr(Mock, name, getattr(_record, name))
assert signature(getattr(Mock, name)) == _record.SIGS.pop(name)
# methods
for name, sig in registry.SIGS.items():
assert signature(getattr(registry, name)) == sig
for name, sig in _record.SIGS.items():
assert signature(getattr(_record, name)) == sig


def test_init_with_args():
Expand Down Expand Up @@ -54,7 +54,7 @@ def get_search_test_filepaths():
shutil.rmtree("unregistered_storage/")


def test_search_artifact(get_search_test_filepaths):
def test_search_and_get(get_search_test_filepaths):
artifact1 = ln.Artifact(
"./unregistered_storage/test-search1", description="nonsense"
)
Expand Down Expand Up @@ -102,6 +102,18 @@ def test_search_artifact(get_search_test_filepaths):
# multi-field search
res = ln.Artifact.search("txt", field=["key", "description", "suffix"]).df()
assert res.iloc[0].suffix == ".txt"

# get

artifact = ln.Artifact.get(description="test-search4")
assert artifact == artifact4

# because we're rendering Artifact.DoesNotExist private
# in some use cases, we're not testing for it
with pytest.raises(ln.Artifact._DoesNotExist):
ln.Artifact.get(description="test-search1000000")

#
artifact0.delete(permanent=True, storage=True)
artifact1.delete(permanent=True, storage=True)
artifact2.delete(permanent=True, storage=True)
Expand All @@ -119,7 +131,7 @@ def test_pass_version():
def test_get_name_field():
transform = ln.Transform(name="test")
transform.save()
assert registry.get_name_field(ln.Run(transform)) == "started_at"
assert _record.get_name_field(ln.Run(transform)) == "started_at"
with pytest.raises(ValueError):
registry.get_name_field(ln.Artifact.ulabels.through())
_record.get_name_field(ln.Artifact.ulabels.through())
transform.delete()
Loading