From b27ec979946d39fe3846ca5ab73aa9cbc1286289 Mon Sep 17 00:00:00 2001 From: Raul Date: Fri, 14 Apr 2023 14:14:19 +0200 Subject: [PATCH] Add error checking to CUDA version of getNeighborPairs (#80) * Add error checking to CUDA version of getNeighborPairs * Add a new bool optional parameter to getNeighborPairs, setting it to true will force the function to synchronize and throw an exception if some error was found, so it can be catched. The default will throw the error asynchronously, which will crash the program. In both cases a meaningful message is printed. * Remove unnecessarily static variable * Change the error handling of getNeighborPairs. - Add a new optional flag, sync_exceptions on top of the current check_errors. - Three behaviors are possible: 1. Default (both false). Operation is CUDA-graph compatible and an uncatchable exception is thrown in case of number of pairs being too high. 2. check_errors=True. Operation is CUDA-graph compatible. No exception is thrown and the number of found pairs is returned, which can be higher than max_number_pairs. 3. check_errors=False and sync_exceptions=True. Operation is NOT CUDA-graph compatible. The operation synchronizes to check for errors and throws a catchable exception if necessary. * Make getNeighborPairs CUDA-graph compatible, add test for it * Remove incorrect comment * Change not by ! * Move all torch.ops.load calls to the __init__.py scripts * Change how the location of libNNPOpsPyTorch.so is found at __init__ scripts * Remove spurious lines in CMakeLists.txt * Update again how libNNPOpsPyTorch.so is found in __init__.py * Remove redundant torch load * Skip CUDA graph test if no GPU is available * Remove incorrect path in __init__ * Use relative path to load NNPOps library in __init__.py * Copy test scripts to build directory, run them there * Remove unnecessary import * Some fixes for CUDA graph support in getNEighborPairs * Reverse logic for check_errors in getNeighborPairs.py * Reverse check_errors flag in the rest of the getNeighborPair-related files * Clarify documentation on the error raised by getNeighborPairs * Always return the number of found pairs in getNeighborPairs * Revert "Always return the number of found pairs in getNeighborPairs" This reverts commit c36243b36955fa13e4c2f0f1dc552bfa11a60d66. * Fix check_error interpretation in getNeighborPairs.py * Add return number of pairs functionality again This reverts commit 8da1c5d77bbe0c95036dcb8382d29ee4aefd62b0. * Update tests with new getNeighborPairs interface * Fix type decorator preventing jit.script from working on getNeighborPairs * Remove sync_exceptions flag, simplifying the behavior and relation with CUDA graphs. If check_errors=False (the default) getNeighborPairs does not check for errors and is compatible with graphs. If check_errors=True, the function raises if necessary but it is incompatible with graphs * Remove unused function * Remove unnecessary synchronization in test * Clarify documentation of check_errors * Clarify documentation of number_found_pairs * Clarify documentation of CUDA graph functionality * Remove obsolete comment * Fix formatting * Fix formatting * Update documentation * Change the (misleading) num_pairs variable name to max_num_pairs. Enforce that the found number of pairs is less than num_pairs * Add test that checks if the max_num_neighbors per particle is enforced. Right now this does not pass, since the function allows that an atom has more neighbors than max_num_neighbors as long as num_found_pairs -1) assert pt.all(neighbors == ref_neighbors) assert pt.allclose(deltas, ref_deltas) @@ -133,28 +133,78 @@ def test_neighbor_grads(device, dtype, num_atoms, grad): (deltas.sum() + distances.sum()).backward() else: raise ValueError('grad') - + if dtype == pt.float32: assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-3, rtol=1e-3) else: assert pt.allclose(ref_positions.grad, positions.grad, atol=1e-8, rtol=1e-5) -# The following test is only run on the CPU. Running it on the GPU triggers a -# CUDA assertion, which causes all tests run after it to fail. - -@pytest.mark.parametrize('device', ['cpu']) +@pytest.mark.parametrize('device', ['cpu', 'cuda']) @pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) def test_too_many_neighbors(device, dtype): - if not pt.cuda.is_available() and device == 'cuda': pytest.skip('No GPU') - # 4 points result into 6 pairs, but there is a storage just for 4. + positions = pt.zeros((4, 3,), device=device, dtype=dtype) with pytest.raises(RuntimeError): - positions = pt.zeros((4, 3,), device=device, dtype=dtype) - getNeighborPairs(positions, cutoff=1, max_num_neighbors=1) - pt.cuda.synchronize() + # checkErrors = True will raise due to exceeding neighbours + getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=True) + + # checkErrors = False will never throw due to exceeding neighbours. In addition, the call will be compatible with CUDA graphs + neighbors, deltas, distances, number_found_pairs = getNeighborPairs(positions, cutoff=1, max_num_pairs=1, check_errors=False) + assert number_found_pairs == 6 + +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) +def test_max_pairs_means_total(device, dtype): + if not pt.cuda.is_available() and device == 'cuda': + pytest.skip('No GPU') + # 4 points result into 6 pairs. + positions = pt.zeros((4, 3,), device=device, dtype=dtype) + with pytest.raises(RuntimeError): + # checkErrors = True should raise due to exceeding neighbours + getNeighborPairs(positions, cutoff=1, max_num_pairs=5, check_errors=True) + getNeighborPairs(positions, cutoff=1, max_num_pairs=6, check_errors=True) + +def test_is_cuda_graph_compatible(): + if not pt.cuda.is_available(): + pytest.skip('No GPU') + device = 'cuda' + dtype = pt.float32 + num_atoms = 100 + # Generate random positions + positions = 10 * pt.randn((num_atoms, 3), device=device, dtype=dtype) + cutoff = 5 + # Get neighbor pairs + ref_neighbors = np.vstack(np.tril_indices(num_atoms, -1)) + ref_positions = positions.cpu().numpy() + ref_deltas = ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]] + ref_distances = np.linalg.norm(ref_deltas, axis=1) + + # Filter the neighbor pairs + mask = ref_distances > cutoff + ref_neighbors[:, mask] = -1 + ref_deltas[mask, :] = np.nan + ref_distances[mask] = np.nan + + # Find the number of neighbors + num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances))) + + graph = pt.cuda.CUDAGraph() + s = pt.cuda.Stream() + s.wait_stream(pt.cuda.current_stream()) + with pt.cuda.stream(s): + for _ in range(3): + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1) + pt.cuda.synchronize() + + with pt.cuda.graph(graph): + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=num_neighbors+1) + + graph.replay() + pt.cuda.synchronize() + @pytest.mark.parametrize('device', ['cpu', 'cuda']) @pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) @@ -187,10 +237,10 @@ def test_periodic_neighbors(device, dtype): # Find the number of neighbors num_neighbors = np.count_nonzero(np.logical_not(np.isnan(ref_distances))) - max_num_neighbors = max(int(np.ceil(num_neighbors / num_atoms)), 1) + max_num_pairs = max(int(num_neighbors), 1) # Compute results - neighbors, deltas, distances = getNeighborPairs(positions, cutoff=cutoff, max_num_neighbors=max_num_neighbors, box_vectors=box_vectors) + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=cutoff, max_num_pairs=max_num_pairs, box_vectors=box_vectors) # Check device assert neighbors.device == positions.device @@ -213,7 +263,7 @@ def test_periodic_neighbors(device, dtype): neighbors, deltas, distances = sort_neighbors(neighbors, deltas, distances) # Resize the reference - ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, num_atoms * max_num_neighbors) + ref_neighbors, ref_deltas, ref_distances = resize_neighbors(ref_neighbors, ref_deltas, ref_distances, max_num_pairs) assert np.all(ref_neighbors == neighbors) assert np.allclose(ref_deltas, deltas, equal_nan=True) @@ -228,7 +278,7 @@ class ForceModule(pt.nn.Module): def forward(self, positions): - neighbors, deltas, distances = getNeighborPairs(positions, cutoff=1.0) + neighbors, deltas, distances, _ = getNeighborPairs(positions, cutoff=1.0) mask = pt.isnan(distances) distances = distances[~mask] return pt.sum(distances**2) diff --git a/src/pytorch/neighbors/getNeighborPairs.py b/src/pytorch/neighbors/getNeighborPairs.py index b6f30072..12a4b03c 100644 --- a/src/pytorch/neighbors/getNeighborPairs.py +++ b/src/pytorch/neighbors/getNeighborPairs.py @@ -2,19 +2,23 @@ from typing import Optional, Tuple -def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = -1, box_vectors: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: - ''' - Returns indices and distances of atom pairs within a given cutoff distance. - - If `max_num_neighbors == -1` (default), all the atom pairs are returned, +def getNeighborPairs( + positions: Tensor, + cutoff: float, + max_num_pairs: int = -1, + box_vectors: Optional[Tensor] = None, + check_errors: bool = False +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Returns indices and distances of atom pairs within a given cutoff distance. + + If `max_num_pairs == -1` (default), all the atom pairs are returned, i.e. `num_pairs = num_atoms * (num_atoms + 1) / 2`. This is intended for the small molecules, where almost all the atoms are within the cutoff distance of each other. - If `max_num_neighbors > 0`, a fixed number of the atom pair are returned, - i.e. `num_pairs = num_atoms * max_num_neighbors`. This is indeded for large - molecule, where most of the atoms are beyond the cutoff distance of each - other. + If `max_num_pairs > 0`, a fixed number of the atom pairs are + returned. This is intended for large molecule, where most of the + atoms are beyond the cutoff distance of each other. This function optionally supports periodic boundary conditions with arbitrary triclinic boxes. The box vectors `a`, `b`, and `c` must satisfy @@ -37,13 +41,20 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = data type has to be`torch.float32` or `torch.float64`. cutoff: float Maximum distance between atom pairs. - max_num_neighbors: int, optional - Maximum number of neighbors per atom. If set to `-1` (default), + max_num_pairs: int, optional + Maximum number of pairs (total number of neighbors). If set to `-1` (default), all possible combinations of atom pairs are included. box_vectors: `torch.Tensor`, optional The vectors defining the periodic box. This must have shape `(3, 3)`, where `box_vectors[0] = a`, `box_vectors[1] = b`, and `box_vectors[2] = c`. If this is omitted, periodic boundary conditions are not applied. + check_errors: bool, optional + If True, a RuntimeError is raised if more than max_num_pairs pairs are found. + The error checking requires synchronization, which adds cost and makes this function + incompatible with CUDA graphs. If this argument is False, no error checking is performed. + This makes it faster and compatible with CUDA graphs, but it is your responsibility + to check the return value for number_found_pairs to make sure that no neighbors were missed. + Default: False Returns ------- @@ -63,17 +74,27 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = If an atom pair is separated by a larger distance than the cutoff, the distance is set to `NaN`. + number_found_pairs: `torch.Tensor` + Contains the total number of pairs found. Be aware that if + check_errors is False, this might be larger than + max_num_pairs. In that case, the output tensors contain + only a subset of the pairs that were found, and the others are + omitted. Which pairs get omitted may vary between invocations. + Exceptions ---------- - If `max_num_neighbors > 0` and too small, `RuntimeError` is raised. + If `max_num_pairs > 0` and too small, `RuntimeError` is raised if check_errors=True. Note ---- - The operation is compatible with CUDA Grahps, i.e. the shapes of the output - tensors are independed of the values of input tensors. + The operation can be compatible with CUDA Graphs: the shapes of + the output tensors are independent of the values of input tensors, + and no synchronization is performed. + For this to be true, check_errors must be False. The CUDA implementation returns the atom pairs in non-determinist order, - if `max_num_neighbors > 0`. + if `max_num_pairs > 0`. + Examples -------- @@ -88,7 +109,7 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = tensor([[1., 0., 0.], [2., 0., 0.], [1., 0., 0.]]), - tensor([1., 2., 1.])) + tensor([1., 2., 1.]), tensor([3], dtype=torch.int32)) >>> getNeighborPairs(positions, cutoff=1.5) # doctest: +NORMALIZE_WHITESPACE (tensor([[ 1, -1, 2], @@ -96,31 +117,31 @@ def getNeighborPairs(positions: Tensor, cutoff: float, max_num_neighbors: int = tensor([[1., 0., 0.], [nan, nan, nan], [1., 0., 0.]]), - tensor([1., nan, 1.])) + tensor([1., nan, 1.]), tensor([3], dtype=torch.int32)) - >>> getNeighborPairs(positions, cutoff=3.0, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE + >>> getNeighborPairs(positions, cutoff=3.0, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE (tensor([[ 1, 2, 2, -1, -1, -1], - [ 0, 0, 1, -1, -1, -1]], dtype=torch.int32), - tensor([[1., 0., 0.], - [2., 0., 0.], - [1., 0., 0.], - [nan, nan, nan], - [nan, nan, nan], - [nan, nan, nan]]), - tensor([1., 2., 1., nan, nan, nan])) - - >>> getNeighborPairs(positions, cutoff=1.5, max_num_neighbors=2) # doctest: +NORMALIZE_WHITESPACE + [ 0, 0, 1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.], + [2., 0., 0.], + [1., 0., 0.], + [nan, nan, nan], + [nan, nan, nan], + [nan, nan, nan]]), tensor([1., 2., 1., nan, nan, nan]), tensor([6], dtype=torch.int32)) + + >>> getNeighborPairs(positions, cutoff=1.5, max_num_pairs=6) # doctest: +NORMALIZE_WHITESPACE (tensor([[ 1, 2, -1, -1, -1, -1], - [ 0, 1, -1, -1, -1, -1]], dtype=torch.int32), - tensor([[1., 0., 0.], - [1., 0., 0.], - [nan, nan, nan], - [nan, nan, nan], - [nan, nan, nan], - [nan, nan, nan]]), - tensor([1., 1., nan, nan, nan, nan])) - ''' + [ 0, 1, -1, -1, -1, -1]], dtype=torch.int32), tensor([[1., 0., 0.], + [1., 0., 0.], + [nan, nan, nan], + [nan, nan, nan], + [nan, nan, nan], + [nan, nan, nan]]), tensor([1., 1., nan, nan, nan, nan]), tensor([6], dtype=torch.int32)) + + """ if box_vectors is None: box_vectors = empty((0, 0), device=positions.device, dtype=positions.dtype) - return ops.neighbors.getNeighborPairs(positions, cutoff, max_num_neighbors, box_vectors) \ No newline at end of file + neighbors, deltas, distances, number_found_pairs = ops.neighbors.getNeighborPairs( + positions, cutoff, max_num_pairs, box_vectors, check_errors + ) + return neighbors, deltas, distances, number_found_pairs diff --git a/src/pytorch/neighbors/getNeighborPairsCPU.cpp b/src/pytorch/neighbors/getNeighborPairsCPU.cpp index 9df95b85..d63e24b2 100644 --- a/src/pytorch/neighbors/getNeighborPairsCPU.cpp +++ b/src/pytorch/neighbors/getNeighborPairsCPU.cpp @@ -16,10 +16,11 @@ using torch::Tensor; using torch::outer; using torch::round; -static tuple forward(const Tensor& positions, - const Scalar& cutoff, - const Scalar& max_num_neighbors, - const Tensor& box_vectors) { +static tuple forward(const Tensor& positions, + const Scalar& cutoff, + const Scalar& max_num_pairs, + const Tensor& box_vectors, + bool checkErrors) { TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); @@ -47,9 +48,9 @@ static tuple forward(const Tensor& positions, TORCH_CHECK(v[1][1] >= 2*v[2][1], "Invalid box vectors: box_vectors[1][1] < 2*box_vectors[2][1]"); } - const int max_num_neighbors_ = max_num_neighbors.to(); - TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1, - "Expected \"max_num_neighbors\" to be positive or equal to -1"); + const int max_num_pairs_ = max_num_pairs.to(); + TORCH_CHECK(max_num_pairs_ > 0 || max_num_pairs_ == -1, + "Expected \"max_num_pairs\" to be positive or equal to -1"); const int num_atoms = positions.size(0); const int num_pairs = num_atoms * (num_atoms - 1) / 2; @@ -68,7 +69,7 @@ static tuple forward(const Tensor& positions, } Tensor distances = frobenius_norm(deltas, 1); - if (max_num_neighbors_ == -1) { + if (max_num_pairs_ == -1) { const Tensor mask = distances > cutoff; neighbors.index_put_({Slice(), mask}, -1); deltas = deltas.clone(); // Break an autograd loop @@ -82,20 +83,26 @@ static tuple forward(const Tensor& positions, deltas = deltas.index({mask, Slice()}); distances = distances.index({mask}); - const int num_pad = num_atoms * max_num_neighbors_ - distances.size(0); - TORCH_CHECK(num_pad >= 0, - "The maximum number of pairs has been exceed! Increase \"max_num_neighbors\""); - + const int num_pad = max_num_pairs_ - distances.size(0); + if (checkErrors) { + TORCH_CHECK(num_pad >= 0, + "The maximum number of pairs has been exceed! Increase \"max_num_pairs\""); + } if (num_pad > 0) { neighbors = hstack({neighbors, full({2, num_pad}, -1, neighbors.options())}); deltas = vstack({deltas, full({num_pad, 3}, NAN, deltas.options())}); distances = hstack({distances, full({num_pad}, NAN, distances.options())}); } } - - return {neighbors, deltas, distances}; + Tensor num_pairs_found = torch::empty(1, indices.options().dtype(kInt32)); + num_pairs_found[0] = distances.size(0); + return {neighbors, deltas, distances, num_pairs_found}; } TORCH_LIBRARY_IMPL(neighbors, CPU, m) { - m.impl("getNeighborPairs", &forward); -} \ No newline at end of file + m.impl("getNeighborPairs", + [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_pairs, + const Tensor& box_vectors, const bool &checkErrors){ + return forward(positions, cutoff, max_num_pairs, box_vectors, checkErrors); + }); +} diff --git a/src/pytorch/neighbors/getNeighborPairsCUDA.cu b/src/pytorch/neighbors/getNeighborPairsCUDA.cu index 2d820a4a..23540af6 100644 --- a/src/pytorch/neighbors/getNeighborPairsCUDA.cu +++ b/src/pytorch/neighbors/getNeighborPairsCUDA.cu @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include #include @@ -64,14 +66,15 @@ template __global__ void forward_kernel( if (distance2 > cutoff2) return; const int32_t i_pair = store_all_pairs ? index : atomicAdd(&i_curr_pair[0], 1); - assert(i_pair < neighbors.size(1)); - - neighbors[0][i_pair] = row; - neighbors[1][i_pair] = column; - deltas[i_pair][0] = delta_x; - deltas[i_pair][1] = delta_y; - deltas[i_pair][2] = delta_z; - distances[i_pair] = sqrt_(distance2); + //We handle too many neighbors outside of the kernel + if (i_pair < neighbors.size(1)) { + neighbors[0][i_pair] = row; + neighbors[1][i_pair] = column; + deltas[i_pair][0] = delta_x; + deltas[i_pair][1] = delta_y; + deltas[i_pair][2] = delta_z; + distances[i_pair] = sqrt_(distance2); + } } template __global__ void backward_kernel( @@ -102,17 +105,17 @@ public: static tensor_list forward(AutogradContext* ctx, const Tensor& positions, const Scalar& cutoff, - const Scalar& max_num_neighbors, - const Tensor& box_vectors) { - + const Scalar& max_num_pairs, + const Tensor& box_vectors, + bool checkErrors) { + const auto stream = getCurrentCUDAStream(positions.get_device()); + const CUDAStreamGuard guard(stream); TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions"); TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0"); TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3"); TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous"); - - const int max_num_neighbors_ = max_num_neighbors.to(); - TORCH_CHECK(max_num_neighbors_ > 0 || max_num_neighbors_ == -1, - "Expected \"max_num_neighbors\" to be positive or equal to -1"); + TORCH_CHECK(max_num_pairs.toInt() > 0 || max_num_pairs.toInt() == -1, + "Expected \"max_num_pairs\" to be positive or equal to -1"); const bool use_periodic = (box_vectors.size(0) != 0); if (use_periodic) { @@ -121,25 +124,23 @@ public: } // Decide the algorithm - const bool store_all_pairs = max_num_neighbors_ == -1; + const bool store_all_pairs = max_num_pairs.toInt() == -1; const int num_atoms = positions.size(0); const int num_all_pairs = num_atoms * (num_atoms - 1) / 2; - const int num_pairs = store_all_pairs ? num_all_pairs : num_atoms * max_num_neighbors_; + const int max_num_pairs_ = store_all_pairs ? num_all_pairs : (max_num_pairs.toInt()); const int num_threads = 128; const int num_blocks = max((num_all_pairs + num_threads - 1) / num_threads, 1); - const auto stream = getCurrentCUDAStream(positions.get_device()); const TensorOptions options = positions.options(); const Tensor i_curr_pair = zeros(1, options.dtype(kInt32)); - const Tensor neighbors = full({2, num_pairs}, -1, options.dtype(kInt32)); - const Tensor deltas = full({num_pairs, 3}, NAN, options); - const Tensor distances = full(num_pairs, NAN, options); + const Tensor neighbors = full({2, max_num_pairs_}, -1, options.dtype(kInt32)); + const Tensor deltas = full({max_num_pairs_, 3}, NAN, options); + const Tensor distances = full(max_num_pairs_, NAN, options); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "getNeighborPairs::forward", [&]() { - const CUDAStreamGuard guard(stream); - const scalar_t cutoff_ = cutoff.to(); - TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); + const scalar_t cutoff_ = cutoff.to(); + TORCH_CHECK(cutoff_ > 0, "Expected \"cutoff\" to be positive"); forward_kernel<<>>( num_all_pairs, get_accessor(positions), @@ -152,11 +153,14 @@ public: get_accessor(distances), get_accessor(box_vectors)); }); - + // Synchronize and check the number of pairs found. Note that this is incompatible with CUDA graphs + if (checkErrors) { + int num_found_pairs = i_curr_pair.item(); + TORCH_CHECK(num_found_pairs <= max_num_pairs_, "Too many neighbor pairs found. Maximum is " + std::to_string(max_num_pairs_), " but found " + std::to_string(num_found_pairs)); + } ctx->save_for_backward({neighbors, deltas, distances}); ctx->saved_data["num_atoms"] = num_atoms; - - return {neighbors, deltas, distances}; + return {neighbors, deltas, distances, i_curr_pair}; } static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { @@ -187,14 +191,16 @@ public: get_accessor(grad_positions)); }); - return {grad_positions, Tensor(), Tensor(), Tensor()}; + return {grad_positions, Tensor(), Tensor(), Tensor(), Tensor(), Tensor()}; } }; TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) { - m.impl("getNeighborPairs", - [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_neighbors, const Tensor& box_vectors){ - const tensor_list results = Autograd::apply(positions, cutoff, max_num_neighbors, box_vectors); - return make_tuple(results[0], results[1], results[2]); - }); -} \ No newline at end of file + m.impl("getNeighborPairs", + [](const Tensor& positions, const Scalar& cutoff, const Scalar& max_num_pairs, + const Tensor& box_vectors, const bool &checkErrors){ + const tensor_list results = Autograd::apply(positions, cutoff, max_num_pairs, + box_vectors, checkErrors); + return make_tuple(results[0], results[1], results[2], results[3]); + }); +} diff --git a/src/pytorch/neighbors/neighbors.cpp b/src/pytorch/neighbors/neighbors.cpp index d8dd5c5b..e5911907 100644 --- a/src/pytorch/neighbors/neighbors.cpp +++ b/src/pytorch/neighbors/neighbors.cpp @@ -1,5 +1,5 @@ #include TORCH_LIBRARY(neighbors, m) { - m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors, Tensor box_vectors) -> (Tensor neighbors, Tensor deltas, Tensor distances)"); -} \ No newline at end of file + m.def("getNeighborPairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors, Tensor box_vectors, bool checkErrors) -> (Tensor neighbors, Tensor deltas, Tensor distances, Tensor num_pairs)"); +}