-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add seamless
numpy
-like functionality
- Loading branch information
Showing
7 changed files
with
166 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,55 @@ | ||
"""Useful helper functions.""" | ||
from collections.abc import Sequence | ||
import logging | ||
from textwrap import dedent | ||
|
||
import numpy as np | ||
|
||
from .pycuvec import vec_types | ||
from .pycuvec import cu_copy, cu_zeros, vec_types | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class CuVec(np.ndarray): | ||
""" | ||
A `numpy.ndarray` compatible view with a `cuvec` member containing the | ||
underlying `cuvec.Vector_*` object (for use in CPython API function calls). | ||
""" | ||
_Vector_types = tuple(vec_types.values()) | ||
|
||
def __new__(cls, arr, cuvec=None): | ||
"""arr: `cuvec.CuVec`, raw `cuvec.Vector_*`, or `numpy.ndarray`""" | ||
if isinstance(arr, CuVec._Vector_types): | ||
log.debug("wrap raw %s", type(arr)) | ||
obj = np.asarray(arr).view(cls) | ||
obj.cuvec = arr | ||
return obj | ||
if isinstance(arr, CuVec): | ||
log.debug("new view") | ||
obj = np.asarray(arr).view(cls) | ||
obj.cuvec = arr.cuvec | ||
return obj | ||
if isinstance(arr, np.ndarray): | ||
log.debug("copy") | ||
return copy(arr) | ||
raise NotImplementedError( | ||
dedent("""\ | ||
Not intended for explicit construction | ||
(do not do `cuvec.CuVec((42, 1337))`; | ||
instead use `cuvec.zeros((42, 137))`""")) | ||
|
||
|
||
def zeros(shape, dtype="float32"): | ||
""" | ||
Returns a new `Vector_*` of the specified shape and data type | ||
(`cuvec` equivalent of `numpy.zeros`). | ||
Returns a `cuvec.CuVec` view of a new `numpy.ndarray` | ||
of the specified shape and data type (`cuvec` equivalent of `numpy.zeros`). | ||
""" | ||
return vec_types[np.dtype(dtype)](shape if isinstance(shape, Sequence) else (shape,)) | ||
return CuVec(cu_zeros(shape, dtype)) | ||
|
||
|
||
def from_numpy(arr): | ||
def copy(arr): | ||
""" | ||
Returns a new `Vector_*` of the specified shape and data type | ||
Returns a `cuvec.CuVec` view of a new `numpy.ndarray` | ||
with data copied from the specified `arr` | ||
(`cuvec` equivalent of `numpy.copy`). | ||
""" | ||
res = zeros(arr.shape, arr.dtype) | ||
np.asarray(res)[:] = arr[:] | ||
return res | ||
return CuVec(cu_copy(arr)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import logging | ||
|
||
import numpy as np | ||
from pytest import mark, raises | ||
|
||
import cuvec | ||
|
||
shape = 127, 344, 344 | ||
|
||
|
||
@mark.parametrize("spec,result", [("i", np.int32), ("d", np.float64)]) | ||
def test_zeros(spec, result): | ||
a = np.asarray(cuvec.zeros(shape, spec)) | ||
assert a.dtype == result | ||
assert a.shape == shape | ||
assert not a.any() | ||
|
||
|
||
def test_copy(): | ||
a = np.random.random(shape) | ||
b = np.asarray(cuvec.copy(a)) | ||
assert a.shape == b.shape | ||
assert a.dtype == b.dtype | ||
assert (a == b).all() | ||
|
||
|
||
def test_CuVec_creation(caplog): | ||
with raises(TypeError): | ||
cuvec.CuVec() | ||
|
||
with raises(NotImplementedError): | ||
cuvec.CuVec(shape) | ||
|
||
caplog.set_level(logging.DEBUG) | ||
caplog.clear() | ||
v = cuvec.CuVec(np.ones(shape, dtype='h')) | ||
assert [i[1:] for i in caplog.record_tuples] == [(10, 'copy'), | ||
(10, "wrap raw <class 'Vector_h'>")] | ||
assert v.shape == shape | ||
assert v.dtype.char == 'h' | ||
assert (v == 1).all() | ||
|
||
caplog.clear() | ||
v = cuvec.zeros(shape, 'd') | ||
assert [i[1:] for i in caplog.record_tuples] == [(10, "wrap raw <class 'Vector_d'>")] | ||
|
||
caplog.clear() | ||
v[0, 0, 0] = 1 | ||
assert not caplog.record_tuples | ||
w = cuvec.CuVec(v) | ||
assert [i[1:] for i in caplog.record_tuples] == [(10, "new view")] | ||
|
||
caplog.clear() | ||
assert w[0, 0, 0] == 1 | ||
v[0, 0, 0] = 9 | ||
assert w[0, 0, 0] == 9 | ||
assert v.cuvec is w.cuvec | ||
assert v.data == w.data | ||
assert not caplog.record_tuples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters