Skip to content

Commit

Permalink
Sketch support for writing, reading sliced AwkwardArrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
danielballan committed Aug 9, 2023
1 parent 8de8e8a commit 830e312
Show file tree
Hide file tree
Showing 18 changed files with 459 additions and 6 deletions.
41 changes: 41 additions & 0 deletions tiled/_tests/test_awkward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import awkward

from ..catalog import in_memory
from ..client import Context, from_context, record_history
from ..server.app import build_app


def test_awkward(tmpdir):
catalog = in_memory(writable_storage=tmpdir)
app = build_app(catalog)
with Context.from_app(app) as context:
client = from_context(context)

# Write data into catalog. It will be stored as directory of buffers
# named like 'node0-offsets' and 'node2-data'.
array = awkward.Array(
[
[{"x": 1.1, "y": [1]}, {"x": 2.2, "y": [1, 2]}],
[],
[{"x": 3.3, "y": [1, 2, 3]}],
]
)
aac = client.write_awkward(array, key="test")

# Read the data back out from the AwkwardArrrayClient, progressively sliced.
assert awkward.almost_equal(aac.read(), array)
assert awkward.almost_equal(aac[:], array)
assert awkward.almost_equal(aac[0], array[0])
assert awkward.almost_equal(aac[0, "y"], array[0, "y"])
assert awkward.almost_equal(aac[0, "y", :1], array[0, "y", :1])

# When sliced, the serer sends less data.
with record_history() as h:
aac[:]
assert len(h.responses) == 1 # sanity check
full_response_size = len(h.responses[0].content)
with record_history() as h:
aac[0, "y"]
assert len(h.responses) == 1 # sanity check
sliced_response_size = len(h.responses[0].content)
assert sliced_response_size < full_response_size
59 changes: 59 additions & 0 deletions tiled/adapters/awkward_buffers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
A directory containing awkward buffers, one file per form key.
"""
from urllib import parse

from ..structures.core import StructureFamily


class AwkwardBuffersAdapter:
structure_family = StructureFamily.awkward

def __init__(
self,
directory,
structure,
metadata=None,
specs=None,
access_policy=None,
):
self.directory = directory
self._metadata = metadata or {}
self._structure = structure
self.specs = list(specs or [])
self.access_policy = access_policy

def metadata(self):
return self._metadata

@classmethod
def init_storage(cls, directory, structure):
from ..server.schemas import Asset

directory.mkdir()
data_uri = parse.urlunparse(("file", "localhost", str(directory), "", "", None))
return [Asset(data_uri=data_uri, is_directory=True)]

def write(self, data):
for form_key, value in data.items():
with open(self.directory / form_key, "wb") as file:
file.write(value)

def read(self, form_keys=None):
selected_suffixed_form_keys = []
if form_keys is None:
# Read all.
selected_suffixed_form_keys.extend(self._structure.suffixed_form_keys)
else:
for form_key in form_keys:
for suffixed_form_key in self._structure.suffixed_form_keys:
if suffixed_form_key.startswith(form_key):
selected_suffixed_form_keys.append(suffixed_form_key)
buffers = {}
for form_key in selected_suffixed_form_keys:
with open(self.directory / form_key, "rb") as file:
buffers[form_key] = file.read()
return buffers

def structure(self):
return self._structure
14 changes: 14 additions & 0 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
PARQUET_MIMETYPE,
SPARSE_BLOCKS_PARQUET_MIMETYPE,
ZARR_MIMETYPE,
ZIP_MIMETYPE,
)
from .utils import SCHEME_PATTERN, ensure_uri, safe_path

Expand All @@ -57,6 +58,7 @@

DEFAULT_CREATION_MIMETYPE = {
StructureFamily.array: ZARR_MIMETYPE,
StructureFamily.awkward: ZIP_MIMETYPE,
StructureFamily.table: PARQUET_MIMETYPE,
StructureFamily.sparse: SPARSE_BLOCKS_PARQUET_MIMETYPE,
}
Expand All @@ -65,6 +67,9 @@
ZARR_MIMETYPE: lambda: importlib.import_module(
"...adapters.zarr", __name__
).ZarrArrayAdapter.init_storage,
ZIP_MIMETYPE: lambda: importlib.import_module(
"...adapters.awkward_buffers", __name__
).AwkwardBuffersAdapter.init_storage,
PARQUET_MIMETYPE: lambda: importlib.import_module(
"...adapters.parquet", __name__
).ParquetDatasetAdapter.init_storage,
Expand Down Expand Up @@ -820,6 +825,14 @@ async def write_block(self, *args, **kwargs):
)


class CatalogAwkwardAdapter(CatalogNodeAdapter):
async def read(self, *args, **kwargs):
return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs)

async def write(self, *args, **kwargs):
return await ensure_awaitable((await self.get_adapter()).write, *args, **kwargs)


class CatalogSparseAdapter(CatalogArrayAdapter):
pass

Expand Down Expand Up @@ -1082,6 +1095,7 @@ def json_serializer(obj):
STRUCTURES = {
StructureFamily.container: CatalogContainerAdapter,
StructureFamily.array: CatalogArrayAdapter,
StructureFamily.awkward: CatalogAwkwardAdapter,
StructureFamily.table: CatalogTableAdapter,
StructureFamily.sparse: CatalogSparseAdapter,
}
4 changes: 2 additions & 2 deletions tiled/catalog/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

# This is the alembic revision ID of the database revision
# required by this version of Tiled.
REQUIRED_REVISION = "83889e049ddc"
REQUIRED_REVISION = "0b033e7fbe30"

# This is list of all valid revisions (from current to oldest).
ALL_REVISIONS = ["83889e049ddc", "6825c778aa3c"]
ALL_REVISIONS = ["0b033e7fbe30", "83889e049ddc", "6825c778aa3c"]


async def initialize_database(engine):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Add 'awkward' to structurefamily enum.
Revision ID: 0b033e7fbe30
Revises: 83889e049ddc
Create Date: 2023-08-08 21:10:20.181470
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "0b033e7fbe30"
down_revision = "83889e049ddc"
branch_labels = None
depends_on = None


def upgrade():
connection = op.get_bind()

if connection.engine.dialect.name == "postgresql":
with op.get_context().autocommit_block():
op.execute(
sa.text(
"ALTER TYPE structurefamily ADD VALUE IF NOT EXISTS 'awkward' AFTER 'array'"
)
)


def downgrade():
# This _could_ be implemented but we will wait for a need since we are
# still in alpha releases.
raise NotImplementedError
4 changes: 4 additions & 0 deletions tiled/catalog/mimetypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# for importing Readers that we will not actually use.
PARQUET_MIMETYPE = "application/x-parquet"
SPARSE_BLOCKS_PARQUET_MIMETYPE = "application/x-parquet-sparse" # HACK!
ZIP_MIMETYPE = "application/zip"
ZARR_MIMETYPE = "application/x-zarr"
DEFAULT_ADAPTERS_BY_MIMETYPE = OneShotCachedMap(
{
Expand Down Expand Up @@ -38,6 +39,9 @@
ZARR_MIMETYPE: lambda: importlib.import_module(
"...adapters.zarr", __name__
).read_zarr,
ZIP_MIMETYPE: lambda: importlib.import_module(
"...adapters.awkward_buffers", __name__
).AwkwardBuffersAdapter,
}
)

Expand Down
52 changes: 52 additions & 0 deletions tiled/client/awkward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import awkward

from ..serialization.awkward import from_zipped_buffers, to_zipped_buffers
from ..structures.awkward import project_form
from .base import BaseClient
from .utils import handle_error


class AwkwardArrayClient(BaseClient):
def __repr__(self):
# TODO Include some summary of the structure. Probably
# lift __repr__ code from awkward itself here.
return f"<{type(self).__name__}>"

def write(self, container):
handle_error(
self.context.http_client.put(
self.item["links"]["full"],
content=bytes(to_zipped_buffers(container, {})),
headers={"Content-Type": "application/zip"},
)
)

def read(self, slice=...):
structure = self.structure()
form = awkward.forms.from_dict(structure.form)
typetracer, report = awkward.typetracer.typetracer_with_report(
form,
forget_length=True,
)
proxy_array = awkward.Array(typetracer)
# TODO Ask awkward to promote _touch_data to a public method.
proxy_array[slice].layout._touch_data(recursive=True)
form_keys_touched = set(report.data_touched)
projected_form = project_form(form, form_keys_touched)
# The order is not important, but sort so that the request is deterministic.
params = {"form_key": sorted(list(form_keys_touched))}
content = handle_error(
self.context.http_client.get(
self.item["links"]["full"],
headers={"Accept": "application/zip"},
params=params,
)
).read()
container = from_zipped_buffers(content)
projected_array = awkward.from_buffers(
projected_form, structure.length, container
)
return projected_array[slice]

def __getitem__(self, slice):
return self.read(slice=slice)
3 changes: 3 additions & 0 deletions tiled/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def metadata_revisions(self):
StructureFamily.array: lambda: importlib.import_module(
"...structures.array", BaseClient.__module__
).ArrayStructure,
StructureFamily.awkward: lambda: importlib.import_module(
"...structures.awkward", BaseClient.__module__
).AwkwardStructure,
StructureFamily.table: lambda: importlib.import_module(
"...structures.table", BaseClient.__module__
).TableStructure,
Expand Down
51 changes: 51 additions & 0 deletions tiled/client/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,52 @@ def write_block(x, block_id, client):
da.map_blocks(write_block, dtype=da.dtype, client=client).compute()
return client

def write_awkward(
self,
array,
*,
key=None,
metadata=None,
dims=None,
specs=None,
):
"""
Write an AwkwardArray.
Parameters
----------
array: awkward.Array
key : str, optional
Key (name) for this new node. If None, the server will provide a unique key.
metadata : dict, optional
User metadata. May be nested. Must contain only basic types
(e.g. numbers, strings, lists, dicts) that are JSON-serializable.
dims : List[str], optional
A label for each dimension of the array.
specs : List[Spec], optional
List of names that are used to label that the data and/or metadata
conform to some named standard specification.
"""
import awkward

from ..structures.awkward import AwkwardStructure

form, length, container = awkward.to_buffers(array)
structure = AwkwardStructure(
length=length,
form=form.to_dict(),
suffixed_form_keys=list(container),
)
client = self.new(
StructureFamily.awkward,
structure,
key=key,
metadata=metadata,
specs=specs,
)
client.write(container)
return client

def write_sparse(
self,
coords,
Expand Down Expand Up @@ -921,6 +967,9 @@ def __call__(self):
{
"container": _Wrap(Container),
"array": _LazyLoad(("..array", Container.__module__), "ArrayClient"),
"awkward": _LazyLoad(
("..awkward", Container.__module__), "AwkwardArrayClient"
),
"dataframe": _LazyLoad(
("..dataframe", Container.__module__), "DataFrameClient"
),
Expand All @@ -937,6 +986,8 @@ def __call__(self):
{
"container": _Wrap(Container),
"array": _LazyLoad(("..array", Container.__module__), "DaskArrayClient"),
# TODO Create DaskAwkwardArrayClient
# "awkward": _LazyLoad(("..awkward", Container.__module__), "DaskAwkwardArrayClient"),
"dataframe": _LazyLoad(
("..dataframe", Container.__module__), "DaskDataFrameClient"
),
Expand Down
2 changes: 1 addition & 1 deletion tiled/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def client_for_item(context, structure_clients, item, structure=None):
def params_from_slice(slice):
"Generate URL query param ?slice=... from Python slice object."
params = {}
if slice is not None:
if (slice is not None) and (slice is not ...):
if isinstance(slice, (int, builtins.slice)):
slice = [slice]
slices = []
Expand Down
4 changes: 4 additions & 0 deletions tiled/serialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def register_builtin_serializers():
from ..serialization import array as _array # noqa: F401

del _array
if modules_available("awkward"):
from ..serialization import awkward as _awkward # noqa: F401

del _awkward
if modules_available("pandas", "pyarrow", "dask.dataframe"):
from ..serialization import table as _table # noqa: F401

Expand Down
28 changes: 28 additions & 0 deletions tiled/serialization/awkward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import io
import zipfile

from ..media_type_registration import deserialization_registry, serialization_registry


@serialization_registry.register("awkward", "application/zip")
def to_zipped_buffers(container, metadata):
file = io.BytesIO()
# Pack multiple buffers into a zipfile, uncompressed. This enables
# multiple buffers in a single response, with random access. The
# entire payload *may* be compressed using Tiled's normal compression
# mechanisms.
with zipfile.ZipFile(file, "w", compresslevel=zipfile.ZIP_STORED) as zip:
for form_key, buffer in container.items():
zip.writestr(form_key, buffer)
return file.getbuffer()


@deserialization_registry.register("awkward", "application/zip")
def from_zipped_buffers(buffer):
file = io.BytesIO(buffer)
with zipfile.ZipFile(file, "r") as zip:
form_keys = zip.namelist()
buffers = {}
for form_key in form_keys:
buffers[form_key] = zip.read(form_key)
return buffers
Loading

0 comments on commit 830e312

Please sign in to comment.