upd cluster inference
This commit is contained in:
parent
88d44f020e
commit
1535a48a3f
162
beans/predict_result.py
Normal file
162
beans/predict_result.py
Normal file
@ -0,0 +1,162 @@
|
||||
import numpy as np
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
class PredictResult:
|
||||
def __init__(self, raw_predict_result, input_pts=None, cluster_params=dict(eps=0.5, min_samples=2)):
|
||||
self.input_pts = input_pts
|
||||
self.cluster_params = cluster_params
|
||||
self.sampled_9d_pose = raw_predict_result
|
||||
self.sampled_matrix_pose = self.get_sampled_matrix_pose()
|
||||
self.distance_matrix = self.calculate_distance_matrix()
|
||||
self.clusters = self.get_cluster_result()
|
||||
self.candidate_matrix_poses = self.get_candidate_poses()
|
||||
self.candidate_9d_poses = [np.concatenate((self.matrix_to_rotation_6d_numpy(matrix[:3,:3]), matrix[:3,3].reshape(-1,)), axis=-1) for matrix in self.candidate_matrix_poses]
|
||||
self.cluster_num = len(self.clusters)
|
||||
|
||||
@staticmethod
|
||||
def rotation_6d_to_matrix_numpy(d6):
|
||||
a1, a2 = d6[:3], d6[3:]
|
||||
b1 = a1 / np.linalg.norm(a1)
|
||||
b2 = a2 - np.dot(b1, a2) * b1
|
||||
b2 = b2 / np.linalg.norm(b2)
|
||||
b3 = np.cross(b1, b2)
|
||||
return np.stack((b1, b2, b3), axis=-2)
|
||||
|
||||
@staticmethod
|
||||
def matrix_to_rotation_6d_numpy(matrix):
|
||||
return np.copy(matrix[:2, :]).reshape((6,))
|
||||
|
||||
def __str__(self):
|
||||
info = "Predict Result:\n"
|
||||
info += f" Predicted pose number: {len(self.sampled_9d_pose)}\n"
|
||||
info += f" Cluster number: {self.cluster_num}\n"
|
||||
for i, cluster in enumerate(self.clusters):
|
||||
info += f" - Cluster {i} size: {len(cluster)}\n"
|
||||
max_distance = np.max(self.distance_matrix[self.distance_matrix != 0])
|
||||
min_distance = np.min(self.distance_matrix[self.distance_matrix != 0])
|
||||
info += f" Max distance: {max_distance}\n"
|
||||
info += f" Min distance: {min_distance}\n"
|
||||
return info
|
||||
|
||||
def get_sampled_matrix_pose(self):
|
||||
sampled_matrix_pose = []
|
||||
for pose in self.sampled_9d_pose:
|
||||
rotation = pose[:6]
|
||||
translation = pose[6:]
|
||||
pose = self.rotation_6d_to_matrix_numpy(rotation)
|
||||
pose = np.concatenate((pose, translation.reshape(-1, 1)), axis=-1)
|
||||
pose = np.concatenate((pose, np.array([[0, 0, 0, 1]])), axis=-2)
|
||||
sampled_matrix_pose.append(pose)
|
||||
return np.array(sampled_matrix_pose)
|
||||
|
||||
def rotation_distance(self, R1, R2):
|
||||
R = np.dot(R1.T, R2)
|
||||
trace = np.trace(R)
|
||||
angle = np.arccos(np.clip((trace - 1) / 2, -1, 1))
|
||||
return angle
|
||||
|
||||
def calculate_distance_matrix(self):
|
||||
n = len(self.sampled_matrix_pose)
|
||||
dist_matrix = np.zeros((n, n))
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
dist_matrix[i, j] = self.rotation_distance(self.sampled_matrix_pose[i][:3, :3], self.sampled_matrix_pose[j][:3, :3])
|
||||
return dist_matrix
|
||||
|
||||
def cluster_rotations(self):
|
||||
clustering = DBSCAN(eps=self.cluster_params['eps'], min_samples=self.cluster_params['min_samples'], metric='precomputed')
|
||||
labels = clustering.fit_predict(self.distance_matrix)
|
||||
return labels
|
||||
|
||||
def get_cluster_result(self):
|
||||
labels = self.cluster_rotations()
|
||||
cluster_num = len(set(labels)) - (1 if -1 in labels else 0)
|
||||
clusters = []
|
||||
for _ in range(cluster_num):
|
||||
clusters.append([])
|
||||
for matrix_pose, label in zip(self.sampled_matrix_pose, labels):
|
||||
if label != -1:
|
||||
clusters[label].append(matrix_pose)
|
||||
clusters.sort(key=len, reverse=True)
|
||||
return clusters
|
||||
|
||||
def get_center_matrix_pose_from_cluster(self, cluster):
|
||||
min_total_distance = float('inf')
|
||||
center_matrix_pose = None
|
||||
|
||||
for matrix_pose in cluster:
|
||||
total_distance = 0
|
||||
for other_matrix_pose in cluster:
|
||||
rot_distance = self.rotation_distance(matrix_pose[:3, :3], other_matrix_pose[:3, :3])
|
||||
total_distance += rot_distance
|
||||
|
||||
if total_distance < min_total_distance:
|
||||
min_total_distance = total_distance
|
||||
center_matrix_pose = matrix_pose
|
||||
|
||||
return center_matrix_pose
|
||||
|
||||
def get_candidate_poses(self):
|
||||
candidate_poses = []
|
||||
for cluster in self.clusters:
|
||||
candidate_poses.append(self.get_center_matrix_pose_from_cluster(cluster))
|
||||
return candidate_poses
|
||||
|
||||
def visualize(self):
|
||||
import plotly.graph_objects as go
|
||||
fig = go.Figure()
|
||||
if self.input_pts is not None:
|
||||
fig.add_trace(go.Scatter3d(
|
||||
x=self.input_pts[:, 0], y=self.input_pts[:, 1], z=self.input_pts[:, 2],
|
||||
mode='markers', marker=dict(size=1, color='gray', opacity=0.5), name='Input Points'
|
||||
))
|
||||
colors = ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance',
|
||||
'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg']
|
||||
for i, cluster in enumerate(self.clusters):
|
||||
color = colors[i]
|
||||
candidate_pose = self.candidate_matrix_poses[i]
|
||||
origin_candidate = candidate_pose[:3, 3]
|
||||
z_axis_candidate = candidate_pose[:3, 2]
|
||||
for pose in cluster:
|
||||
origin = pose[:3, 3]
|
||||
z_axis = pose[:3, 2]
|
||||
fig.add_trace(go.Cone(
|
||||
x=[origin[0]], y=[origin[1]], z=[origin[2]],
|
||||
u=[z_axis[0]], v=[z_axis[1]], w=[z_axis[2]],
|
||||
colorscale=color,
|
||||
sizemode="absolute", sizeref=0.05, anchor="tail", showscale=False
|
||||
))
|
||||
fig.add_trace(go.Cone(
|
||||
x=[origin_candidate[0]], y=[origin_candidate[1]], z=[origin_candidate[2]],
|
||||
u=[z_axis_candidate[0]], v=[z_axis_candidate[1]], w=[z_axis_candidate[2]],
|
||||
colorscale=color,
|
||||
sizemode="absolute", sizeref=0.1, anchor="tail", showscale=False
|
||||
))
|
||||
|
||||
fig.update_layout(
|
||||
title="Clustered Poses and Input Points",
|
||||
scene=dict(
|
||||
xaxis_title='X',
|
||||
yaxis_title='Y',
|
||||
zaxis_title='Z'
|
||||
),
|
||||
margin=dict(l=0, r=0, b=0, t=40),
|
||||
scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25))
|
||||
)
|
||||
|
||||
fig.show()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
step = 0
|
||||
raw_predict_result = np.load(f"inference_result_pack/inference_result_pack/{step}/all_pred_pose_9d.npy")
|
||||
input_pts = np.loadtxt(f"inference_result_pack/inference_result_pack/{step}/input_pts.txt")
|
||||
print(raw_predict_result.shape)
|
||||
predict_result = PredictResult(raw_predict_result, input_pts, cluster_params=dict(eps=0.25, min_samples=3))
|
||||
print(predict_result)
|
||||
print(len(predict_result.candidate_matrix_poses))
|
||||
print(predict_result.distance_matrix)
|
||||
#import ipdb; ipdb.set_trace()
|
||||
predict_result.visualize()
|
||||
|
@ -7,7 +7,7 @@ runner:
|
||||
parallel: False
|
||||
|
||||
experiment:
|
||||
name: train_ab_global_only_with_accept_probability
|
||||
name: train_ab_global_only_with_wp_p++_dense
|
||||
root_dir: "experiments"
|
||||
use_checkpoint: False
|
||||
epoch: -1 # -1 stands for last epoch
|
||||
@ -80,7 +80,7 @@ dataset:
|
||||
pipeline:
|
||||
nbv_reconstruction_pipeline:
|
||||
modules:
|
||||
pts_encoder: pointnet_encoder
|
||||
pts_encoder: pointnet++_encoder
|
||||
seq_encoder: transformer_seq_encoder
|
||||
pose_encoder: pose_encoder
|
||||
view_finder: gf_view_finder
|
||||
@ -98,6 +98,7 @@ module:
|
||||
|
||||
pointnet++_encoder:
|
||||
in_dim: 3
|
||||
params_name: dense
|
||||
|
||||
transformer_seq_encoder:
|
||||
embed_dim: 256
|
||||
|
@ -4,6 +4,7 @@ from utils.render import RenderUtil
|
||||
from utils.pose import PoseUtil
|
||||
from utils.pts import PtsUtil
|
||||
from utils.reconstruction import ReconstructionUtil
|
||||
from beans.predict_result import PredictResult
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
Loading…
x
Reference in New Issue
Block a user