diff --git a/README.md b/README.md index 99fa1e3ffb..cf5c66218e 100644 --- a/README.md +++ b/README.md @@ -461,10 +461,40 @@ Parameters: Responses: -- `200`: JSON content with the following structure: +- `200`: JSON content that provides the types of the columns (see https://huggingface.co/docs/datasets/about_dataset_features.html) and the data rows, with the following structure: ```json { + "features": [ + { + "dataset": "glue", + "config": "ax", + "features": { + "premise": { + "dtype": "string", + "id": null, + "_type": "Value" + }, + "hypothesis": { + "dtype": "string", + "id": null, + "_type": "Value" + }, + "label": { + "num_classes": 3, + "names": ["entailment", "neutral", "contradiction"], + "names_file": null, + "id": null, + "_type": "ClassLabel" + }, + "idx": { + "dtype": "int32", + "id": null, + "_type": "Value" + } + } + } + ], "rows": [ { "dataset": "glue", diff --git a/pyproject.toml b/pyproject.toml index 5fce1edb4e..6072f9aec7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "datasets-preview-backend" -version = "0.8.0" +version = "0.8.1" description = "API to extract rows of 🤗 datasets" authors = ["Sylvain Lesage "] diff --git a/src/datasets_preview_backend/queries/rows.py b/src/datasets_preview_backend/queries/rows.py index 0c32ee230f..8fe309c8c9 100644 --- a/src/datasets_preview_backend/queries/rows.py +++ b/src/datasets_preview_backend/queries/rows.py @@ -9,10 +9,13 @@ from datasets_preview_backend.constants import DATASETS_BLOCKLIST from datasets_preview_backend.exceptions import Status400Error, Status404Error from datasets_preview_backend.queries.configs import get_configs_response +from datasets_preview_backend.queries.infos import get_infos_response from datasets_preview_backend.queries.splits import get_splits_response from datasets_preview_backend.responses import CachedResponse from datasets_preview_backend.types import ( ConfigsContent, + FeatureItem, + InfosContent, RowItem, RowsContent, SplitsContent, @@ -41,6 +44,9 @@ def get_rows( raise TypeError("split argument should be a string") num_rows = EXTRACT_ROWS_LIMIT + rowItems: List[RowItem] = [] + featureItems: List[FeatureItem] = [] + if config is not None and split is not None: try: iterable_dataset = load_dataset(dataset, name=config, split=split, streaming=True, use_auth_token=token) @@ -81,7 +87,28 @@ def get_rows( f"could not read all the required rows ({len(rows)} / {num_rows}) from dataset {dataset} -" f" {config} - {split}" ) - return {"rows": [{"dataset": dataset, "config": config, "split": split, "row": row} for row in rows]} + + rowItems = [{"dataset": dataset, "config": config, "split": split, "row": row} for row in rows] + + content = get_infos_response(dataset=dataset, config=config, token=token).content + if "infos" not in content: + error = cast(StatusErrorContent, content) + if "status_code" in error and error["status_code"] == 404: + raise Status404Error("features could not be found") + raise Status400Error("features could not be found") + infos_content = cast(InfosContent, content) + infoItems = [infoItem["info"] for infoItem in infos_content["infos"]] + + if len(infoItems) != 1: + raise Exception("a dataset config should have exactly one info") + infoItem = infoItems[0] + if "features" not in infoItem: + raise Status400Error("a dataset config info should contain a 'features' property") + localFeaturesItems: List[FeatureItem] = [ + {"dataset": dataset, "config": config, "features": infoItem["features"]} + ] + + return {"features": localFeaturesItems, "rows": rowItems} if config is None: content = get_configs_response(dataset=dataset, token=token).content @@ -95,7 +122,6 @@ def get_rows( else: configs = [config] - rowItems: List[RowItem] = [] # Note that we raise on the first error for config in configs: content = get_splits_response(dataset=dataset, config=config, token=token).content @@ -108,12 +134,20 @@ def get_rows( splits = [splitItem["split"] for splitItem in splits_content["splits"]] for split in splits: - rows_content = cast( - RowsContent, get_rows_response(dataset=dataset, config=config, split=split, token=token).content - ) + content = get_rows_response(dataset=dataset, config=config, split=split, token=token).content + if "rows" not in content: + error = cast(StatusErrorContent, content) + if "status_code" in error and error["status_code"] == 404: + raise Status404Error("rows could not be found") + raise Status400Error("rows could not be found") + rows_content = cast(RowsContent, content) rowItems += rows_content["rows"] + for featureItem in rows_content["features"]: + # there should be only one element. Anyway, let's loop + if featureItem not in featureItems: + featureItems.append(featureItem) - return {"rows": rowItems} + return {"features": featureItems, "rows": rowItems} @memoize(cache, expire=CACHE_TTL_SECONDS) # type:ignore diff --git a/src/datasets_preview_backend/reports.py b/src/datasets_preview_backend/reports.py index 63281c6507..7655a173c0 100644 --- a/src/datasets_preview_backend/reports.py +++ b/src/datasets_preview_backend/reports.py @@ -33,14 +33,14 @@ def __init__( if response is not None: # response might be too heavy (we don't want to replicate the cache) # we get the essence of the response, depending on the case - if "info" in response: - self.result = {"info_num_keys": len(response["info"])} + if "infos" in response: + self.result = {"infos_length": len(response["infos"])} elif "configs" in response: self.result = {"configs": [c["config"] for c in response["configs"]]} elif "splits" in response: self.result = {"splits": response["splits"]} elif "rows" in response: - self.result = {"rows_length": len(response["rows"])} + self.result = {"rows_length": len(response["rows"]), "features_length": len(response["features"])} else: self.result = {} diff --git a/src/datasets_preview_backend/types.py b/src/datasets_preview_backend/types.py index 81f7c2ce57..f23d476120 100644 --- a/src/datasets_preview_backend/types.py +++ b/src/datasets_preview_backend/types.py @@ -29,6 +29,12 @@ class RowItem(TypedDict): row: Any +class FeatureItem(TypedDict): + dataset: str + config: str + features: Any + + # Content of endpoint responses @@ -49,6 +55,7 @@ class SplitsContent(TypedDict): class RowsContent(TypedDict): + features: List[FeatureItem] rows: List[RowItem] diff --git a/tests/queries/test_rows.py b/tests/queries/test_rows.py index e684a6c02e..756f4d535a 100644 --- a/tests/queries/test_rows.py +++ b/tests/queries/test_rows.py @@ -30,11 +30,25 @@ def test_get_split_rows() -> None: assert rowItem["row"]["tokens"][0] == "What" +def test_get_split_features() -> None: + dataset = "acronym_identification" + config = DEFAULT_CONFIG_NAME + split = "train" + response = get_rows(dataset, config, split) + assert "features" in response + assert len(response["features"]) == 1 + featureItem = response["features"][0] + assert "dataset" in featureItem + assert "config" in featureItem + assert "features" in featureItem + assert featureItem["features"]["tokens"]["_type"] == "Sequence" + + def test_get_split_rows_without_split() -> None: dataset = "acronym_identification" response = get_rows(dataset, DEFAULT_CONFIG_NAME) - rows = response["rows"] - assert len(rows) == 3 * EXTRACT_ROWS_LIMIT + assert len(response["rows"]) == 3 * EXTRACT_ROWS_LIMIT + assert len(response["features"]) == 1 def test_get_split_rows_without_config() -> None: @@ -42,6 +56,7 @@ def test_get_split_rows_without_config() -> None: split = "train" response1 = get_rows(dataset) assert len(response1["rows"]) == 1 * 3 * EXTRACT_ROWS_LIMIT + assert len(response1["features"]) == 1 response2 = get_rows(dataset, None, split) assert response1 == response2 @@ -49,6 +64,7 @@ def test_get_split_rows_without_config() -> None: dataset = "adversarial_qa" response3 = get_rows(dataset) assert len(response3["rows"]) == 4 * 3 * EXTRACT_ROWS_LIMIT + assert len(response3["features"]) == 4 def test_get_unknown_dataset() -> None: