Skip to content

Commit

Permalink
Add __binsparse_descriptor__ and __binsparse_dlpack__.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Sep 3, 2024
1 parent fb0affe commit a2a562b
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 1 deletion.
38 changes: 37 additions & 1 deletion sparse/numba_backend/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,15 @@ def isinf(self):
def isnan(self):
return self.tocoo().isnan().asformat("gcxs", compressed_axes=self.compressed_axes)

# `GCXS` is a reshaped/transposed `CSR`, but it can't (usually)
# be expressed in the `binsparse` 0.1 language.
# We are missing index maps.
def __binsparse_descriptor__(self) -> dict:
return super().__binsparse_descriptor__()

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return super().__binsparse_dlpack__()


class _Compressed2d(GCXS):
class_compressed_axes: tuple[int]
Expand Down Expand Up @@ -883,6 +892,34 @@ def from_numpy(cls, x, fill_value=0, idx_dtype=None):
coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype)
return cls.from_coo(coo, cls.class_compressed_axes, idx_dtype)

def __binsparse_descriptor__(self) -> dict:
from sparse._version import __version__

data_dt = str(self.data.dtype)
if np.issubdtype(data_dt, np.complexfloating):
data_dt = f"complex[float{self.data.dtype.itemsize // 2}]"
return {
"binsparse": {
"version": "0.1",
"format": self.format.upper(),
"shape": list(self.shape),
"number_of_stored_values": self.nnz,
"data_types": {
"pointers_to_1": str(self.indices.dtype),
"indices_1": str(self.indptr.dtype),
"values": data_dt,
},
},
"original_source": f"`sparse`, version {__version__}",
}

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return {
"pointers_to_1": self.indices,
"indices_1": self.indptr,
"values": self.data,
}


class CSR(_Compressed2d):
"""
Expand Down Expand Up @@ -915,7 +952,6 @@ def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"
return self
return CSC((self.data, self.indices, self.indptr), self.shape[::-1])


class CSC(_Compressed2d):
"""
The CSC or CCS scheme stores a n-dimensional array using n+1 one-dimensional arrays.
Expand Down
38 changes: 38 additions & 0 deletions sparse/numba_backend/_coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,44 @@ def isnan(self):
prune=True,
)

def __binsparse_descriptor__(self) -> dict:
from sparse._version import __version__

data_dt = str(self.data.dtype)
if np.issubdtype(data_dt, np.complexfloating):
data_dt = f"complex[float{self.data.dtype.itemsize // 2}]"
return {
"binsparse": {
"version": "0.1",
"format": {
"custom": {
"level": {
"level_desc": "sparse",
"rank": self.ndim,
"level": {
"level_desc": "element",
},
}
}
},
"shape": list(self.shape),
"number_of_stored_values": self.nnz,
"data_types": {
"pointers_to_1": "uint8",
"indices_1": str(self.coords.dtype),
"values": data_dt,
},
},
"original_source": f"`sparse`, version {__version__}",
}

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
return {
"pointers_to_1": np.array([0, self.nnz], dtype=np.uint8),
"indices_1": self.coords,
"values": self.data,
}


def as_coo(x, shape=None, fill_value=None, idx_dtype=None):
"""
Expand Down
6 changes: 6 additions & 0 deletions sparse/numba_backend/_dok.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,12 @@ def reshape(self, shape, order="C"):

return DOK.from_coo(self.to_coo().reshape(shape))

def __binsparse_descriptor__(self) -> dict:
raise RuntimeError("`DOK` doesn't support the `__binsparse_descriptor__` protocol.")

def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
raise RuntimeError("`DOK` doesn't support the `__binsparse_dlpack__` protocol.")


def to_slice(k):
"""Convert integer indices to one-element slices for consistency"""
Expand Down
25 changes: 25 additions & 0 deletions sparse/numba_backend/_sparse_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,31 @@ def _str_impl(self, summary):
except (ImportError, ValueError):
return summary

@abstractmethod
def __binsparse_descriptor__(self) -> dict:
"""Return a `dict` equivalent to a parsed JSON [`binsparse` descriptor](https://graphblas.org/binsparse-specification/#descriptor)
of this array.
Returns
-------
dict
Parsed `binsparse` descriptor.
"""
raise NotImplementedError

@abstractmethod
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
"""A `dict` containing the constituent arrays of this sparse array. The keys are compatible with the
[`binsparse`](https://graphblas.org/binsparse-specification/) scheme, and the values are [`__dlpack__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html)
compatible objects.
Returns
-------
dict[str, np.ndarray]
The constituent arrays.
"""
raise NotImplementedError

@abstractmethod
def asformat(self, format):
"""
Expand Down

0 comments on commit a2a562b

Please sign in to comment.