forked from chrischoy/pytorch_knn_cuda
-
Notifications
You must be signed in to change notification settings - Fork 1
/
__init__.py
38 lines (25 loc) · 835 Bytes
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import unittest
import torch
from torch.autograd import Variable, Function
import knn_pytorch
class KNearestNeighbor(Function):
""" Compute k nearest neighbors for each query point.
"""
def __init__(self, k):
self.k = k
def forward(self, ref, query):
ref = ref.float().cuda()
query = query.float().cuda()
inds = torch.empty(query.shape[0], self.k, query.shape[2]).long().cuda()
# make sure inputs are contiguous
knn_pytorch.knn(ref.contiguous(), query.contiguous(), inds.contiguous())
return inds
class TestKNearestNeighbor(unittest.TestCase):
def test_forward(self):
D, N, M = 128, 100, 1000
ref = Variable(torch.rand(2, D, N))
query = Variable(torch.rand(2, D, M))
inds = KNearestNeighbor(2)(ref, query)
# print inds
if __name__ == '__main__':
unittest.main()