success
This commit is contained in:
17
baselines/grasping/GSNet/knn/knn_modules.py
Executable file
17
baselines/grasping/GSNet/knn/knn_modules.py
Executable file
@@ -0,0 +1,17 @@
|
||||
import unittest
|
||||
import gc
|
||||
import operator as op
|
||||
import functools
|
||||
import torch
|
||||
from torch.autograd import Variable, Function
|
||||
from knn_pytorch import knn_pytorch
|
||||
# import knn_pytorch
|
||||
def knn(ref, query, k=1):
|
||||
""" Compute k nearest neighbors for each query point.
|
||||
"""
|
||||
device = ref.device
|
||||
ref = ref.float().to(device)
|
||||
query = query.float().to(device)
|
||||
inds = torch.empty(query.shape[0], k, query.shape[2]).long().to(device)
|
||||
knn_pytorch.knn(ref, query, inds)
|
||||
return inds
|
Reference in New Issue
Block a user