From 1535a48a3f1375ad2dc5349d25b6b758da3a4145 Mon Sep 17 00:00:00 2001 From: hofee Date: Sun, 5 Jan 2025 15:50:04 +0000 Subject: [PATCH] upd cluster inference --- beans/predict_result.py | 162 ++++++++++++++++++++++++ configs/server/server_train_config.yaml | 5 +- runners/inferencer.py | 1 + 3 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 beans/predict_result.py diff --git a/beans/predict_result.py b/beans/predict_result.py new file mode 100644 index 0000000..db99270 --- /dev/null +++ b/beans/predict_result.py @@ -0,0 +1,162 @@ +import numpy as np +from sklearn.cluster import DBSCAN + +class PredictResult: + def __init__(self, raw_predict_result, input_pts=None, cluster_params=dict(eps=0.5, min_samples=2)): + self.input_pts = input_pts + self.cluster_params = cluster_params + self.sampled_9d_pose = raw_predict_result + self.sampled_matrix_pose = self.get_sampled_matrix_pose() + self.distance_matrix = self.calculate_distance_matrix() + self.clusters = self.get_cluster_result() + self.candidate_matrix_poses = self.get_candidate_poses() + self.candidate_9d_poses = [np.concatenate((self.matrix_to_rotation_6d_numpy(matrix[:3,:3]), matrix[:3,3].reshape(-1,)), axis=-1) for matrix in self.candidate_matrix_poses] + self.cluster_num = len(self.clusters) + + @staticmethod + def rotation_6d_to_matrix_numpy(d6): + a1, a2 = d6[:3], d6[3:] + b1 = a1 / np.linalg.norm(a1) + b2 = a2 - np.dot(b1, a2) * b1 + b2 = b2 / np.linalg.norm(b2) + b3 = np.cross(b1, b2) + return np.stack((b1, b2, b3), axis=-2) + + @staticmethod + def matrix_to_rotation_6d_numpy(matrix): + return np.copy(matrix[:2, :]).reshape((6,)) + + def __str__(self): + info = "Predict Result:\n" + info += f" Predicted pose number: {len(self.sampled_9d_pose)}\n" + info += f" Cluster number: {self.cluster_num}\n" + for i, cluster in enumerate(self.clusters): + info += f" - Cluster {i} size: {len(cluster)}\n" + max_distance = np.max(self.distance_matrix[self.distance_matrix != 0]) + min_distance = np.min(self.distance_matrix[self.distance_matrix != 0]) + info += f" Max distance: {max_distance}\n" + info += f" Min distance: {min_distance}\n" + return info + + def get_sampled_matrix_pose(self): + sampled_matrix_pose = [] + for pose in self.sampled_9d_pose: + rotation = pose[:6] + translation = pose[6:] + pose = self.rotation_6d_to_matrix_numpy(rotation) + pose = np.concatenate((pose, translation.reshape(-1, 1)), axis=-1) + pose = np.concatenate((pose, np.array([[0, 0, 0, 1]])), axis=-2) + sampled_matrix_pose.append(pose) + return np.array(sampled_matrix_pose) + + def rotation_distance(self, R1, R2): + R = np.dot(R1.T, R2) + trace = np.trace(R) + angle = np.arccos(np.clip((trace - 1) / 2, -1, 1)) + return angle + + def calculate_distance_matrix(self): + n = len(self.sampled_matrix_pose) + dist_matrix = np.zeros((n, n)) + for i in range(n): + for j in range(n): + dist_matrix[i, j] = self.rotation_distance(self.sampled_matrix_pose[i][:3, :3], self.sampled_matrix_pose[j][:3, :3]) + return dist_matrix + + def cluster_rotations(self): + clustering = DBSCAN(eps=self.cluster_params['eps'], min_samples=self.cluster_params['min_samples'], metric='precomputed') + labels = clustering.fit_predict(self.distance_matrix) + return labels + + def get_cluster_result(self): + labels = self.cluster_rotations() + cluster_num = len(set(labels)) - (1 if -1 in labels else 0) + clusters = [] + for _ in range(cluster_num): + clusters.append([]) + for matrix_pose, label in zip(self.sampled_matrix_pose, labels): + if label != -1: + clusters[label].append(matrix_pose) + clusters.sort(key=len, reverse=True) + return clusters + + def get_center_matrix_pose_from_cluster(self, cluster): + min_total_distance = float('inf') + center_matrix_pose = None + + for matrix_pose in cluster: + total_distance = 0 + for other_matrix_pose in cluster: + rot_distance = self.rotation_distance(matrix_pose[:3, :3], other_matrix_pose[:3, :3]) + total_distance += rot_distance + + if total_distance < min_total_distance: + min_total_distance = total_distance + center_matrix_pose = matrix_pose + + return center_matrix_pose + + def get_candidate_poses(self): + candidate_poses = [] + for cluster in self.clusters: + candidate_poses.append(self.get_center_matrix_pose_from_cluster(cluster)) + return candidate_poses + + def visualize(self): + import plotly.graph_objects as go + fig = go.Figure() + if self.input_pts is not None: + fig.add_trace(go.Scatter3d( + x=self.input_pts[:, 0], y=self.input_pts[:, 1], z=self.input_pts[:, 2], + mode='markers', marker=dict(size=1, color='gray', opacity=0.5), name='Input Points' + )) + colors = ['aggrnyl', 'agsunset', 'algae', 'amp', 'armyrose', 'balance', + 'blackbody', 'bluered', 'blues', 'blugrn', 'bluyl', 'brbg'] + for i, cluster in enumerate(self.clusters): + color = colors[i] + candidate_pose = self.candidate_matrix_poses[i] + origin_candidate = candidate_pose[:3, 3] + z_axis_candidate = candidate_pose[:3, 2] + for pose in cluster: + origin = pose[:3, 3] + z_axis = pose[:3, 2] + fig.add_trace(go.Cone( + x=[origin[0]], y=[origin[1]], z=[origin[2]], + u=[z_axis[0]], v=[z_axis[1]], w=[z_axis[2]], + colorscale=color, + sizemode="absolute", sizeref=0.05, anchor="tail", showscale=False + )) + fig.add_trace(go.Cone( + x=[origin_candidate[0]], y=[origin_candidate[1]], z=[origin_candidate[2]], + u=[z_axis_candidate[0]], v=[z_axis_candidate[1]], w=[z_axis_candidate[2]], + colorscale=color, + sizemode="absolute", sizeref=0.1, anchor="tail", showscale=False + )) + + fig.update_layout( + title="Clustered Poses and Input Points", + scene=dict( + xaxis_title='X', + yaxis_title='Y', + zaxis_title='Z' + ), + margin=dict(l=0, r=0, b=0, t=40), + scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25)) + ) + + fig.show() + + + +if __name__ == "__main__": + step = 0 + raw_predict_result = np.load(f"inference_result_pack/inference_result_pack/{step}/all_pred_pose_9d.npy") + input_pts = np.loadtxt(f"inference_result_pack/inference_result_pack/{step}/input_pts.txt") + print(raw_predict_result.shape) + predict_result = PredictResult(raw_predict_result, input_pts, cluster_params=dict(eps=0.25, min_samples=3)) + print(predict_result) + print(len(predict_result.candidate_matrix_poses)) + print(predict_result.distance_matrix) + #import ipdb; ipdb.set_trace() + predict_result.visualize() + diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index a3d8bec..87ba6e3 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -7,7 +7,7 @@ runner: parallel: False experiment: - name: train_ab_global_only_with_accept_probability + name: train_ab_global_only_with_wp_p++_dense root_dir: "experiments" use_checkpoint: False epoch: -1 # -1 stands for last epoch @@ -80,7 +80,7 @@ dataset: pipeline: nbv_reconstruction_pipeline: modules: - pts_encoder: pointnet_encoder + pts_encoder: pointnet++_encoder seq_encoder: transformer_seq_encoder pose_encoder: pose_encoder view_finder: gf_view_finder @@ -98,6 +98,7 @@ module: pointnet++_encoder: in_dim: 3 + params_name: dense transformer_seq_encoder: embed_dim: 256 diff --git a/runners/inferencer.py b/runners/inferencer.py index 66e202e..e79b6cc 100644 --- a/runners/inferencer.py +++ b/runners/inferencer.py @@ -4,6 +4,7 @@ from utils.render import RenderUtil from utils.pose import PoseUtil from utils.pts import PtsUtil from utils.reconstruction import ReconstructionUtil +from beans.predict_result import PredictResult import torch from tqdm import tqdm