Skip to content

Commit

Permalink
Add types to rows (#49)
Browse files Browse the repository at this point in the history
* fix: 🐛 fix report

* feat: 🎸 add "features" (column types) to /rows response

It allows to parse the rows adequately.
  • Loading branch information
severo authored Sep 24, 2021
1 parent 13e5332 commit c2a78e7
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 13 deletions.
32 changes: 31 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,40 @@ Parameters:

Responses:

- `200`: JSON content with the following structure:
- `200`: JSON content that provides the types of the columns (see https://huggingface.co/docs/datasets/about_dataset_features.html) and the data rows, with the following structure:

```json
{
"features": [
{
"dataset": "glue",
"config": "ax",
"features": {
"premise": {
"dtype": "string",
"id": null,
"_type": "Value"
},
"hypothesis": {
"dtype": "string",
"id": null,
"_type": "Value"
},
"label": {
"num_classes": 3,
"names": ["entailment", "neutral", "contradiction"],
"names_file": null,
"id": null,
"_type": "ClassLabel"
},
"idx": {
"dtype": "int32",
"id": null,
"_type": "Value"
}
}
}
],
"rows": [
{
"dataset": "glue",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "datasets-preview-backend"
version = "0.8.0"
version = "0.8.1"
description = "API to extract rows of 🤗 datasets"
authors = ["Sylvain Lesage <[email protected]>"]

Expand Down
46 changes: 40 additions & 6 deletions src/datasets_preview_backend/queries/rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from datasets_preview_backend.constants import DATASETS_BLOCKLIST
from datasets_preview_backend.exceptions import Status400Error, Status404Error
from datasets_preview_backend.queries.configs import get_configs_response
from datasets_preview_backend.queries.infos import get_infos_response
from datasets_preview_backend.queries.splits import get_splits_response
from datasets_preview_backend.responses import CachedResponse
from datasets_preview_backend.types import (
ConfigsContent,
FeatureItem,
InfosContent,
RowItem,
RowsContent,
SplitsContent,
Expand Down Expand Up @@ -41,6 +44,9 @@ def get_rows(
raise TypeError("split argument should be a string")
num_rows = EXTRACT_ROWS_LIMIT

rowItems: List[RowItem] = []
featureItems: List[FeatureItem] = []

if config is not None and split is not None:
try:
iterable_dataset = load_dataset(dataset, name=config, split=split, streaming=True, use_auth_token=token)
Expand Down Expand Up @@ -81,7 +87,28 @@ def get_rows(
f"could not read all the required rows ({len(rows)} / {num_rows}) from dataset {dataset} -"
f" {config} - {split}"
)
return {"rows": [{"dataset": dataset, "config": config, "split": split, "row": row} for row in rows]}

rowItems = [{"dataset": dataset, "config": config, "split": split, "row": row} for row in rows]

content = get_infos_response(dataset=dataset, config=config, token=token).content
if "infos" not in content:
error = cast(StatusErrorContent, content)
if "status_code" in error and error["status_code"] == 404:
raise Status404Error("features could not be found")
raise Status400Error("features could not be found")
infos_content = cast(InfosContent, content)
infoItems = [infoItem["info"] for infoItem in infos_content["infos"]]

if len(infoItems) != 1:
raise Exception("a dataset config should have exactly one info")
infoItem = infoItems[0]
if "features" not in infoItem:
raise Status400Error("a dataset config info should contain a 'features' property")
localFeaturesItems: List[FeatureItem] = [
{"dataset": dataset, "config": config, "features": infoItem["features"]}
]

return {"features": localFeaturesItems, "rows": rowItems}

if config is None:
content = get_configs_response(dataset=dataset, token=token).content
Expand All @@ -95,7 +122,6 @@ def get_rows(
else:
configs = [config]

rowItems: List[RowItem] = []
# Note that we raise on the first error
for config in configs:
content = get_splits_response(dataset=dataset, config=config, token=token).content
Expand All @@ -108,12 +134,20 @@ def get_rows(
splits = [splitItem["split"] for splitItem in splits_content["splits"]]

for split in splits:
rows_content = cast(
RowsContent, get_rows_response(dataset=dataset, config=config, split=split, token=token).content
)
content = get_rows_response(dataset=dataset, config=config, split=split, token=token).content
if "rows" not in content:
error = cast(StatusErrorContent, content)
if "status_code" in error and error["status_code"] == 404:
raise Status404Error("rows could not be found")
raise Status400Error("rows could not be found")
rows_content = cast(RowsContent, content)
rowItems += rows_content["rows"]
for featureItem in rows_content["features"]:
# there should be only one element. Anyway, let's loop
if featureItem not in featureItems:
featureItems.append(featureItem)

return {"rows": rowItems}
return {"features": featureItems, "rows": rowItems}


@memoize(cache, expire=CACHE_TTL_SECONDS) # type:ignore
Expand Down
6 changes: 3 additions & 3 deletions src/datasets_preview_backend/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def __init__(
if response is not None:
# response might be too heavy (we don't want to replicate the cache)
# we get the essence of the response, depending on the case
if "info" in response:
self.result = {"info_num_keys": len(response["info"])}
if "infos" in response:
self.result = {"infos_length": len(response["infos"])}
elif "configs" in response:
self.result = {"configs": [c["config"] for c in response["configs"]]}
elif "splits" in response:
self.result = {"splits": response["splits"]}
elif "rows" in response:
self.result = {"rows_length": len(response["rows"])}
self.result = {"rows_length": len(response["rows"]), "features_length": len(response["features"])}
else:
self.result = {}

Expand Down
7 changes: 7 additions & 0 deletions src/datasets_preview_backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class RowItem(TypedDict):
row: Any


class FeatureItem(TypedDict):
dataset: str
config: str
features: Any


# Content of endpoint responses


Expand All @@ -49,6 +55,7 @@ class SplitsContent(TypedDict):


class RowsContent(TypedDict):
features: List[FeatureItem]
rows: List[RowItem]


Expand Down
20 changes: 18 additions & 2 deletions tests/queries/test_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,41 @@ def test_get_split_rows() -> None:
assert rowItem["row"]["tokens"][0] == "What"


def test_get_split_features() -> None:
dataset = "acronym_identification"
config = DEFAULT_CONFIG_NAME
split = "train"
response = get_rows(dataset, config, split)
assert "features" in response
assert len(response["features"]) == 1
featureItem = response["features"][0]
assert "dataset" in featureItem
assert "config" in featureItem
assert "features" in featureItem
assert featureItem["features"]["tokens"]["_type"] == "Sequence"


def test_get_split_rows_without_split() -> None:
dataset = "acronym_identification"
response = get_rows(dataset, DEFAULT_CONFIG_NAME)
rows = response["rows"]
assert len(rows) == 3 * EXTRACT_ROWS_LIMIT
assert len(response["rows"]) == 3 * EXTRACT_ROWS_LIMIT
assert len(response["features"]) == 1


def test_get_split_rows_without_config() -> None:
dataset = "acronym_identification"
split = "train"
response1 = get_rows(dataset)
assert len(response1["rows"]) == 1 * 3 * EXTRACT_ROWS_LIMIT
assert len(response1["features"]) == 1

response2 = get_rows(dataset, None, split)
assert response1 == response2

dataset = "adversarial_qa"
response3 = get_rows(dataset)
assert len(response3["rows"]) == 4 * 3 * EXTRACT_ROWS_LIMIT
assert len(response3["features"]) == 4


def test_get_unknown_dataset() -> None:
Expand Down

0 comments on commit c2a78e7

Please sign in to comment.