update fps algo and fps mask
This commit is contained in:
@@ -24,13 +24,12 @@ class PtsUtil:
|
||||
return point_cloud[idx]
|
||||
|
||||
@staticmethod
|
||||
def fps_downsample_point_cloud(point_cloud, num_points, require_mask=False):
|
||||
def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False):
|
||||
N = point_cloud.shape[0]
|
||||
mask = np.zeros(N, dtype=bool)
|
||||
|
||||
sampled_indices = np.zeros(num_points, dtype=int)
|
||||
sampled_indices[0] = np.random.randint(0, N)
|
||||
mask[sampled_indices[0]] = True
|
||||
distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1)
|
||||
for i in range(1, num_points):
|
||||
farthest_index = np.argmax(distances)
|
||||
@@ -41,8 +40,8 @@ class PtsUtil:
|
||||
distances = np.minimum(distances, new_distances)
|
||||
|
||||
sampled_points = point_cloud[sampled_indices]
|
||||
if require_mask:
|
||||
return sampled_points, mask
|
||||
if require_idx:
|
||||
return sampled_points, sampled_indices
|
||||
return sampled_points
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user