diff --git a/__pycache__/app_cad.cpython-39.pyc b/__pycache__/app_cad.cpython-39.pyc new file mode 100644 index 0000000..f47a076 Binary files /dev/null and b/__pycache__/app_cad.cpython-39.pyc differ diff --git a/__pycache__/cad_strategy.cpython-39.pyc b/__pycache__/cad_strategy.cpython-39.pyc new file mode 100644 index 0000000..257559c Binary files /dev/null and b/__pycache__/cad_strategy.cpython-39.pyc differ diff --git a/app_cad.py b/app_cad.py new file mode 100644 index 0000000..70ab16c --- /dev/null +++ b/app_cad.py @@ -0,0 +1,9 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.cad_strategy import CADStrategyRunner + +@PytorchBootApplication("cad") +class AppCAD: + @staticmethod + def start(): + CADStrategyRunner("configs/cad_config.yaml").run() + \ No newline at end of file diff --git a/configs/cad_config.yaml b/configs/cad_config.yaml new file mode 100644 index 0000000..3596d7a --- /dev/null +++ b/configs/cad_config.yaml @@ -0,0 +1,27 @@ + +runner: + general: + seed: 1 + device: cpu + cuda_visible_devices: "0,1,2,3,4,5,6,7" + + experiment: + name: debug + root_dir: "experiments" + + generate: + model_dir: "/home/yan20/nbv_rec/data/test_CAD/test_model" + model_start_idx: 0 + voxel_size: 0.005 + max_view: 512 + min_view: 128 + max_diag: 0.7 + min_diag: 0.01 + random_view_ratio: 0.2 + min_cam_table_included_degree: 20 + + reconstruct: + soft_overlap_threshold: 0.3 + hard_overlap_threshold: 0.6 + scan_points_threshold: 10 + \ No newline at end of file diff --git a/runners/__pycache__/cad_strategy.cpython-39.pyc b/runners/__pycache__/cad_strategy.cpython-39.pyc new file mode 100644 index 0000000..3aee3cb Binary files /dev/null and b/runners/__pycache__/cad_strategy.cpython-39.pyc differ diff --git a/runners/cad_strategy.py b/runners/cad_strategy.py new file mode 100644 index 0000000..395af42 --- /dev/null +++ b/runners/cad_strategy.py @@ -0,0 +1,144 @@ +import os +import trimesh +import numpy as np +from PytorchBoot.runners.runner import Runner +from PytorchBoot.config import ConfigManager +import PytorchBoot.stereotype as stereotype +from PytorchBoot.utils.log_util import Log +from PytorchBoot.status import status_manager + +from utils.control_util import ControlUtil +from utils.communicate_util import CommunicateUtil +from utils.pts_util import PtsUtil +from utils.view_sample_util import ViewSampleUtil +from utils.reconstruction_util import ReconstructionUtil + + +@stereotype.runner("CAD_strategy_runner") +class CADStrategyRunner(Runner): + + def __init__(self, config_path: str): + super().__init__(config_path) + self.load_experiment("cad_strategy") + self.status_info = { + "status_manager": status_manager, + "app_name": "cad", + "runner_name": "cad_strategy" + } + self.generate_config = ConfigManager.get("runner", "generate") + self.reconstruct_config = ConfigManager.get("runner", "reconstruct") + self.model_dir = self.generate_config["model_dir"] + self.voxel_size = self.generate_config["voxel_size"] + self.max_view = self.generate_config["max_view"] + self.min_view = self.generate_config["min_view"] + self.max_diag = self.generate_config["max_diag"] + self.min_diag = self.generate_config["min_diag"] + self.min_cam_table_included_degree = self.generate_config["min_cam_table_included_degree"] + self.random_view_ratio = self.generate_config["random_view_ratio"] + + self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"] + self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"] + self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"] + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + + def run_one_model(self, model_name): + + ''' init robot ''' + ControlUtil.init() + ''' load CAD model ''' + model_path = os.path.join(self.model_dir, model_name) + cad_model = trimesh.load(model_path) + ''' take first view ''' + view_data = CommunicateUtil.get_view_data() + first_cam_pts = None + ''' register ''' + cad_to_cam = PtsUtil.register_icp(first_cam_pts, cad_model) + cam_to_world = ControlUtil.get_pose() + cad_to_world = cam_to_world @ cad_to_cam + cad_model:trimesh.Trimesh = cad_model.apply_transform(cad_to_world) + ''' sample view ''' + min_corner = cad_model.bounds[0] + max_corner = cad_model.bounds[1] + diag = np.linalg.norm(max_corner - min_corner) + view_num = int(self.min_view + (diag - self.min_diag)/(self.max_diag - self.min_diag) * (self.max_view - self.min_view)) + sampled_view_data = ViewSampleUtil.sample_view_data_world_space( + cad_model, cad_to_world, + voxel_size = self.voxel_size, + max_views = view_num, + min_cam_table_included_degree= self.min_cam_table_included_degree, + random_view_ratio = self.random_view_ratio + ) + cam_to_world_poses = sampled_view_data["cam_to_world_poses"] + world_model_points = sampled_view_data["voxel_down_sampled_points"] + + ''' take sample view ''' + scan_points_idx_list = [] + sample_view_pts_list = [] + for cam_to_world in cam_to_world_poses: + ControlUtil.move_to(cam_to_world) + ''' get world pts ''' + view_data = CommunicateUtil.get_view_data() + cam_pts = None + scan_points_idx = None + world_pts = PtsUtil.transform_point_cloud(cam_pts, cam_to_world) + sample_view_pts_list.append(world_pts) + scan_points_idx_list.append(scan_points_idx) + + ''' generate strategy ''' + limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap( + world_model_points, sample_view_pts_list, + scan_points_indices_list = scan_points_idx_list, + init_view=0, + threshold=self.voxel_size, + soft_overlap_threshold= self.soft_overlap_threshold, + hard_overlap_threshold= self.hard_overlap_threshold, + scan_points_threshold = self.scan_points_threshold, + status_info=self.status_info + ) + + ''' extract cam_to world sequence ''' + cam_to_world_seq = [] + coveraget_rate_seq = [] + + for idx, coverage_rate in limited_useful_view: + cam_to_world_seq.append(cam_to_world_poses[idx]) + coveraget_rate_seq.append(coverage_rate) + + ''' take best seq view ''' + for cam_to_world in cam_to_world_seq: + ControlUtil.move_to(cam_to_world) + ''' get world pts ''' + view_data = CommunicateUtil.get_view_data() + cam_pts = None + scan_points_idx = None + world_pts = PtsUtil.transform_point_cloud(cam_pts, cam_to_world) + sample_view_pts_list.append(world_pts) + scan_points_idx_list.append(scan_points_idx) + + + def run(self): + total = len(os.listdir(self.model_dir)) + model_start_idx = self.generate_config["model_start_idx"] + count_object = model_start_idx + for model_name in os.listdir(self.model_dir[model_start_idx:]): + Log.info(f"[{count_object}/{total}]Processing {model_name}") + self.run_one_model(model_name) + Log.success(f"[{count_object}/{total}]Finished processing {model_name}") + + +if __name__ == "__main__": + model_path = "/home/yan20/nbv_rec/data/test_CAD/test_model/bear_scaled.ply" + model = trimesh.load(model_path) + test_pts_L = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_L.txt") + test_pts_R = np.loadtxt("/home/yan20/nbv_rec/data/test_CAD/cam_pts_0_R.txt") + cam_to_world_L = PtsUtil.register_icp(test_pts_L, model) + cam_to_world_R = PtsUtil.register_icp(test_pts_R, model) + print(cam_to_world_L) + print("================================") + print(cam_to_world_R) + \ No newline at end of file diff --git a/runners/inference.py b/runners/inference.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/control_util.cpython-39.pyc b/utils/__pycache__/control_util.cpython-39.pyc new file mode 100644 index 0000000..c52b63d Binary files /dev/null and b/utils/__pycache__/control_util.cpython-39.pyc differ diff --git a/utils/communicate_util.py b/utils/communicate_util.py new file mode 100644 index 0000000..47a5820 --- /dev/null +++ b/utils/communicate_util.py @@ -0,0 +1,13 @@ + +class CommunicateUtil: + VIEW_HOST = "127.0.0.1:5000" + INFERENCE_HOST = "127.0.0.1:5000" + + def get_view_data() -> dict: + data = None + return data + + def get_inference_data() -> dict: + data = None + return data + \ No newline at end of file diff --git a/utils/control_util.py b/utils/control_util.py index df9027b..a0f565b 100644 --- a/utils/control_util.py +++ b/utils/control_util.py @@ -6,10 +6,10 @@ class ControlUtil: __fa = FrankaArm(robot_num=2) - BASE_TO_DISPLAYTBLE:np.ndarray = np.asarray([ - [1, 0, 0, 0], + BASE_TO_WORLD:np.ndarray = np.asarray([ + [1, 0, 0, -0.5], [0, 1, 0, 0], - [0, 0, 1, 0], + [0, 0, 1, -0.2], [0, 0, 0, 1] ]) @@ -19,38 +19,123 @@ class ControlUtil: [0, 0, 1, 0], [0, 0, 0, 1] ]) + theta = np.radians(25) + INIT_POSE:np.ndarray = np.asarray([ + [np.cos(theta), 0, -np.sin(theta), 0], + [0, -1, 0, 0], + [-np.sin(theta), 0, -np.cos(theta), 0.35], + [0, 0, 0, 1] + ]) + + AXIS_THRESHOLD = (-(np.pi+5e-2)/2, (np.pi+5e-2)/2) @staticmethod - def get_franka_arm() -> FrankaArm: - return ControlUtil.__fa - + def franka_reset() -> None: + ControlUtil.__fa.reset_joints() + + @staticmethod + def init() -> None: + ControlUtil.set_pose(ControlUtil.INIT_POSE) + @staticmethod def get_pose() -> np.ndarray: - gripper_to_base = ControlUtil.__fa.get_pose().matrix - cam_to_world = ControlUtil.BASE_TO_DISPLAYTBLE @ gripper_to_base @ ControlUtil.CAMERA_TO_GRIPPER + gripper_to_base = ControlUtil.get_curr_gripper_to_base_pose() + cam_to_world = ControlUtil.BASE_TO_WORLD @ gripper_to_base @ ControlUtil.CAMERA_TO_GRIPPER return cam_to_world @staticmethod def set_pose(cam_to_world: np.ndarray) -> None: - gripper_to_base = np.linalg.inv(ControlUtil.BASE_TO_DISPLAYTBLE) @ cam_to_world @ np.linalg.inv(ControlUtil.CAMERA_TO_GRIPPER) + gripper_to_base = ControlUtil.solve_gripper_to_base(cam_to_world) gripper_to_base = RigidTransform(rotation=gripper_to_base[:3, :3], translation=gripper_to_base[:3, 3], from_frame="franka_tool", to_frame="world") ControlUtil.__fa.goto_pose(gripper_to_base, use_impedance=False, ignore_errors=False) + + @staticmethod + def rotate_display_table(degree): + pass @staticmethod - def reset() -> None: - ControlUtil.__fa.reset_joints() + def get_curr_gripper_to_base_pose() -> np.ndarray: + return ControlUtil.__fa.get_pose().matrix + + @staticmethod + def solve_gripper_to_base(cam_to_world: np.ndarray) -> np.ndarray: + return np.linalg.inv(ControlUtil.BASE_TO_WORLD) @ cam_to_world @ np.linalg.inv(ControlUtil.CAMERA_TO_GRIPPER) + + @staticmethod + def sovle_cam_to_world(gripper_to_base: np.ndarray) -> np.ndarray: + return ControlUtil.BASE_TO_WORLD @ gripper_to_base @ ControlUtil.CAMERA_TO_GRIPPER + + @staticmethod + def solve_display_table_rot_and_cam_to_world(cam_to_world: np.ndarray) -> tuple: + gripper_to_base = ControlUtil.solve_gripper_to_base(cam_to_world) + gripper_to_base_axis_angle = ControlUtil.get_gripper_to_base_axis_angle(gripper_to_base) + + if ControlUtil.AXIS_THRESHOLD[0] <= gripper_to_base_axis_angle <= ControlUtil.AXIS_THRESHOLD[1]: + return 0, cam_to_world + else: + for display_table_rot in np.linspace(0.1,180, 1800): + display_table_rot_z_pose = ControlUtil.get_z_axis_rot_mat(display_table_rot) + new_cam_to_world = display_table_rot_z_pose @ cam_to_world + if ControlUtil.AXIS_THRESHOLD[0] <= ControlUtil.get_gripper_to_base_axis_angle(new_cam_to_world) <= ControlUtil.AXIS_THRESHOLD[1]: + return -display_table_rot, new_cam_to_world + + display_table_rot = -display_table_rot + display_table_rot_z_pose = ControlUtil.get_z_axis_rot_mat(display_table_rot) + new_cam_to_world = display_table_rot_z_pose @ cam_to_world + if ControlUtil.AXIS_THRESHOLD[0] <= ControlUtil.get_gripper_to_base_axis_angle(new_cam_to_world) <= ControlUtil.AXIS_THRESHOLD[1]: + return -display_table_rot, new_cam_to_world + + @staticmethod + def get_z_axis_rot_mat(degree): + radian = np.radians(degree) + return np.array([ + [np.cos(radian), -np.sin(radian), 0, 0], + [np.sin(radian), np.cos(radian), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + + @staticmethod + def get_gripper_to_base_axis_angle(gripper_to_base: np.ndarray) -> bool: + rot_mat = gripper_to_base[:3,:3] + gripper_z_axis = rot_mat[:,2] + base_x_axis = np.array([1,0,0]) + angle = np.arccos(np.dot(gripper_z_axis, base_x_axis)) + return angle + + @staticmethod + def move_to(pose: np.ndarray): + rot_degree, cam_to_world = ControlUtil.solve_display_table_rot_and_cam_to_world(pose) + print("table rot degree:", rot_degree) + ControlUtil.rotate_display_table(rot_degree) + ControlUtil.set_pose(cam_to_world) # ----------- Debug Test ------------- if __name__ == "__main__": - test_pose = np.asarray([ - [1, 0, 0, 0.4], - [0, -1, 0, 0], - [0, 0, -1, 0.6], - [0, 0, 0, 1] - ]) - ControlUtil.set_pose(test_pose) - print(ControlUtil.get_pose()) - ControlUtil.reset() - print(ControlUtil.get_pose()) \ No newline at end of file + #ControlUtil.init() + import time + start = time.time() + rot_degree, cam_to_world = ControlUtil.solve_display_table_rot_and_cam_to_world(ControlUtil.INIT_POSE) + end = time.time() + print(f"Time: {end-start}") + print(rot_degree, cam_to_world) + # test_pose = np.asarray([ + # [1, 0, 0, 0.4], + # [0, -1, 0, 0], + # [0, 0, -1, 0.6], + # [0, 0, 0, 1] + # ]) + # ControlUtil.set_pose(test_pose) + # print(ControlUtil.get_pose()) + # ControlUtil.reset() + # print(ControlUtil.get_pose()) + + angle = ControlUtil.get_gripper_to_base_axis_angle(ControlUtil.solve_gripper_to_base(cam_to_world)) + threshold = ControlUtil.AXIS_THRESHOLD + + angle_degree = np.degrees(angle) + threshold_degree = np.degrees(threshold[0]), np.degrees(threshold[1]) + print(f"Angle: {angle_degree}, range: {threshold_degree}") + ControlUtil.set_pose(cam_to_world) \ No newline at end of file diff --git a/utils/pose_util.py b/utils/pose_util.py new file mode 100644 index 0000000..ddaed4a --- /dev/null +++ b/utils/pose_util.py @@ -0,0 +1,151 @@ +import numpy as np + +class PoseUtil: + ROTATION = 1 + TRANSLATION = 2 + SCALE = 3 + + @staticmethod + def get_uniform_translation(trans_m_min, trans_m_max, trans_unit, debug=False): + if isinstance(trans_m_min, list): + x_min, y_min, z_min = trans_m_min + x_max, y_max, z_max = trans_m_max + else: + x_min, y_min, z_min = trans_m_min, trans_m_min, trans_m_min + x_max, y_max, z_max = trans_m_max, trans_m_max, trans_m_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + translation = np.array([x, y, z]) + if trans_unit == "cm": + translation = translation / 100 + if debug: + print("uniform translation:", translation) + return translation + + @staticmethod + def get_uniform_rotation(rot_degree_min=0, rot_degree_max=180, debug=False): + axis = np.random.randn(3) + axis /= np.linalg.norm(axis) + theta = np.random.uniform( + rot_degree_min / 180 * np.pi, rot_degree_max / 180 * np.pi + ) + + K = np.array( + [[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]] + ) + R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K) + if debug: + print("uniform rotation:", theta * 180 / np.pi) + return R + + @staticmethod + def get_uniform_pose( + trans_min, trans_max, rot_min=0, rot_max=180, trans_unit="cm", debug=False + ): + translation = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + rotation = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + pose = np.eye(4) + pose[:3, :3] = rotation + pose[:3, 3] = translation + return pose + + @staticmethod + def get_n_uniform_pose( + trans_min, + trans_max, + rot_min=0, + rot_max=180, + n=1, + trans_unit="cm", + fix=None, + contain_canonical=True, + debug=False, + ): + if fix == PoseUtil.ROTATION: + translations = np.zeros((n, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + if contain_canonical: + translations[0] = np.zeros(3) + rotations = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + elif fix == PoseUtil.TRANSLATION: + rotations = np.zeros((n, 3, 3)) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + rotations[0] = np.eye(3) + translations = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + else: + translations = np.zeros((n, 3)) + rotations = np.zeros((n, 3, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + translations[0] = np.zeros(3) + rotations[0] = np.eye(3) + + pose = np.eye(4, 4, k=0)[np.newaxis, :].repeat(n, axis=0) + pose[:, :3, :3] = rotations + pose[:, :3, 3] = translations + + return pose + + @staticmethod + def get_n_uniform_pose_batch( + trans_min, + trans_max, + rot_min=0, + rot_max=180, + n=1, + batch_size=1, + trans_unit="cm", + fix=None, + contain_canonical=False, + debug=False, + ): + + batch_poses = [] + for i in range(batch_size): + pose = PoseUtil.get_n_uniform_pose( + trans_min, + trans_max, + rot_min, + rot_max, + n, + trans_unit, + fix, + contain_canonical, + debug, + ) + batch_poses.append(pose) + pose_batch = np.stack(batch_poses, axis=0) + return pose_batch + + @staticmethod + def get_uniform_scale(scale_min, scale_max, debug=False): + if isinstance(scale_min, list): + x_min, y_min, z_min = scale_min + x_max, y_max, z_max = scale_max + else: + x_min, y_min, z_min = scale_min, scale_min, scale_min + x_max, y_max, z_max = scale_max, scale_max, scale_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + scale = np.array([x, y, z]) + if debug: + print("uniform scale:", scale) + return scale diff --git a/utils/pts_util.py b/utils/pts_util.py new file mode 100644 index 0000000..6ec45fa --- /dev/null +++ b/utils/pts_util.py @@ -0,0 +1,124 @@ +import numpy as np +import open3d as o3d +import torch +import trimesh +from scipy.spatial import cKDTree + +class PtsUtil: + + @staticmethod + def voxel_downsample_point_cloud(point_cloud, voxel_size=0.005): + o3d_pc = o3d.geometry.PointCloud() + o3d_pc.points = o3d.utility.Vector3dVector(point_cloud) + downsampled_pc = o3d_pc.voxel_down_sample(voxel_size) + return np.asarray(downsampled_pc.points) + + @staticmethod + def random_downsample_point_cloud(point_cloud, num_points, require_idx=False): + if point_cloud.shape[0] == 0: + if require_idx: + return point_cloud, np.array([]) + return point_cloud + idx = np.random.choice(len(point_cloud), num_points, replace=True) + if require_idx: + return point_cloud[idx], idx + return point_cloud[idx] + + @staticmethod + def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False): + N = point_cloud.shape[0] + mask = np.zeros(N, dtype=bool) + + sampled_indices = np.zeros(num_points, dtype=int) + sampled_indices[0] = np.random.randint(0, N) + distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1) + for i in range(1, num_points): + farthest_index = np.argmax(distances) + sampled_indices[i] = farthest_index + mask[farthest_index] = True + + new_distances = np.linalg.norm(point_cloud - point_cloud[farthest_index], axis=1) + distances = np.minimum(distances, new_distances) + + sampled_points = point_cloud[sampled_indices] + if require_idx: + return sampled_points, sampled_indices + return sampled_points + + @staticmethod + def random_downsample_point_cloud_tensor(point_cloud, num_points): + idx = torch.randint(0, len(point_cloud), (num_points,)) + return point_cloud[idx] + + @staticmethod + def voxelize_points(points, voxel_size): + voxel_indices = np.floor(points / voxel_size).astype(np.int32) + unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True) + return unique_voxels + + @staticmethod + def transform_point_cloud(points, pose_mat): + points_h = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1) + points_h = np.dot(pose_mat, points_h.T).T + return points_h[:, :3] + + @staticmethod + def get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size=0.005, require_idx=False): + voxels_L, indices_L = PtsUtil.voxelize_points(point_cloud_L, voxel_size) + voxels_R, _ = PtsUtil.voxelize_points(point_cloud_R, voxel_size) + + voxel_indices_L = voxels_L.view([("", voxels_L.dtype)] * 3) + voxel_indices_R = voxels_R.view([("", voxels_R.dtype)] * 3) + overlapping_voxels = np.intersect1d(voxel_indices_L, voxel_indices_R) + mask_L = np.isin( + indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0] + ) + overlapping_points = point_cloud_L[mask_L] + if require_idx: + return overlapping_points, mask_L + return overlapping_points + + @staticmethod + def filter_points(points, points_normals, cam_pose, voxel_size=0.002, theta=45, z_range=(0.2, 0.45)): + + """ filter with z range """ + points_cam = PtsUtil.transform_point_cloud(points, np.linalg.inv(cam_pose)) + idx = (points_cam[:, 2] > z_range[0]) & (points_cam[:, 2] < z_range[1]) + z_filtered_points = points[idx] + + """ filter with normal """ + sampled_points = PtsUtil.voxel_downsample_point_cloud(z_filtered_points, voxel_size) + kdtree = cKDTree(points_normals[:,:3]) + _, indices = kdtree.query(sampled_points) + nearest_points = points_normals[indices] + + normals = nearest_points[:, 3:] + camera_axis = -cam_pose[:3, 2] + normals_normalized = normals / np.linalg.norm(normals, axis=1, keepdims=True) + cos_theta = np.dot(normals_normalized, camera_axis) + theta_rad = np.deg2rad(theta) + idx = cos_theta > np.cos(theta_rad) + filtered_sampled_points= sampled_points[idx] + return filtered_sampled_points[:, :3] + + @staticmethod + def register_icp(pcl: np.ndarray, model: trimesh.Trimesh, threshold = 0.005) -> np.ndarray: + """ + Register point cloud to CAD model. + Returns the transformation matrix. + """ + + mesh_points = np.asarray(model.vertices) + mesh_point_cloud = o3d.geometry.PointCloud() + mesh_point_cloud.points = o3d.utility.Vector3dVector(mesh_points) + + pcl_point_cloud = o3d.geometry.PointCloud() + pcl_point_cloud.points = o3d.utility.Vector3dVector(pcl) + + + reg_icp = o3d.pipelines.registration.registration_icp( + pcl_point_cloud, mesh_point_cloud, threshold, + np.eye(4), + o3d.pipelines.registration.TransformationEstimationPointToPoint() + ) + return reg_icp.transformation \ No newline at end of file diff --git a/utils/reconstruction_util.py b/utils/reconstruction_util.py new file mode 100644 index 0000000..0e7082d --- /dev/null +++ b/utils/reconstruction_util.py @@ -0,0 +1,160 @@ +import numpy as np +from scipy.spatial import cKDTree +from utils.pts_util import PtsUtil + +class ReconstructionUtil: + + @staticmethod + def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01): + kdtree = cKDTree(combined_point_cloud) + distances, _ = kdtree.query(target_point_cloud) + covered_points_num = np.sum(distances < threshold) + coverage_rate = covered_points_num / target_point_cloud.shape[0] + return coverage_rate, covered_points_num + + @staticmethod + def compute_overlap_rate(new_point_cloud, combined_point_cloud, threshold=0.01): + kdtree = cKDTree(combined_point_cloud) + distances, _ = kdtree.query(new_point_cloud) + overlapping_points = np.sum(distances < threshold) + if new_point_cloud.shape[0] == 0: + overlap_rate = 0 + else: + overlap_rate = overlapping_points / new_point_cloud.shape[0] + return overlap_rate + + + @staticmethod + def get_new_added_points(old_combined_pts, new_pts, threshold=0.005): + if old_combined_pts.size == 0: + return new_pts + if new_pts.size == 0: + return np.array([]) + + tree = cKDTree(old_combined_pts) + distances, _ = tree.query(new_pts, k=1) + new_added_points = new_pts[distances > threshold] + return new_added_points + + @staticmethod + def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, scan_points_indices_list, threshold=0.01, soft_overlap_threshold=0.5, hard_overlap_threshold=0.7, init_view = 0, scan_points_threshold=5, status_info=None): + selected_views = [init_view] + combined_point_cloud = point_cloud_list[init_view] + history_indices = [scan_points_indices_list[init_view]] + + max_rec_pts = np.vstack(point_cloud_list) + downsampled_max_rec_pts = PtsUtil.voxel_downsample_point_cloud(max_rec_pts, threshold) + + max_rec_pts_num = downsampled_max_rec_pts.shape[0] + max_real_rec_pts_coverage, _ = ReconstructionUtil.compute_coverage_rate(target_point_cloud, downsampled_max_rec_pts, threshold) + + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, combined_point_cloud, threshold) + current_coverage = new_coverage + current_covered_num = new_covered_num + + remaining_views = list(range(len(point_cloud_list))) + view_sequence = [(init_view, current_coverage)] + cnt_processed_view = 0 + remaining_views.remove(init_view) + curr_rec_pts_num = combined_point_cloud.shape[0] + + while remaining_views: + best_view = None + best_coverage_increase = -1 + best_combined_point_cloud = None + best_covered_num = 0 + + for view_index in remaining_views: + if point_cloud_list[view_index].shape[0] == 0: + continue + if selected_views: + new_scan_points_indices = scan_points_indices_list[view_index] + if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): + overlap_threshold = hard_overlap_threshold + else: + overlap_threshold = soft_overlap_threshold + overlap_rate = ReconstructionUtil.compute_overlap_rate(point_cloud_list[view_index],combined_point_cloud, threshold) + if overlap_rate < overlap_threshold: + continue + new_combined_point_cloud = np.vstack([combined_point_cloud, point_cloud_list[view_index]]) + new_downsampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(new_combined_point_cloud,threshold) + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, new_downsampled_combined_point_cloud, threshold) + + coverage_increase = new_coverage - current_coverage + if coverage_increase > best_coverage_increase: + best_coverage_increase = coverage_increase + best_view = view_index + best_covered_num = new_covered_num + best_combined_point_cloud = new_downsampled_combined_point_cloud + + + if best_view is not None: + if best_coverage_increase <=1e-3 or best_covered_num - current_covered_num <= 5: + break + + selected_views.append(best_view) + best_rec_pts_num = best_combined_point_cloud.shape[0] + print(f"Current rec pts num: {curr_rec_pts_num}, Best rec pts num: {best_rec_pts_num}, Best cover pts: {best_covered_num}, Max rec pts num: {max_rec_pts_num}") + print(f"Current coverage: {current_coverage}, Best coverage increase: {best_coverage_increase}, Max Real coverage: {max_real_rec_pts_coverage}") + current_covered_num = best_covered_num + curr_rec_pts_num = best_rec_pts_num + combined_point_cloud = best_combined_point_cloud + remaining_views.remove(best_view) + history_indices.append(scan_points_indices_list[best_view]) + current_coverage += best_coverage_increase + cnt_processed_view += 1 + if status_info is not None: + sm = status_info["status_manager"] + app_name = status_info["app_name"] + runner_name = status_info["runner_name"] + sm.set_status(app_name, runner_name, "current coverage", current_coverage) + sm.set_progress(app_name, runner_name, "processed view", cnt_processed_view, len(point_cloud_list)) + + view_sequence.append((best_view, current_coverage)) + + else: + break + if status_info is not None: + sm = status_info["status_manager"] + app_name = status_info["app_name"] + runner_name = status_info["runner_name"] + sm.set_progress(app_name, runner_name, "processed view", len(point_cloud_list), len(point_cloud_list)) + return view_sequence, remaining_views, combined_point_cloud + + + @staticmethod + def generate_scan_points(display_table_top, display_table_radius, min_distance=0.03, max_points_num = 500, max_attempts = 1000): + points = [] + attempts = 0 + while len(points) < max_points_num and attempts < max_attempts: + angle = np.random.uniform(0, 2 * np.pi) + r = np.random.uniform(0, display_table_radius) + x = r * np.cos(angle) + y = r * np.sin(angle) + z = display_table_top + new_point = (x, y, z) + if all(np.linalg.norm(np.array(new_point) - np.array(existing_point)) >= min_distance for existing_point in points): + points.append(new_point) + attempts += 1 + return points + + @staticmethod + def compute_covered_scan_points(scan_points, point_cloud, threshold=0.01): + + tree = cKDTree(point_cloud) + covered_points = [] + indices = [] + for i, scan_point in enumerate(scan_points): + if tree.query_ball_point(scan_point, threshold): + covered_points.append(scan_point) + indices.append(i) + return covered_points, indices + + @staticmethod + def check_scan_points_overlap(history_indices, indices2, threshold=5): + for indices1 in history_indices: + if len(set(indices1).intersection(set(indices2))) >= threshold: + return True + return False + + \ No newline at end of file diff --git a/utils/view_sample_util.py b/utils/view_sample_util.py new file mode 100644 index 0000000..ce70ff1 --- /dev/null +++ b/utils/view_sample_util.py @@ -0,0 +1,162 @@ + +import numpy as np +from utils.pose_util import PoseUtil +import trimesh +from collections import defaultdict +from scipy.spatial.transform import Rotation as R +import random + +class ViewSampleUtil: + @staticmethod + def farthest_point_sampling(points, num_samples): + num_points = points.shape[0] + if num_samples >= num_points: + return points, np.arange(num_points) + sampled_indices = np.zeros(num_samples, dtype=int) + sampled_indices[0] = np.random.randint(num_points) + min_distances = np.full(num_points, np.inf) + for i in range(1, num_samples): + current_point = points[sampled_indices[i - 1]] + dist_to_current_point = np.linalg.norm(points - current_point, axis=1) + min_distances = np.minimum(min_distances, dist_to_current_point) + sampled_indices[i] = np.argmax(min_distances) + downsampled_points = points[sampled_indices] + return downsampled_points, sampled_indices + + @staticmethod + def voxel_downsample(points, voxel_size): + voxel_grid = defaultdict(list) + for i, point in enumerate(points): + voxel_index = tuple((point // voxel_size).astype(int)) + voxel_grid[voxel_index].append(i) + + downsampled_points = [] + downsampled_indices = [] + for indices in voxel_grid.values(): + selected_index = indices[0] + downsampled_points.append(points[selected_index]) + downsampled_indices.append(selected_index) + + return np.array(downsampled_points), downsampled_indices + + @staticmethod + def sample_view_data(mesh: trimesh.Trimesh, distance_range: tuple = (0.25, 0.5), voxel_size: float = 0.005, max_views: int = 1, pertube_repeat: int = 1) -> dict: + view_data = { + "look_at_points": [], + "cam_positions": [], + } + + vertices = mesh.vertices + look_at_points = [] + cam_positions = [] + normals = [] + vertex_normals = mesh.vertex_normals + + for i, vertex in enumerate(vertices): + look_at_point = vertex + + view_data["look_at_points"].append(look_at_point) + + normal = vertex_normals[i] + if np.isnan(normal).any(): + continue + if np.dot(normal, look_at_point) < 0: + normal = -normal + + normals.append(normal) + + for _ in range(pertube_repeat): + perturb_angle = np.radians(np.random.uniform(0, 30)) + perturb_axis = np.random.normal(size=3) + perturb_axis /= np.linalg.norm(perturb_axis) + rotation_matrix = R.from_rotvec(perturb_angle * perturb_axis).as_matrix() + perturbed_normal = np.dot(rotation_matrix, normal) + + distance = np.random.uniform(*distance_range) + cam_position = look_at_point + distance * perturbed_normal + look_at_points.append(look_at_point) + cam_positions.append(cam_position) + + look_at_points = np.array(look_at_points) + cam_positions = np.array(cam_positions) + + voxel_downsampled_look_at_points, selected_indices = ViewSampleUtil.voxel_downsample(look_at_points, voxel_size) + voxel_downsampled_cam_positions = cam_positions[selected_indices] + voxel_downsampled_normals = np.array(normals)[selected_indices] + + fps_downsampled_look_at_points, selected_indices = ViewSampleUtil.farthest_point_sampling(voxel_downsampled_look_at_points, max_views * 2) + fps_downsampled_cam_positions = voxel_downsampled_cam_positions[selected_indices] + + view_data["look_at_points"] = fps_downsampled_look_at_points.tolist() + view_data["cam_positions"] = fps_downsampled_cam_positions.tolist() + view_data["normals"] = voxel_downsampled_normals.tolist() + view_data["voxel_down_sampled_points"] = voxel_downsampled_look_at_points + + return view_data + + @staticmethod + def get_world_points_and_normals(view_data: dict, obj_world_pose: np.ndarray) -> tuple: + world_points = [] + world_normals = [] + for voxel_down_sampled_points, normal in zip(view_data["voxel_down_sampled_points"], view_data["normals"]): + voxel_down_sampled_points_world = obj_world_pose @ np.append(voxel_down_sampled_points, 1.0) + normal_world = obj_world_pose[:3, :3] @ normal + world_points.append(voxel_down_sampled_points_world[:3]) + world_normals.append(normal_world) + return np.array(world_points), np.array(world_normals) + + @staticmethod + def get_cam_pose(view_data: dict, obj_world_pose: np.ndarray, max_views: int, min_cam_table_included_degree: int, random_view_ratio: float) -> np.ndarray: + cam_poses = [] + min_height_z = 1000 + for look_at_point, cam_position in zip(view_data["look_at_points"], view_data["cam_positions"]): + look_at_point_world = obj_world_pose @ np.append(look_at_point, 1.0) + cam_position_world = obj_world_pose @ np.append(cam_position, 1.0) + if look_at_point_world[2] < min_height_z: + min_height_z = look_at_point_world[2] + look_at_point_world = look_at_point_world[:3] + cam_position_world = cam_position_world[:3] + + forward_vector = cam_position_world - look_at_point_world + forward_vector /= np.linalg.norm(forward_vector) + + up_vector = np.array([0, 0, 1]) + + right_vector = np.cross(up_vector, forward_vector) + right_vector /= np.linalg.norm(right_vector) + + corrected_up_vector = np.cross(forward_vector, right_vector) + rotation_matrix = np.array([right_vector, corrected_up_vector, forward_vector]).T + + cam_pose = np.eye(4) + cam_pose[:3, :3] = rotation_matrix + cam_pose[:3, 3] = cam_position_world + cam_poses.append(cam_pose) + + filtered_cam_poses = [] + for cam_pose in cam_poses: + if cam_pose[2, 3] > min_height_z: + direction_vector = cam_pose[:3, 2] + horizontal_normal = np.array([0, 0, 1]) + cos_angle = np.dot(direction_vector, horizontal_normal) / (np.linalg.norm(direction_vector) * np.linalg.norm(horizontal_normal)) + angle = np.arccos(np.clip(cos_angle, -1.0, 1.0)) + angle_degree = np.degrees(angle) + if angle_degree < 90 - min_cam_table_included_degree: + filtered_cam_poses.append(cam_pose) + if random.random() < random_view_ratio: + pertube_pose = PoseUtil.get_uniform_pose([0.1, 0.1, 0.1], [3, 3, 3], 0, 180, "cm") + filtered_cam_poses.append(pertube_pose @ cam_pose) + + if len(filtered_cam_poses) > max_views: + indices = np.random.choice(len(filtered_cam_poses), max_views, replace=False) + filtered_cam_poses = [filtered_cam_poses[i] for i in indices] + + return np.array(filtered_cam_poses) + + @staticmethod + def sample_view_data_world_space(mesh: trimesh.Trimesh, cad_to_world: np.ndarray, distance_range:tuple = (0.25,0.5), voxel_size:float = 0.005, max_views: int=1, min_cam_table_included_degree:int=20, random_view_ratio:float = 0.2) -> dict: + view_data = ViewSampleUtil.sample_view_data(mesh, distance_range, voxel_size, max_views) + view_data["cam_to_world_poses"] = ViewSampleUtil.get_cam_pose(view_data, cad_to_world, max_views, min_cam_table_included_degree, random_view_ratio) + view_data["voxel_down_sampled_points"], view_data["normals"] = ViewSampleUtil.get_world_points_and_normals(view_data, cad_to_world) + return view_data +