Skip to content

Commit

Permalink
Add error checking to CUDA version of getNeighborPairs (#80)
Browse files Browse the repository at this point in the history
* Add error checking to CUDA version of getNeighborPairs

* Add a new bool optional parameter to getNeighborPairs, setting it to true will
force the function to synchronize and throw an exception if some error
was found, so it can be catched.
The default will throw the error asynchronously, which will crash the
program.
In both cases a meaningful message is printed.

* Remove unnecessarily static variable

* Change the error handling of getNeighborPairs.
- Add a new optional flag, sync_exceptions on top of the current
check_errors.
- Three behaviors are possible:
  1. Default (both false). Operation is CUDA-graph compatible and an
  uncatchable exception is thrown in case of number of pairs being too
  high.
  2. check_errors=True. Operation is CUDA-graph compatible. No
  exception is thrown and the number of found pairs is returned, which
  can be higher than max_number_pairs.
  3. check_errors=False and sync_exceptions=True. Operation is NOT
  CUDA-graph compatible. The operation synchronizes to check for
  errors and throws a catchable exception if necessary.

* Make getNeighborPairs CUDA-graph compatible, add test for it

* Remove incorrect comment

* Change not by !

* Move all torch.ops.load calls to the __init__.py scripts

* Change how the location of libNNPOpsPyTorch.so is found at __init__ scripts

* Remove spurious lines in CMakeLists.txt

* Update again how libNNPOpsPyTorch.so is found in __init__.py

* Remove redundant torch load

* Skip CUDA graph test if no GPU is available

* Remove incorrect path in __init__

* Use relative path to load NNPOps library in __init__.py

* Copy test scripts to build directory, run them there

* Remove unnecessary import

* Some fixes for CUDA graph support in getNEighborPairs

* Reverse logic for check_errors in getNeighborPairs.py

* Reverse check_errors flag in the rest of the getNeighborPair-related files

* Clarify documentation on the error raised by getNeighborPairs

* Always return the number of found pairs in getNeighborPairs

* Revert "Always return the number of found pairs in getNeighborPairs"

This reverts commit c36243b.

* Fix check_error interpretation in getNeighborPairs.py

* Add return number of pairs functionality again

This reverts commit 8da1c5d.

* Update tests with new getNeighborPairs interface

* Fix type decorator preventing jit.script from working on getNeighborPairs

* Remove sync_exceptions flag, simplifying the behavior and relation
with CUDA graphs.
If check_errors=False (the default) getNeighborPairs does not check
for errors and is compatible with graphs.
If check_errors=True, the function raises if necessary but it is
incompatible with graphs

* Remove unused function

* Remove unnecessary synchronization in test

* Clarify documentation of check_errors

* Clarify documentation of number_found_pairs

* Clarify documentation of CUDA graph functionality

* Remove obsolete comment

* Fix formatting

* Fix formatting

* Update documentation

* Change the (misleading) num_pairs variable name to max_num_pairs.
Enforce that the found number of pairs is less than num_pairs

* Add test that checks if the max_num_neighbors per particle is
enforced.
Right now this does not pass, since the function allows that an atom
has more neighbors than max_num_neighbors as long as num_found_pairs<num_atoms*max_num_neighbors

* Change the meaning and name from max_num_neighbors (maximum number of neighbors per particle) to max_num_pairs (maximum number of total pairs).

* Fix typo in comment
  • Loading branch information
RaulPPelaez authored Apr 14, 2023
1 parent c5b12ba commit b27ec97
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 110 deletions.
90 changes: 70 additions & 20 deletions src/pytorch/neighbors/TestNeighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs):

# Find the number of neighbors
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))
max_num_neighbors = -1 if all_pairs else max(int(np.ceil(num_neighbors / num_atoms)), 1)
max_num_pairs = -1 if all_pairs else max(int(num_neighbors), 1)

# Compute results
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=max_num_pairs)

# Check device
assert neighbors.device == positions.device
Expand All @@ -83,7 +83,7 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs):
neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances)

# Resize the reference
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors)
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, max_num_pairs)

assert np.all(ref_neighbors == neighbors)
assert np.allclose(ref_deltas, deltas, equal_nan=True)
Expand All @@ -94,7 +94,7 @@ def test_neighbor_values(device, dtype, num_atoms, cutoff, all_pairs):
@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100, 1000])
@pytest.mark.parametrize('grad', ['deltas', 'distances', 'combined'])
def test_neighbor_grads(device, dtype, num_atoms, grad):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

Expand All @@ -114,8 +114,8 @@ def test_neighbor_grads(device, dtype, num_atoms, grad):
# Compute values using NNPOps
positions.requires_grad_(True)
print(positions)
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff)

assert pt.all(neighbors > -1)
assert pt.all(neighbors == ref_neighbors)
assert pt.allclose(deltas, ref_deltas)
Expand All @@ -133,28 +133,78 @@ def test_neighbor_grads(device, dtype, num_atoms, grad):
(deltas.sum() + distances.sum()).backward()
else:
raise ValueError('grad')

if dtype == pt.float32:
assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-3, rtol=1e-3)
else:
assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-8, rtol=1e-5)


# The following test is only run on the CPU. Running it on the GPU triggers a
# CUDA assertion, which causes all tests run after it to fail.

@pytest.mark.parametrize('device', ['cpu'])
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_too_many_neighbors(device, dtype):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

# 4 points result into 6 pairs, but there is a storage just for 4.
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
with pytest.raises(RuntimeError):
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
getNeighborPairs(positions, cutoff=1, max_num_neighbors=1)
pt.cuda.synchronize()
# checkErrors = True will raise due to exceeding neighbours
getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=True)

# checkErrors = False will never throw due to exceeding neighbours. In addition, the call will be compatible with CUDA graphs
neighbors, deltas, distances, number_found_pairs = getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=False)
assert number_found_pairs == 6

@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
def test_max_pairs_means_total(device, dtype):
if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')
# 4 points result into 6 pairs.
positions = pt.zeros((4, 3,), device=device, dtype=dtype)
with pytest.raises(RuntimeError):
# checkErrors = True should raise due to exceeding neighbours
getNeighborPairs(positions, cutoff=1, max_num_pairs=5, check_errors=True)
getNeighborPairs(positions, cutoff=1, max_num_pairs=6, check_errors=True)

def test_is_cuda_graph_compatible():
if not pt.cuda.is_available():
pytest.skip('No GPU')
device = 'cuda'
dtype = pt.float32
num_atoms = 100
# Generate random positions
positions = 10 * pt.randn((num_atoms, 3), device=device, dtype=dtype)
cutoff = 5
# Get neighbor pairs
ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1))
ref_positions = positions.cpu().numpy()
ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]]
ref_distances = np.linalg.norm(ref_deltas, axis=1)

# Filter the neighbor pairs
mask = ref_distances > cutoff
ref_neighbors[:, mask] = -1
ref_deltas[mask, :] = np.nan
ref_distances[mask] = np.nan

# Find the number of neighbors
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))

graph = pt.cuda.CUDAGraph()
s = pt.cuda.Stream()
s.wait_stream(pt.cuda.current_stream())
with pt.cuda.stream(s):
for _ in range(3):
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1)
pt.cuda.synchronize()

with pt.cuda.graph(graph):
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1)

graph.replay()
pt.cuda.synchronize()


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
Expand Down Expand Up @@ -187,10 +237,10 @@ def test_periodic_neighbors(device, dtype):

# Find the number of neighbors
num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances)))
max_num_neighbors = max(int(np.ceil(num_neighbors / num_atoms)), 1)
max_num_pairs = max(int(num_neighbors), 1)

# Compute results
neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors, box_vectors=box_vectors)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=max_num_pairs, box_vectors=box_vectors)

# Check device
assert neighbors.device == positions.device
Expand All @@ -213,7 +263,7 @@ def test_periodic_neighbors(device, dtype):
neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances)

# Resize the reference
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors)
ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, max_num_pairs)

assert np.all(ref_neighbors == neighbors)
assert np.allclose(ref_deltas, deltas, equal_nan=True)
Expand All @@ -228,7 +278,7 @@ class ForceModule(pt.nn.Module):

def forward(self, positions):

neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0)
neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=1.0)
mask = pt.isnan(distances)
distances = distances[~mask]
return pt.sum(distances**2)
Expand Down
97 changes: 59 additions & 38 deletions src/pytorch/neighbors/getNeighborPairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
from typing import Optional, Tuple


def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
'''
Returns indices and distances of atom pairs within a given cutoff distance.
If `max_num_neighbors == -1` (default), all the atom pairs are returned,
def getNeighborPairs(
positions: Tensor,
cutoff: float,
max_num_pairs: int = -1,
box_vectors: Optional[Tensor] = None,
check_errors: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Returns indices and distances of atom pairs within a given cutoff distance.
If `max_num_pairs == -1` (default), all the atom pairs are returned,
i.e. `num_pairs = num_atoms * (num_atoms + 1) / 2`. This is intended for
the small molecules, where almost all the atoms are within the cutoff
distance of each other.
If `max_num_neighbors > 0`, a fixed number of the atom pair are returned,
i.e. `num_pairs = num_atoms * max_num_neighbors`. This is indeded for large
molecule, where most of the atoms are beyond the cutoff distance of each
other.
If `max_num_pairs > 0`, a fixed number of the atom pairs are
returned. This is intended for large molecule, where most of the
atoms are beyond the cutoff distance of each other.
This function optionally supports periodic boundary conditions with
arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy
Expand All @@ -37,13 +41,20 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
data type has to be`torch.float32` or `torch.float64`.
cutoff: float
Maximum distance between atom pairs.
max_num_neighbors: int, optional
Maximum number of neighbors per atom. If set to `-1` (default),
max_num_pairs: int, optional
Maximum number of pairs (total number of neighbors). If set to `-1` (default),
all possible combinations of atom pairs are included.
box_vectors: `torch.Tensor`, optional
The vectors defining the periodic box. This must have shape `(3, 3)`,
where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`.
If this is omitted, periodic boundary conditions are not applied.
check_errors: bool, optional
If True, a RuntimeError is raised if more than max_num_pairs pairs are found.
The error checking requires synchronization, which adds cost and makes this function
incompatible with CUDA graphs. If this argument is False, no error checking is performed.
This makes it faster and compatible with CUDA graphs, but it is your responsibility
to check the return value for number_found_pairs to make sure that no neighbors were missed.
Default: False
Returns
-------
Expand All @@ -63,17 +74,27 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
If an atom pair is separated by a larger distance than the cutoff,
the distance is set to `NaN`.
number_found_pairs: `torch.Tensor`
Contains the total number of pairs found. Be aware that if
check_errors is False, this might be larger than
max_num_pairs. In that case, the output tensors contain
only a subset of the pairs that were found, and the others are
omitted. Which pairs get omitted may vary between invocations.
Exceptions
----------
If `max_num_neighbors > 0` and too small, `RuntimeError` is raised.
If `max_num_pairs > 0` and too small, `RuntimeError` is raised if check_errors=True.
Note
----
The operation is compatible with CUDA Grahps, i.e. the shapes of the output
tensors are independed of the values of input tensors.
The operation can be compatible with CUDA Graphs: the shapes of
the output tensors are independent of the values of input tensors,
and no synchronization is performed.
For this to be true, check_errors must be False.
The CUDA implementation returns the atom pairs in non-determinist order,
if `max_num_neighbors > 0`.
if `max_num_pairs > 0`.
Examples
--------
Expand All @@ -88,39 +109,39 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int =
tensor([[1., 0., 0.],
[2., 0., 0.],
[1., 0., 0.]]),
tensor([1., 2., 1.]))
tensor([1., 2., 1.]), tensor([3], dtype=torch.int32))
>>> getNeighborPairs(positions, cutoff=1.5) # doctest: +NORMALIZE_WHITESPACE
(tensor([[ 1, -1, 2],
[ 0, -1, 1]], dtype=torch.int32),
tensor([[1., 0., 0.],
[nan, nan, nan],
[1., 0., 0.]]),
tensor([1., nan, 1.]))
tensor([1., nan, 1.]), tensor([3], dtype=torch.int32))
>>> getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE
>>> getNeighborPairs(positions, cutoff=3.0, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE
(tensor([[ 1, 2, 2, -1, -1, -1],
[ 0, 0, 1, -1, -1, -1]], dtype=torch.int32),
tensor([[1., 0., 0.],
[2., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]),
tensor([1., 2., 1., nan, nan, nan]))
>>> getNeighborPairs(positions, cutoff=1.5, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE
[ 0, 0, 1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.],
[2., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]), tensor([1., 2., 1., nan, nan, nan]), tensor([6], dtype=torch.int32))
>>> getNeighborPairs(positions, cutoff=1.5, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE
(tensor([[ 1, 2, -1, -1, -1, -1],
[ 0, 1, -1, -1, -1, -1]], dtype=torch.int32),
tensor([[1., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]),
tensor([1., 1., nan, nan, nan, nan]))
'''
[ 0, 1, -1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.],
[1., 0., 0.],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]), tensor([1., 1., nan, nan, nan, nan]), tensor([6], dtype=torch.int32))
"""

if box_vectors is None:
box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype)
return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors)
neighbors, deltas, distances, number_found_pairs = ops.neighbors.getNeighborPairs(
positions, cutoff, max_num_pairs, box_vectors, check_errors
)
return neighbors, deltas, distances, number_found_pairs
39 changes: 23 additions & 16 deletions src/pytorch/neighbors/getNeighborPairsCPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ using torch::Tensor;
using torch::outer;
using torch::round;

static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
const Scalar& cutoff,
const Scalar& max_num_neighbors,
const Tensor& box_vectors) {
static tuple<Tensor, Tensor, Tensor, Tensor> forward(const Tensor& positions,
const Scalar& cutoff,
const Scalar& max_num_pairs,
const Tensor& box_vectors,
bool checkErrors) {

TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0");
Expand Down Expand Up @@ -47,9 +48,9 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
TORCH_CHECK(v[1][1] >= 2*v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]");
}

const int max_num_neighbors_ = max_num_neighbors.to<int>();
TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1,
"Expected \"max_num_neighbors\" to be positive or equal to -1");
const int max_num_pairs_ = max_num_pairs.to<int>();
TORCH_CHECK(max_num_pairs_ > 0 || max_num_pairs_ == -1,
"Expected \"max_num_pairs\" to be positive or equal to -1");

const int num_atoms = positions.size(0);
const int num_pairs = num_atoms * (num_atoms - 1) / 2;
Expand All @@ -68,7 +69,7 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
}
Tensor distances = frobenius_norm(deltas, 1);

if (max_num_neighbors_ == -1) {
if (max_num_pairs_ == -1) {
const Tensor mask = distances > cutoff;
neighbors.index_put_({Slice(), mask}, -1);
deltas = deltas.clone(); // Break an autograd loop
Expand All @@ -82,20 +83,26 @@ static tuple<Tensor, Tensor, Tensor> forward(const Tensor& positions,
deltas = deltas.index({mask, Slice()});
distances = distances.index({mask});

const int num_pad = num_atoms * max_num_neighbors_ - distances.size(0);
TORCH_CHECK(num_pad >= 0,
"The maximum number of pairs has been exceed! Increase \"max_num_neighbors\"");

const int num_pad = max_num_pairs_ - distances.size(0);
if (checkErrors) {
TORCH_CHECK(num_pad >= 0,
"The maximum number of pairs has been exceed! Increase \"max_num_pairs\"");
}
if (num_pad > 0) {
neighbors = hstack({neighbors, full({2, num_pad}, -1, neighbors.options())});
deltas = vstack({deltas, full({num_pad, 3}, NAN, deltas.options())});
distances = hstack({distances, full({num_pad}, NAN, distances.options())});
}
}

return {neighbors, deltas, distances};
Tensor num_pairs_found = torch::empty(1, indices.options().dtype(kInt32));
num_pairs_found[0] = distances.size(0);
return {neighbors, deltas, distances, num_pairs_found};
}

TORCH_LIBRARY_IMPL(neighbors, CPU, m) {
m.impl("getNeighborPairs", &forward);
}
m.impl("getNeighborPairs",
[](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_pairs,
const Tensor& box_vectors, const bool &checkErrors){
return forward(positions, cutoff, max_num_pairs, box_vectors, checkErrors);
});
}
Loading

0 comments on commit b27ec97

Please sign in to comment.