update config
This commit is contained in:
224
runners/cad_open_loop_strategy.py
Normal file
224
runners/cad_open_loop_strategy.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import os
|
||||
import time
|
||||
import trimesh
|
||||
import tempfile
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.status import status_manager
|
||||
|
||||
from utils.control_util import ControlUtil
|
||||
from utils.communicate_util import CommunicateUtil
|
||||
from utils.pts_util import PtsUtil
|
||||
from utils.reconstruction_util import ReconstructionUtil
|
||||
from utils.preprocess_util import save_scene_data, save_scene_data_multithread
|
||||
from utils.data_load import DataLoadUtil
|
||||
from utils.view_util import ViewUtil
|
||||
|
||||
|
||||
@stereotype.runner("CAD_open_loop_strategy_runner")
|
||||
class CADOpenLoopStrategyRunner(Runner):
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
super().__init__(config_path)
|
||||
self.load_experiment("cad_open_loop_strategy")
|
||||
self.status_info = {
|
||||
"status_manager": status_manager,
|
||||
"app_name": "cad",
|
||||
"runner_name": "CAD_open_loop_strategy_runner"
|
||||
}
|
||||
self.generate_config = ConfigManager.get("runner", "generate")
|
||||
self.reconstruct_config = ConfigManager.get("runner", "reconstruct")
|
||||
self.blender_bin_path = self.generate_config["blender_bin_path"]
|
||||
self.generator_script_path = self.generate_config["generator_script_path"]
|
||||
self.model_dir = self.generate_config["model_dir"]
|
||||
self.voxel_size = self.generate_config["voxel_size"]
|
||||
self.max_view = self.generate_config["max_view"]
|
||||
self.min_view = self.generate_config["min_view"]
|
||||
self.max_diag = self.generate_config["max_diag"]
|
||||
self.min_diag = self.generate_config["min_diag"]
|
||||
self.min_cam_table_included_degree = self.generate_config["min_cam_table_included_degree"]
|
||||
self.random_view_ratio = self.generate_config["random_view_ratio"]
|
||||
|
||||
self.soft_overlap_threshold = self.reconstruct_config["soft_overlap_threshold"]
|
||||
self.hard_overlap_threshold = self.reconstruct_config["hard_overlap_threshold"]
|
||||
self.scan_points_threshold = self.reconstruct_config["scan_points_threshold"]
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
|
||||
def split_scan_pts_and_obj_pts(self, world_pts, z_threshold = 0):
|
||||
scan_pts = world_pts[world_pts[:,2] < z_threshold]
|
||||
obj_pts = world_pts[world_pts[:,2] >= z_threshold]
|
||||
return scan_pts, obj_pts
|
||||
|
||||
def run_one_model(self, model_name):
|
||||
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
|
||||
result = dict()
|
||||
|
||||
shot_pts_list = []
|
||||
ControlUtil.connect_robot()
|
||||
''' init robot '''
|
||||
Log.info("[Part 1/5] start init and register")
|
||||
ControlUtil.init()
|
||||
|
||||
''' load CAD model '''
|
||||
model_path = os.path.join(self.model_dir, model_name,"mesh.ply")
|
||||
temp_name = "cad_model_world"
|
||||
cad_model = trimesh.load(model_path)
|
||||
''' take first view '''
|
||||
Log.info("[Part 1/5] take first view data")
|
||||
view_data = CommunicateUtil.get_view_data(init=True)
|
||||
first_cam_pts = ViewUtil.get_pts(view_data)
|
||||
first_cam_to_real_world = ControlUtil.get_pose()
|
||||
first_real_world_pts = PtsUtil.transform_point_cloud(first_cam_pts, first_cam_to_real_world)
|
||||
_, first_splitted_real_world_pts = self.split_scan_pts_and_obj_pts(first_real_world_pts)
|
||||
np.savetxt(f"first_real_pts_{model_name}.txt", first_splitted_real_world_pts)
|
||||
''' register '''
|
||||
Log.info("[Part 1/5] do registeration")
|
||||
real_world_to_cad = PtsUtil.register(first_splitted_real_world_pts, cad_model)
|
||||
cad_to_real_world = np.linalg.inv(real_world_to_cad)
|
||||
Log.success("[Part 1/5] finish init and register")
|
||||
real_world_to_blender_world = np.eye(4)
|
||||
real_world_to_blender_world[:3, 3] = np.asarray([0, 0, 0.9215])
|
||||
cad_model_real_world:trimesh.Trimesh = cad_model.apply_transform(cad_to_real_world)
|
||||
cad_model_real_world.export(os.path.join(temp_dir, f"real_world_{temp_name}.obj"))
|
||||
cad_model_blender_world:trimesh.Trimesh = cad_model.apply_transform(real_world_to_blender_world)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir = "/home/yan20/nbv_rec/project/franka_control/temp_output"
|
||||
cad_model_blender_world.export(os.path.join(temp_dir, f"{temp_name}.obj"))
|
||||
scene_dir = os.path.join(temp_dir, temp_name)
|
||||
''' sample view '''
|
||||
Log.info("[Part 2/5] start running renderer")
|
||||
subprocess.run([
|
||||
self.blender_bin_path, '-b', '-P', self.generator_script_path, '--', temp_dir
|
||||
], capture_output=True, text=True)
|
||||
Log.success("[Part 2/5] finish running renderer")
|
||||
|
||||
|
||||
world_model_points = np.loadtxt(os.path.join(scene_dir, "points_and_normals.txt"))[:,:3]
|
||||
''' preprocess '''
|
||||
Log.info("[Part 3/5] start preprocessing data")
|
||||
save_scene_data(temp_dir, temp_name)
|
||||
Log.success("[Part 3/5] finish preprocessing data")
|
||||
|
||||
pts_dir = os.path.join(temp_dir,temp_name,"pts")
|
||||
sample_view_pts_list = []
|
||||
scan_points_idx_list = []
|
||||
frame_num = len(os.listdir(pts_dir))
|
||||
|
||||
for frame_idx in range(frame_num):
|
||||
pts_path = os.path.join(temp_dir,temp_name, "pts", f"{frame_idx}.txt")
|
||||
idx_path = os.path.join(temp_dir,temp_name, "scan_points_indices", f"{frame_idx}.npy")
|
||||
point_cloud = np.loadtxt(pts_path)
|
||||
if point_cloud.shape[0] != 0:
|
||||
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud, self.voxel_size)
|
||||
indices = np.load(idx_path)
|
||||
try:
|
||||
len(indices)
|
||||
except:
|
||||
indices = np.array([indices])
|
||||
sample_view_pts_list.append(sampled_point_cloud)
|
||||
scan_points_idx_list.append(indices)
|
||||
|
||||
''' generate strategy '''
|
||||
|
||||
Log.info("[Part 4/5] start generating strategy")
|
||||
limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(
|
||||
world_model_points, sample_view_pts_list,
|
||||
scan_points_indices_list = scan_points_idx_list,
|
||||
init_view=0,
|
||||
threshold=self.voxel_size,
|
||||
soft_overlap_threshold = self.soft_overlap_threshold,
|
||||
hard_overlap_threshold = self.hard_overlap_threshold,
|
||||
scan_points_threshold = self.scan_points_threshold,
|
||||
status_info=self.status_info
|
||||
)
|
||||
Log.success("[Part 4/5] finish generating strategy")
|
||||
|
||||
''' extract cam_to_world sequence '''
|
||||
cam_to_world_seq = []
|
||||
coveraget_rate_seq = []
|
||||
render_pts = []
|
||||
idx_seq = []
|
||||
for idx, coverage_rate in limited_useful_view:
|
||||
path = DataLoadUtil.get_path(temp_dir, temp_name, idx)
|
||||
cam_info = DataLoadUtil.load_cam_info(path, binocular=True)
|
||||
cam_to_world_seq.append(cam_info["cam_to_world_O"])
|
||||
coveraget_rate_seq.append(coverage_rate)
|
||||
idx_seq.append(idx)
|
||||
render_pts.append(sample_view_pts_list[idx])
|
||||
|
||||
Log.info("[Part 5/5] start running robot")
|
||||
''' take best seq view '''
|
||||
#import ipdb; ipdb.set_trace()
|
||||
target_scanned_pts = np.concatenate(sample_view_pts_list)
|
||||
voxel_downsampled_target_scanned_pts = PtsUtil.voxel_downsample_point_cloud(target_scanned_pts, self.voxel_size)
|
||||
result = dict()
|
||||
gt_scanned_pts = np.concatenate(render_pts, axis=0)
|
||||
voxel_down_sampled_gt_scanned_pts = PtsUtil.voxel_downsample_point_cloud(gt_scanned_pts, self.voxel_size)
|
||||
result["gt_final_coverage_rate_cad"] = ReconstructionUtil.compute_coverage_rate(voxel_downsampled_target_scanned_pts, voxel_down_sampled_gt_scanned_pts, self.voxel_size)
|
||||
step = 1
|
||||
result["real_coverage_rate_seq"] = []
|
||||
for cam_to_world in cam_to_world_seq:
|
||||
try:
|
||||
ControlUtil.move_to(cam_to_world)
|
||||
''' get world pts '''
|
||||
time.sleep(0.5)
|
||||
view_data = CommunicateUtil.get_view_data()
|
||||
if view_data is None:
|
||||
Log.error("Failed to get view data")
|
||||
continue
|
||||
cam_pts = ViewUtil.get_pts(view_data)
|
||||
shot_pts_list.append(cam_pts)
|
||||
scanned_pts = np.concatenate(shot_pts_list, axis=0)
|
||||
voxel_down_sampled_scanned_pts = PtsUtil.voxel_downsample_point_cloud(scanned_pts, self.voxel_size)
|
||||
voxel_down_sampled_scanned_pts_world = PtsUtil.transform_point_cloud(voxel_down_sampled_scanned_pts, first_cam_to_real_world)
|
||||
curr_CR = ReconstructionUtil.compute_coverage_rate(voxel_downsampled_target_scanned_pts, voxel_down_sampled_scanned_pts_world, self.voxel_size)
|
||||
Log.success(f"(step {step}/{len(cam_to_world_seq)}) current coverage: {curr_CR} | gt coverage: {result['gt_final_coverage_rate_cad']}")
|
||||
result["real_final_coverage_rate"] = curr_CR
|
||||
result["real_coverage_rate_seq"].append(curr_CR)
|
||||
step += 1
|
||||
except Exception as e:
|
||||
Log.error(f"Failed to move to {cam_to_world}")
|
||||
Log.error(e)
|
||||
|
||||
#import ipdb;ipdb.set_trace()
|
||||
|
||||
for idx in range(len(shot_pts_list)):
|
||||
if not os.path.exists(os.path.join(temp_dir, temp_name, "shot_pts")):
|
||||
os.makedirs(os.path.join(temp_dir, temp_name, "shot_pts"))
|
||||
if not os.path.exists(os.path.join(temp_dir, temp_name, "render_pts")):
|
||||
os.makedirs(os.path.join(temp_dir, temp_name, "render_pts"))
|
||||
shot_pts = PtsUtil.transform_point_cloud(shot_pts_list[idx], first_cam_to_real_world)
|
||||
np.savetxt(os.path.join(temp_dir, temp_name, "shot_pts", f"{idx}.txt"), shot_pts)
|
||||
np.savetxt(os.path.join(temp_dir, temp_name, "render_pts", f"{idx}.txt"), render_pts[idx])
|
||||
|
||||
|
||||
Log.success("[Part 5/5] finish running robot")
|
||||
|
||||
Log.debug(result)
|
||||
|
||||
def run(self):
|
||||
total = len(os.listdir(self.model_dir))
|
||||
model_start_idx = self.generate_config["model_start_idx"]
|
||||
count_object = model_start_idx
|
||||
for model_name in os.listdir(self.model_dir[model_start_idx:]):
|
||||
Log.info(f"[{count_object}/{total}]Processing {model_name}")
|
||||
self.run_one_model(model_name)
|
||||
Log.success(f"[{count_object}/{total}]Finished processing {model_name}")
|
||||
|
||||
|
||||
# ---------------------------- test ---------------------------- #
|
||||
if __name__ == "__main__":
|
||||
|
||||
model_path = r"C:\Users\hofee\Downloads\mesh.obj"
|
||||
model = trimesh.load(model_path)
|
||||
|
Reference in New Issue
Block a user