Skip to content

Commit

Permalink
add zeros, from_numpy helpers
Browse files Browse the repository at this point in the history
- deprecate `vector`
- add tests
  • Loading branch information
casperdcl committed Jan 17, 2021
1 parent 8bac9be commit 8ee6305
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 25 deletions.
9 changes: 5 additions & 4 deletions cuvec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,22 @@
# config
'cmake_prefix', 'include_path',
# functions
'dev_sync', 'vector',
'dev_sync', 'from_numpy', 'zeros',
# data
'typecodes'] # yapf: disable
'typecodes', 'vec_types'] # yapf: disable

from pathlib import Path

from pkg_resources import resource_filename

try:
from .cuvec import dev_sync
from .pycuvec import typecodes, vector
except ImportError as err:
from warnings import warn
warn(str(err), UserWarning)
dev_sync = vector = None
else:
from .helpers import from_numpy, zeros
from .pycuvec import typecodes, vec_types

# for use in `cmake -DCMAKE_PREFIX_PATH=...`
cmake_prefix = Path(resource_filename(__name__, "cmake")).resolve()
Expand Down
24 changes: 24 additions & 0 deletions cuvec/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Useful helper functions."""
from collections.abc import Sequence

import numpy as np

from .pycuvec import vec_types


def zeros(shape, dtype="float32"):
"""
Returns a new `Vector_*` of the specified shape and data type
(`cuvec` equivalent of `numpy.zeros`).
"""
return vec_types[np.dtype(dtype)](shape if isinstance(shape, Sequence) else (shape,))


def from_numpy(arr):
"""
Returns a new `Vector_*` of the specified shape and data type
(`cuvec` equivalent of `numpy.copy`).
"""
res = zeros(arr.shape, arr.dtype)
np.asarray(res)[:] = arr[:]
return res
5 changes: 0 additions & 5 deletions cuvec/pycuvec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Thin wrappers around `cuvec` C++/CUDA module"""
import array
from collections.abc import Sequence

import numpy as np

Expand Down Expand Up @@ -31,7 +30,3 @@
np.dtype('uint64'): Vector_Q,
np.dtype('float32'): Vector_f,
np.dtype('float64'): Vector_d}


def vector(shape, dtype=np.float32):
return vec_types[np.dtype(dtype)](shape if isinstance(shape, Sequence) else (shape,))
23 changes: 7 additions & 16 deletions tests/test_cuvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
import cuvec


@mark.parametrize("vtype", list(cuvec.typecodes))
def test_Vector_asarray(vtype):
"""vtype(char): any of bBhHiIqQfd"""
v = getattr(cuvec.cuvec, f"Vector_{vtype}")((1, 2, 3))
assert str(v) == f"Vector_{vtype}((1, 2, 3))"
@mark.parametrize("tp", list(cuvec.typecodes))
def test_Vector_asarray(tp):
"""tp(char): any of bBhHiIqQfd"""
v = getattr(cuvec.cuvec, f"Vector_{tp}")((1, 2, 3))
assert str(v) == f"Vector_{tp}((1, 2, 3))"
a = np.asarray(v)
assert not a.any()
a[0, 0] = 42
b = np.asarray(v)
assert (b[0, 0] == 42).all()
assert not b[1:, 1:].any()
assert a.dtype.char == vtype
del a, b
assert a.dtype.char == tp
del a, b, v


def test_Vector_strides():
Expand All @@ -25,12 +25,3 @@ def test_Vector_strides():
a = np.asarray(v)
assert a.shape == shape
assert a.strides == (473344, 1376, 4)


def test_vector():
shape = 127, 344, 344
a = np.asarray(cuvec.vector(shape, "i"))
assert a.dtype == np.int32

a = np.asarray(cuvec.vector(shape, "d"))
assert a.dtype == np.float64
22 changes: 22 additions & 0 deletions tests/test_pycuvec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
from pytest import mark

import cuvec


@mark.parametrize("spec,result", [("i", np.int32), ("d", np.float64)])
def test_zeros(spec, result):
shape = 127, 344, 344
a = np.asarray(cuvec.zeros(shape, spec))
assert a.dtype == result
assert a.shape == shape
assert not a.any()


def test_from_numpy():
shape = 127, 344, 344
a = np.random.random(shape)
b = np.asarray(cuvec.from_numpy(a))
assert a.shape == b.shape
assert a.dtype == b.dtype
assert (a == b).all()

0 comments on commit 8ee6305

Please sign in to comment.