nbv_rec_franka_control/runners/cad_close_loop_new.py

341 lines
15 KiB
Python
Raw Normal View History

2025-01-15 14:42:54 +08:00
import os
import time
import trimesh
import tempfile
import subprocess
import numpy as np
from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager
import PytorchBoot.stereotype as stereotype
from PytorchBoot.utils.log_util import Log
from PytorchBoot.status import status_manager
from utils.control_util import ControlUtil
from utils.communicate_util import CommunicateUtil
from utils.pts_util import PtsUtil
from utils.reconstruction_util import ReconstructionUtil
from utils.preprocess_util import save_scene_data
from utils.data_load import DataLoadUtil
from utils.view_util import ViewUtil
class PointCloud:
def __init__(points, camera, name):
pass
class PointCloudGroup:
def __init__(point_clouds, name):
pass
@stereotype.runner("temp")
class CADCloseLoopOnlineRegStrategyRunner(Runner):
def __init__(self, config_path: str):
super().__init__(config_path)
self.load_experiment("cad_strategy")
self.generate_config = ConfigManager.get("runner", "generate")
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
self.output_dir = self.generate_config["output_dir"]
self.model_dir = self.generate_config["model_dir"]
self.object_name = self.generate_config["object_name"]
self.blender_bin_path = self.generate_config["blender_bin_path"]
self.generator_script_path = self.generate_config["generator_script_path"]
self.voxel_size = self.generate_config["voxel_size"]
self.max_shot_view_num = self.reconstruct_config["max_shot_view_num"]
self.min_shot_new_pts_num = self.reconstruct_config["min_shot_new_pts_num"]
self.min_coverage_increase = self.reconstruct_config["min_coverage_increase"]
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
def split_scan_pts_and_obj_pts(self, world_pts, z_threshold=0):
scan_pts = world_pts[world_pts[:, 2] < z_threshold]
obj_pts = world_pts[world_pts[:, 2] >= z_threshold]
return scan_pts, obj_pts
def loop_scan(self, first_cam_to_real_world):
view_pts_list = []
first_view_data = CommunicateUtil.get_view_data(init=True)
ControlUtil.absolute_rotate_display_table(90)
first_pts = ViewUtil.get_pts(first_view_data)
first_real_world_pts = PtsUtil.transform_point_cloud(
first_pts, first_cam_to_real_world
)
_, first_splitted_real_world_pts = self.split_scan_pts_and_obj_pts(
first_real_world_pts
)
view_pts_list.append(first_splitted_real_world_pts)
shot_num = 4
for i in range(shot_num-1):
view_data = CommunicateUtil.get_view_data()
if i != shot_num - 2:
ControlUtil.absolute_rotate_display_table(90)
time.sleep(0.5)
if view_data is None:
Log.error("No view data received")
continue
view_pts = ViewUtil.get_pts(view_data)
real_world_pts = PtsUtil.transform_point_cloud(
view_pts, first_cam_to_real_world
)
_, splitted_real_world_pts = self.split_scan_pts_and_obj_pts(
real_world_pts
)
view_pts_list.append(splitted_real_world_pts)
return view_pts_list
def register(self):
ControlUtil.connect_robot()
""" init robot """
Log.info("start init")
ControlUtil.init()
first_cam_to_real_world = ControlUtil.get_pose()
""" loop shooting """
Log.info("start loop shooting")
view_pts_list = self.loop_scan(first_cam_to_real_world)
""" register """
Log.info("start register")
pts = np.vstack(view_pts_list)
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
if not os.path.exists(os.path.join(self.output_dir, self.object_name)):
os.makedirs(os.path.join(self.output_dir, self.object_name))
scene_dir = os.path.join(self.output_dir, self.object_name)
model_path = os.path.join(self.model_dir, self.object_name, "mesh.obj")
cad_model = trimesh.load(model_path)
real_world_to_cad = PtsUtil.register(pts, cad_model)
cad_to_real_world = np.linalg.inv(real_world_to_cad)
Log.success("finish init and register")
real_world_to_blender_world = np.eye(4)
real_world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215])
cad_model_real_world: trimesh.Trimesh = cad_model.apply_transform(
cad_to_real_world
)
cad_model_real_world.export(os.path.join(scene_dir, "mesh.obj"))
#downsampled_pts = PtsUtil.voxel_downsample_point_cloud(pts, self.voxel_size)
np.savetxt(os.path.join(scene_dir, "pts_for_init_reg.txt"), pts)
return cad_to_real_world
def render_data(self):
scene_dir = os.path.join(self.output_dir, self.object_name)
result = subprocess.run(
[
self.blender_bin_path,
"-b",
"-P",
self.generator_script_path,
"--",
scene_dir,
],
capture_output=True,
text=True,
)
print(result)
def preprocess_data(self):
save_scene_data(self.output_dir, self.object_name, file_type="npy")
def get_scan_points_indices(self, scan_points, mask, object_mask_label= (0, 255, 0, 255), cam_intrinsic = None, cam_extrinsic = None):
scan_points_homogeneous = np.hstack((scan_points, np.ones((scan_points.shape[0], 1))))
points_camera = np.dot(np.linalg.inv(cam_extrinsic), scan_points_homogeneous.T).T[:, :3]
points_image_homogeneous = np.dot(cam_intrinsic, points_camera.T).T
points_image_homogeneous /= points_image_homogeneous[:, 2:]
pixel_x = points_image_homogeneous[:, 0].astype(int)
pixel_y = points_image_homogeneous[:, 1].astype(int)
h, w = mask.shape[:2]
valid_indices = (pixel_x >= 0) & (pixel_x < w) & (pixel_y >= 0) & (pixel_y < h)
mask_colors = mask[pixel_y[valid_indices], pixel_x[valid_indices]]
selected_points_indices = np.where((mask_colors != object_mask_label).any(axis=-1))[0]
selected_points_indices = np.where(valid_indices)[0][selected_points_indices]
return selected_points_indices
def run_one_model(self, model_name):
scene_dir = os.path.join(self.output_dir, model_name)
ControlUtil.connect_robot()
""" init robot """
Log.info("start init")
ControlUtil.init()
first_cam_to_real_world = ControlUtil.get_pose()
""" loop shooting """
Log.info("start loop shooting")
view_pts_list = self.loop_scan(first_cam_to_real_world)
""" register """
cad_path = os.path.join(scene_dir, "mesh.obj")
cad_model = trimesh.load(cad_path)
Log.info("start register")
init_pts = np.vstack(view_pts_list)
real_world_to_cad = PtsUtil.register(init_pts, cad_model)
curr_cad_to_real_world = np.linalg.inv(real_world_to_cad)
# np.savetxt(os.path.join("/home/yan20/nbv_rec/project/franka_control/debug", "pts_for_init_reg.txt"), init_pts)
# debug_cad = cad_model.apply_transform(curr_cad_to_real_world)
# debug_cad.export(os.path.join("/home/yan20/nbv_rec/project/franka_control/debug", "cad_for_init_reg.obj"))
pts_dir = os.path.join(scene_dir, "pts")
sample_view_pts_list = []
frame_num = len(os.listdir(pts_dir))
for frame_idx in range(frame_num):
pts_path = os.path.join(scene_dir, "pts", f"{frame_idx}.npy")
point_cloud = np.load(pts_path)
if point_cloud.shape[0] != 0:
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(
point_cloud, self.voxel_size
)
sample_view_pts_list.append(sampled_point_cloud)
else:
sample_view_pts_list.append(np.zeros((0, 3)))
""" close-loop online registery strategy """
scanned_pts = PtsUtil.voxel_downsample_point_cloud(init_pts, voxel_size=self.voxel_size)
shot_pts_list = []
last_coverage = 0
Log.info("start close-loop control")
cnt = 0
mask_list = []
cam_to_cad_list = []
cam_R_to_cad_list = []
shot_view_idx_list = []
scan_points_path = os.path.join(self.output_dir, self.object_name, "scan_points.txt")
display_table_scan_points = np.loadtxt(scan_points_path)
for i in range(frame_num):
path = DataLoadUtil.get_path(self.output_dir, self.object_name, i)
mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True)
mask_list.append((mask_L, mask_R))
cam_info = DataLoadUtil.load_cam_info(path, binocular=True)
cam_to_cad = cam_info["cam_to_world"]
cam_to_cad_list.append(cam_to_cad)
cam_R_to_cad = cam_info["cam_to_world_R"]
cam_R_to_cad_list.append(cam_R_to_cad)
selected_view = []
while True:
import ipdb; ipdb.set_trace()
history_indices = []
scan_points_idx_list = []
for i in range(frame_num):
cam_to_cad = cam_to_cad_list[i]
cam_R_to_cad = cam_R_to_cad_list[i]
curr_cam_L_to_world = curr_cad_to_real_world @ cam_to_cad
curr_cam_R_to_world = curr_cad_to_real_world @ cam_R_to_cad
scan_points_indices_L = self.get_scan_points_indices(display_table_scan_points, mask_list[i][0], cam_intrinsic=cam_info["cam_intrinsic"], cam_extrinsic=curr_cam_L_to_world)
scan_points_indices_R = self.get_scan_points_indices(display_table_scan_points, mask_list[i][1], cam_intrinsic=cam_info["cam_intrinsic"], cam_extrinsic=curr_cam_R_to_world)
scan_points_indices = np.intersect1d(scan_points_indices_L, scan_points_indices_R)
scan_points_idx_list.append(scan_points_indices)
for shot_view_idx in shot_view_idx_list:
history_indices.append(scan_points_idx_list[shot_view_idx])
cad_scanned_pts = PtsUtil.transform_point_cloud(scanned_pts, np.linalg.inv(curr_cad_to_real_world))
next_best_view, next_best_coverage, next_best_covered_num = (
ReconstructionUtil.compute_next_best_view_with_overlap(
cad_scanned_pts,
sample_view_pts_list,
selected_view,
history_indices,
scan_points_idx_list,
threshold=self.voxel_size,
overlap_area_threshold=25,
scan_points_threshold=self.scan_points_threshold,
)
)
if next_best_view is None:
Log.warning("No next best view found")
selected_view.append(next_best_view)
nbv_path = DataLoadUtil.get_path(self.output_dir, self.object_name, next_best_view)
nbv_cam_info = DataLoadUtil.load_cam_info(nbv_path, binocular=True)
nbv_cam_to_cad = nbv_cam_info["cam_to_world_O"]
nbv_cam_to_world = curr_cad_to_real_world @ nbv_cam_to_cad
target_camera_pose = nbv_cam_to_world @ ControlUtil.CAMERA_CORRECTION
ControlUtil.move_to(target_camera_pose)
''' get world pts '''
time.sleep(0.5)
view_data = CommunicateUtil.get_view_data()
if view_data is None:
Log.error("No view data received")
continue
shot_view_idx_list.append(next_best_view)
cam_shot_pts = ViewUtil.get_pts(view_data)
world_shot_pts = PtsUtil.transform_point_cloud(
cam_shot_pts, first_cam_to_real_world
)
_, world_splitted_shot_pts = self.split_scan_pts_and_obj_pts(
world_shot_pts
)
shot_pts_list.append(world_splitted_shot_pts)
debug_dir = os.path.join(scene_dir, "debug")
if not os.path.exists(debug_dir):
os.makedirs(debug_dir)
last_scanned_pts_num = scanned_pts.shape[0]
import ipdb;ipdb.set_trace()
new_scanned_pts = PtsUtil.voxel_downsample_point_cloud(
np.vstack([scanned_pts, world_splitted_shot_pts]), self.voxel_size
)
# last_real_world_to_cad = real_world_to_cad
# real_world_to_cad = PtsUtil.register_fine(new_scanned_pts, cad_model)
# # rot distance of two rotation matrix
# rot_dist = np.arccos(
# (np.trace(real_world_to_cad[:3, :3].T @ last_real_world_to_cad[:3, :3]) - 1) / 2
# )
# print(f"-----rot dist: {rot_dist}")
curr_cad_to_real_world = np.linalg.inv(real_world_to_cad)
cad_splitted_shot_pts = PtsUtil.transform_point_cloud(world_splitted_shot_pts, real_world_to_cad)
np.savetxt(os.path.join(debug_dir, f"shot_pts_{cnt}.txt"), world_splitted_shot_pts)
np.savetxt(os.path.join(debug_dir, f"render_pts_{cnt}.txt"), sample_view_pts_list[next_best_view])
np.savetxt(os.path.join(debug_dir, f"reg_scanned_pts_{cnt}.txt"), new_scanned_pts)
cad_pts = cad_model.vertices
world_cad_pts = PtsUtil.transform_point_cloud(cad_pts, curr_cad_to_real_world)
np.savetxt(os.path.join(debug_dir, f"world_cad_pts_{cnt}.txt"), world_cad_pts)
#import ipdb; ipdb.set_trace()
new_scanned_pts_num = new_scanned_pts.shape[0]
scanned_pts = new_scanned_pts
Log.info(
f"Next Best cover pts: {next_best_covered_num}, Best coverage: {next_best_coverage}"
)
coverage_rate_increase = next_best_coverage - last_coverage
if coverage_rate_increase < self.min_coverage_increase:
Log.info(f"Coverage rate = {coverage_rate_increase} < {self.min_coverage_increase}, stop scanning")
# break
last_coverage = next_best_coverage
new_added_pts_num = new_scanned_pts_num - last_scanned_pts_num
if new_added_pts_num < self.min_shot_new_pts_num:
Log.info(f"New added pts num = {new_added_pts_num} < {self.min_shot_new_pts_num}")
#ipdb.set_trace()
if len(shot_pts_list) >= self.max_shot_view_num:
Log.info(f"Scanned view num = {len(shot_pts_list)} >= {self.max_shot_view_num}, stop scanning")
#break
cnt += 1
Log.success("[Part 4/4] finish close-loop control")
def run(self):
self.run_one_model(self.object_name)
# ---------------------------- test ---------------------------- #
if __name__ == "__main__":
model_path = r"/home/yan20/nbv_rec/data/models/workpiece_1/mesh.obj"
model = trimesh.load(model_path)