nbv_rec_franka_control/runners/cad_close_loop_new.py
2025-01-15 14:42:54 +08:00

341 lines
15 KiB
Python

import os
import time
import trimesh
import tempfile
import subprocess
import numpy as np
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from PytorchBoot.status import status_manager
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.reconstruction_util import ReconstructionUtil
from utils.preprocess_util import save_scene_data
from utils.data_load import DataLoadUtil
from utils.view_util import ViewUtil
class PointCloud:
def __init__(points, camera, name):
pass
class PointCloudGroup:
def __init__(point_clouds, name):
pass
@stereotype.runner("temp")
class CADCloseLoopOnlineRegStrategyRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("cad_strategy")
self.generate_config = ConfigManager.get("runner", "generate")
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
self.output_dir = self.generate_config["output_dir"]
self.model_dir = self.generate_config["model_dir"]
self.object_name = self.generate_config["object_name"]
self.blender_bin_path = self.generate_config["blender_bin_path"]
self.generator_script_path = self.generate_config["generator_script_path"]
self.voxel_size = self.generate_config["voxel_size"]
self.max_shot_view_num = self.reconstruct_config["max_shot_view_num"]
self.min_shot_new_pts_num = self.reconstruct_config["min_shot_new_pts_num"]
self.min_coverage_increase = self.reconstruct_config["min_coverage_increase"]
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0):
scan_pts = world_pts[world_pts[:, 2] < z_threshold]
obj_pts = world_pts[world_pts[:, 2] >= z_threshold]
return scan_pts, obj_pts
def loop_scan(self, first_cam_to_real_world):
view_pts_list = []
first_view_data = CommunicateUtil.get_view_data(init=True)
ControlUtil.absolute_rotate_display_table(90)
first_pts = ViewUtil.get_pts(first_view_data)
first_real_world_pts = PtsUtil.transform_point_cloud(
first_pts, first_cam_to_real_world
)
_, first_splitted_real_world_pts = self.split_scan_pts_and_obj_pts(
first_real_world_pts
)
view_pts_list.append(first_splitted_real_world_pts)
shot_num = 4
for i in range(shot_num-1):
view_data = CommunicateUtil.get_view_data()
if i != shot_num - 2:
ControlUtil.absolute_rotate_display_table(90)
time.sleep(0.5)
if view_data is None:
Log.error("No view data received")
continue
view_pts = ViewUtil.get_pts(view_data)
real_world_pts = PtsUtil.transform_point_cloud(
view_pts, first_cam_to_real_world
)
_, splitted_real_world_pts = self.split_scan_pts_and_obj_pts(
real_world_pts
)
view_pts_list.append(splitted_real_world_pts)
return view_pts_list
def register(self):
ControlUtil.connect_robot()
""" init robot """
Log.info("start init")
ControlUtil.init()
first_cam_to_real_world = ControlUtil.get_pose()
""" loop shooting """
Log.info("start loop shooting")
view_pts_list = self.loop_scan(first_cam_to_real_world)
""" register """
Log.info("start register")
pts = np.vstack(view_pts_list)
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
if not os.path.exists(os.path.join(self.output_dir, self.object_name)):
os.makedirs(os.path.join(self.output_dir, self.object_name))
scene_dir = os.path.join(self.output_dir, self.object_name)
model_path = os.path.join(self.model_dir, self.object_name, "mesh.obj")
cad_model = trimesh.load(model_path)
real_world_to_cad = PtsUtil.register(pts, cad_model)
cad_to_real_world = np.linalg.inv(real_world_to_cad)
Log.success("finish init and register")
real_world_to_blender_world = np.eye(4)
real_world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215])
cad_model_real_world: trimesh.Trimesh = cad_model.apply_transform(
cad_to_real_world
)
cad_model_real_world.export(os.path.join(scene_dir, "mesh.obj"))
#downsampled_pts = PtsUtil.voxel_downsample_point_cloud(pts, self.voxel_size)
np.savetxt(os.path.join(scene_dir, "pts_for_init_reg.txt"), pts)
return cad_to_real_world
def render_data(self):
scene_dir = os.path.join(self.output_dir, self.object_name)
result = subprocess.run(
[
self.blender_bin_path,
"-b",
"-P",
self.generator_script_path,
"--",
scene_dir,
],
capture_output=True,
text=True,
)
print(result)
def preprocess_data(self):
save_scene_data(self.output_dir, self.object_name, file_type="npy")
def get_scan_points_indices(self, scan_points, mask, object_mask_label= (0, 255, 0, 255), cam_intrinsic = None, cam_extrinsic = None):
scan_points_homogeneous = np.hstack((scan_points, np.ones((scan_points.shape[0], 1))))
points_camera = np.dot(np.linalg.inv(cam_extrinsic), scan_points_homogeneous.T).T[:, :3]
points_image_homogeneous = np.dot(cam_intrinsic, points_camera.T).T
points_image_homogeneous /= points_image_homogeneous[:, 2:]
pixel_x = points_image_homogeneous[:, 0].astype(int)
pixel_y = points_image_homogeneous[:, 1].astype(int)
h, w = mask.shape[:2]
valid_indices = (pixel_x >= 0) & (pixel_x < w) & (pixel_y >= 0) & (pixel_y < h)
mask_colors = mask[pixel_y[valid_indices], pixel_x[valid_indices]]
selected_points_indices = np.where((mask_colors != object_mask_label).any(axis=-1))[0]
selected_points_indices = np.where(valid_indices)[0][selected_points_indices]
return selected_points_indices
def run_one_model(self, model_name):
scene_dir = os.path.join(self.output_dir, model_name)
ControlUtil.connect_robot()
""" init robot """
Log.info("start init")
ControlUtil.init()
first_cam_to_real_world = ControlUtil.get_pose()
""" loop shooting """
Log.info("start loop shooting")
view_pts_list = self.loop_scan(first_cam_to_real_world)
""" register """
cad_path = os.path.join(scene_dir, "mesh.obj")
cad_model = trimesh.load(cad_path)
Log.info("start register")
init_pts = np.vstack(view_pts_list)
real_world_to_cad = PtsUtil.register(init_pts, cad_model)
curr_cad_to_real_world = np.linalg.inv(real_world_to_cad)
# np.savetxt(os.path.join("/home/yan20/nbv_rec/project/franka_control/debug", "pts_for_init_reg.txt"), init_pts)
# debug_cad = cad_model.apply_transform(curr_cad_to_real_world)
# debug_cad.export(os.path.join("/home/yan20/nbv_rec/project/franka_control/debug", "cad_for_init_reg.obj"))
pts_dir = os.path.join(scene_dir, "pts")
sample_view_pts_list = []
frame_num = len(os.listdir(pts_dir))
for frame_idx in range(frame_num):
pts_path = os.path.join(scene_dir, "pts", f"{frame_idx}.npy")
point_cloud = np.load(pts_path)
if point_cloud.shape[0] != 0:
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(
point_cloud, self.voxel_size
)
sample_view_pts_list.append(sampled_point_cloud)
else:
sample_view_pts_list.append(np.zeros((0, 3)))
""" close-loop online registery strategy """
scanned_pts = PtsUtil.voxel_downsample_point_cloud(init_pts, voxel_size=self.voxel_size)
shot_pts_list = []
last_coverage = 0
Log.info("start close-loop control")
cnt = 0
mask_list = []
cam_to_cad_list = []
cam_R_to_cad_list = []
shot_view_idx_list = []
scan_points_path = os.path.join(self.output_dir, self.object_name, "scan_points.txt")
display_table_scan_points = np.loadtxt(scan_points_path)
for i in range(frame_num):
path = DataLoadUtil.get_path(self.output_dir, self.object_name, i)
mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True)
mask_list.append((mask_L, mask_R))
cam_info = DataLoadUtil.load_cam_info(path, binocular=True)
cam_to_cad = cam_info["cam_to_world"]
cam_to_cad_list.append(cam_to_cad)
cam_R_to_cad = cam_info["cam_to_world_R"]
cam_R_to_cad_list.append(cam_R_to_cad)
selected_view = []
while True:
import ipdb; ipdb.set_trace()
history_indices = []
scan_points_idx_list = []
for i in range(frame_num):
cam_to_cad = cam_to_cad_list[i]
cam_R_to_cad = cam_R_to_cad_list[i]
curr_cam_L_to_world = curr_cad_to_real_world @ cam_to_cad
curr_cam_R_to_world = curr_cad_to_real_world @ cam_R_to_cad
scan_points_indices_L = self.get_scan_points_indices(display_table_scan_points, mask_list[i][0], cam_intrinsic=cam_info["cam_intrinsic"], cam_extrinsic=curr_cam_L_to_world)
scan_points_indices_R = self.get_scan_points_indices(display_table_scan_points, mask_list[i][1], cam_intrinsic=cam_info["cam_intrinsic"], cam_extrinsic=curr_cam_R_to_world)
scan_points_indices = np.intersect1d(scan_points_indices_L, scan_points_indices_R)
scan_points_idx_list.append(scan_points_indices)
for shot_view_idx in shot_view_idx_list:
history_indices.append(scan_points_idx_list[shot_view_idx])
cad_scanned_pts = PtsUtil.transform_point_cloud(scanned_pts, np.linalg.inv(curr_cad_to_real_world))
next_best_view, next_best_coverage, next_best_covered_num = (
ReconstructionUtil.compute_next_best_view_with_overlap(
cad_scanned_pts,
sample_view_pts_list,
selected_view,
history_indices,
scan_points_idx_list,
threshold=self.voxel_size,
overlap_area_threshold=25,
scan_points_threshold=self.scan_points_threshold,
)
)
if next_best_view is None:
Log.warning("No next best view found")
selected_view.append(next_best_view)
nbv_path = DataLoadUtil.get_path(self.output_dir, self.object_name, next_best_view)
nbv_cam_info = DataLoadUtil.load_cam_info(nbv_path, binocular=True)
nbv_cam_to_cad = nbv_cam_info["cam_to_world_O"]
nbv_cam_to_world = curr_cad_to_real_world @ nbv_cam_to_cad
target_camera_pose = nbv_cam_to_world @ ControlUtil.CAMERA_CORRECTION
ControlUtil.move_to(target_camera_pose)
''' get world pts '''
time.sleep(0.5)
view_data = CommunicateUtil.get_view_data()
if view_data is None:
Log.error("No view data received")
continue
shot_view_idx_list.append(next_best_view)
cam_shot_pts = ViewUtil.get_pts(view_data)
world_shot_pts = PtsUtil.transform_point_cloud(
cam_shot_pts, first_cam_to_real_world
)
_, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts(
world_shot_pts
)
shot_pts_list.append(world_splitted_shot_pts)
debug_dir = os.path.join(scene_dir, "debug")
if not os.path.exists(debug_dir):
os.makedirs(debug_dir)
last_scanned_pts_num = scanned_pts.shape[0]
import ipdb;ipdb.set_trace()
new_scanned_pts = PtsUtil.voxel_downsample_point_cloud(
np.vstack([scanned_pts, world_splitted_shot_pts]), self.voxel_size
)
# last_real_world_to_cad = real_world_to_cad
# real_world_to_cad = PtsUtil.register_fine(new_scanned_pts, cad_model)
# # rot distance of two rotation matrix
# rot_dist = np.arccos(
# (np.trace(real_world_to_cad[:3, :3].T @ last_real_world_to_cad[:3, :3]) - 1) / 2
# )
# print(f"-----rot dist: {rot_dist}")
curr_cad_to_real_world = np.linalg.inv(real_world_to_cad)
cad_splitted_shot_pts = PtsUtil.transform_point_cloud(world_splitted_shot_pts, real_world_to_cad)
np.savetxt(os.path.join(debug_dir, f"shot_pts_{cnt}.txt"), world_splitted_shot_pts)
np.savetxt(os.path.join(debug_dir, f"render_pts_{cnt}.txt"), sample_view_pts_list[next_best_view])
np.savetxt(os.path.join(debug_dir, f"reg_scanned_pts_{cnt}.txt"), new_scanned_pts)
cad_pts = cad_model.vertices
world_cad_pts = PtsUtil.transform_point_cloud(cad_pts, curr_cad_to_real_world)
np.savetxt(os.path.join(debug_dir, f"world_cad_pts_{cnt}.txt"), world_cad_pts)
#import ipdb; ipdb.set_trace()
new_scanned_pts_num = new_scanned_pts.shape[0]
scanned_pts = new_scanned_pts
Log.info(
f"Next Best cover pts: {next_best_covered_num}, Best coverage: {next_best_coverage}"
)
coverage_rate_increase = next_best_coverage - last_coverage
if coverage_rate_increase < self.min_coverage_increase:
Log.info(f"Coverage rate = {coverage_rate_increase} < {self.min_coverage_increase}, stop scanning")
# break
last_coverage = next_best_coverage
new_added_pts_num = new_scanned_pts_num - last_scanned_pts_num
if new_added_pts_num < self.min_shot_new_pts_num:
Log.info(f"New added pts num = {new_added_pts_num} < {self.min_shot_new_pts_num}")
#ipdb.set_trace()
if len(shot_pts_list) >= self.max_shot_view_num:
Log.info(f"Scanned view num = {len(shot_pts_list)} >= {self.max_shot_view_num}, stop scanning")
#break
cnt += 1
Log.success("[Part 4/4] finish close-loop control")
def run(self):
self.run_one_model(self.object_name)
# ---------------------------- test ---------------------------- #
if __name__ == "__main__":
model_path = r"/home/yan20/nbv_rec/data/models/workpiece_1/mesh.obj"
model = trimesh.load(model_path)