add multi seq training

This commit is contained in:
2024-09-23 14:30:51 +08:00
parent 6cdff9c83f
commit 3c4077ec4f
12 changed files with 152 additions and 93 deletions

View File

@@ -92,12 +92,16 @@ class Inferencer(Runner):
model_points_normals = data["model_points_normals"][0]
model_pts = model_points_normals[:,:3]
down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold)
first_frame_to_world = data["first_frame_to_world"][0]
first_frame_to_world_9d = data["first_to_world_9d"][0]
first_frame_to_world = torch.eye(4, device=first_frame_to_world_9d.device)
first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(first_frame_to_world_9d[:,:6])[0]
first_frame_to_world[:3,3] = first_frame_to_world_9d[0,6:]
first_frame_to_world = first_frame_to_world.to(self.device)
''' data for inference '''
input_data = {}
input_data["scanned_pts"] = [data["first_pts"][0].to(self.device)]
input_data["scanned_n_to_world_pose_9d"] = [data["first_frame_to_world"][0].to(self.device)]
input_data["scanned_n_to_world_pose_9d"] = [data["first_to_world_9d"][0].to(self.device)]
input_data["mode"] = namespace.Mode.TEST
input_pts_N = input_data["scanned_pts"][0].shape[1]
@@ -113,20 +117,19 @@ class Inferencer(Runner):
while len(pred_cr_seq) < max_iter and retry < max_retry:
output = self.pipeline(input_data)
next_pose_9d = output["pred_pose_9d"]
pred_pose = torch.eye(4, device=next_pose_9d.device)
pred_pose_9d = output["pred_pose_9d"]
pred_pose = torch.eye(4, device=pred_pose_9d.device)
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(next_pose_9d[:,:6])[0]
pred_pose[:3,3] = next_pose_9d[0,6:]
pred_n_to_world_pose_mat = torch.matmul(first_frame_to_world, pred_pose)
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0]
pred_pose[:3,3] = pred_pose_9d[0,6:]
try:
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_n_to_world_pose_mat, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
except Exception as e:
Log.warning(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
retry_no_pts_pose.append(pred_n_to_world_pose_mat.cpu().numpy().tolist())
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
continue
@@ -138,7 +141,7 @@ class Inferencer(Runner):
break
if pred_cr <= last_pred_cr + cr_increase_threshold:
retry += 1
retry_duplication_pose.append(pred_n_to_world_pose_mat.cpu().numpy().tolist())
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
continue
retry = 0
@@ -151,7 +154,7 @@ class Inferencer(Runner):
new_pts_tensor = torch.tensor(new_pts, dtype=torch.float32).unsqueeze(0).to(self.device)
input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0] , new_pts_tensor], dim=0)]
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], next_pose_9d], dim=0)]
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
last_pred_cr = pred_cr

View File

@@ -25,15 +25,17 @@ class StrategyGenerator(Runner):
self.to_specified_dir = ConfigManager.get("runner", "generate", "to_specified_dir")
self.save_best_combined_pts = ConfigManager.get("runner", "generate", "save_best_combined_points")
self.save_mesh = ConfigManager.get("runner", "generate", "save_mesh")
self.load_pts = ConfigManager.get("runner", "generate", "load_points")
self.filter_degree = ConfigManager.get("runner", "generate", "filter_degree")
self.overwrite = ConfigManager.get("runner", "generate", "overwrite")
self.save_pts = ConfigManager.get("runner","generate","save_points")
self.seq_num = ConfigManager.get("runner","generate","seq_num")
def run(self):
dataset_name_list = ConfigManager.get("runner", "generate", "dataset_list")
voxel_threshold, overlap_threshold = ConfigManager.get("runner","generate","voxel_threshold"), ConfigManager.get("runner","generate","overlap_threshold")
self.save_pts = ConfigManager.get("runner","generate","save_points")
for dataset_idx in range(len(dataset_name_list)):
dataset_name = dataset_name_list[dataset_idx]
status_manager.set_progress("generate_strategy", "strategy_generator", "dataset", dataset_idx, len(dataset_name_list))
@@ -48,7 +50,7 @@ class StrategyGenerator(Runner):
diag = DataLoadUtil.get_bbox_diag(model_dir, scene_name)
voxel_threshold = diag*0.02
status_manager.set_status("generate_strategy", "strategy_generator", "voxel_threshold", voxel_threshold)
output_label_path = DataLoadUtil.get_label_path(root_dir, scene_name)
output_label_path = DataLoadUtil.get_label_path(root_dir, scene_name,0)
if os.path.exists(output_label_path) and not self.overwrite:
Log.info(f"Scene <{scene_name}> Already Exists, Skip")
cnt += 1
@@ -79,43 +81,52 @@ class StrategyGenerator(Runner):
pts_list = []
for frame_idx in range(frame_num):
path = DataLoadUtil.get_path(root, scene_name, frame_idx)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num)
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True)
#display_table = None #DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True, target_mask_label=()) #TODO
sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree)
if self.load_pts and os.path.exists(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt")):
sampled_point_cloud = np.loadtxt(os.path.join(root,scene_name, "pts", f"{frame_idx}.txt"))
status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num)
pts_list.append(sampled_point_cloud)
continue
else:
path = DataLoadUtil.get_path(root, scene_name, frame_idx)
cam_params = DataLoadUtil.load_cam_info(path, binocular=True)
status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num)
point_cloud = DataLoadUtil.get_target_point_cloud_world_from_path(path, binocular=True)
sampled_point_cloud = ReconstructionUtil.filter_points(point_cloud, model_points_normals, cam_pose=cam_params["cam_to_world"], voxel_size=voxel_threshold, theta=self.filter_degree)
if self.save_pts:
pts_dir = os.path.join(root,scene_name, "pts")
if not os.path.exists(pts_dir):
os.makedirs(pts_dir)
np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud)
pts_list.append(sampled_point_cloud)
if self.save_pts:
pts_dir = os.path.join(root,scene_name, "pts")
if not os.path.exists(pts_dir):
os.makedirs(pts_dir)
np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud)
pts_list.append(sampled_point_cloud)
status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num)
seq_num = min(self.seq_num, len(pts_list))
init_view_list = range(seq_num)
seq_idx = 0
for init_view in init_view_list:
status_manager.set_progress("generate_strategy", "strategy_generator", "computing sequence", seq_idx, len(init_view_list))
limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(down_sampled_model_pts, pts_list,init_view=init_view, threshold=voxel_threshold, overlap_threshold=overlap_threshold, status_info=self.status_info)
data_pairs = self.generate_data_pairs(limited_useful_view)
seq_save_data = {
"data_pairs": data_pairs,
"best_sequence": limited_useful_view,
"max_coverage_rate": limited_useful_view[-1][1]
}
limited_useful_view, _, best_combined_pts = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(down_sampled_model_pts, pts_list, threshold=voxel_threshold, overlap_threshold=overlap_threshold, status_info=self.status_info)
data_pairs = self.generate_data_pairs(limited_useful_view)
seq_save_data = {
"data_pairs": data_pairs,
"best_sequence": limited_useful_view,
"max_coverage_rate": limited_useful_view[-1][1]
}
status_manager.set_status("generate_strategy", "strategy_generator", "max_coverage_rate", limited_useful_view[-1][1])
Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}")
output_label_path = DataLoadUtil.get_label_path(root, scene_name)
output_best_reconstructed_pts_path = os.path.join(root,scene_name, f"best_reconstructed_pts.txt")
with open(output_label_path, 'w') as f:
json.dump(seq_save_data, f)
if self.save_best_combined_pts:
np.savetxt(output_best_reconstructed_pts_path, best_combined_pts)
status_manager.set_status("generate_strategy", "strategy_generator", "max_coverage_rate", limited_useful_view[-1][1])
Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}")
output_label_path = DataLoadUtil.get_label_path(root, scene_name, seq_idx)
with open(output_label_path, 'w') as f:
json.dump(seq_save_data, f)
seq_idx += 1
if self.save_mesh:
DataLoadUtil.save_target_mesh_at_world_space(root, model_dir, scene_name)
status_manager.set_progress("generate_strategy", "strategy_generator", "computing sequence", len(init_view_list), len(init_view_list))
def generate_data_pairs(self, useful_view):