This commit is contained in:
2024-11-07 19:33:18 +08:00
parent 5bcd0fc6e3
commit 2cd811c1b7
8 changed files with 77 additions and 105 deletions

View File

@@ -76,6 +76,8 @@ class Inferencer(Runner):
for i in tqdm(range(total), desc=f"Processing {test_set_name}", ncols=100):
data = test_set.__getitem__(i)
scene_name = data["scene_name"]
if scene_name != "omniobject3d-book_004":
continue
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
if os.path.exists(inference_result_path):
Log.info(f"Inference result already exists for scene: {scene_name}")
@@ -87,7 +89,7 @@ class Inferencer(Runner):
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 7):
def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 5):
scene_name = data["scene_name"]
Log.info(f"Processing scene: {scene_name}")
status_manager.set_status("inference", "inferencer", "scene", scene_name)
@@ -110,10 +112,13 @@ class Inferencer(Runner):
input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)]
input_data["mode"] = namespace.Mode.TEST
input_pts_N = input_data["combined_scanned_pts"].shape[1]
root = os.path.dirname(scene_path)
display_table_info = DataLoadUtil.get_display_table_info(root, scene_name)
radius = display_table_info["radius"]
scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=radius))
first_frame_target_pts, first_frame_target_normals, first_frame_scan_points_indices = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
scanned_view_pts = [first_frame_target_pts]
history_indices = [first_frame_scan_points_indices]
@@ -124,6 +129,7 @@ class Inferencer(Runner):
retry = 0
pred_cr_seq = [last_pred_cr]
success = 0
last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], 0.002).shape[0]
import time
while len(pred_cr_seq) < max_iter and retry < max_retry:
start_time = time.time()
@@ -146,7 +152,7 @@ class Inferencer(Runner):
curr_overlap_area_threshold = overlap_area_threshold * 0.5
downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold)
overlap, new_added_pts_num = ReconstructionUtil.check_overlap(downsampled_new_target_pts, down_sampled_model_pts, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True)
overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, down_sampled_model_pts, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True)
if not overlap:
retry += 1
retry_overlap_pose.append(pred_pose.cpu().numpy().tolist())
@@ -170,31 +176,22 @@ class Inferencer(Runner):
continue
start_time = time.time()
pred_cr, covered_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
end_time = time.time()
print(f"Time taken for coverage rate computation: {end_time - start_time} seconds")
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
print("new added pts num: ", new_added_pts_num)
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
print("max coverage rate reached!: ", pred_cr)
success += 1
elif new_added_pts_num < 5:
#success += 1
print("min added pts num reached!: ", new_added_pts_num)
if pred_cr <= last_pred_cr + cr_increase_threshold:
retry += 1
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
continue
retry = 0
pred_cr_seq.append(pred_cr)
scanned_view_pts.append(new_target_pts)
down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_target_pts, input_pts_N)
new_pts = down_sampled_new_pts_world
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
combined_scanned_pts = np.concatenate([input_data["combined_scanned_pts"][0].cpu().numpy(), new_pts], axis=0)
combined_scanned_pts = np.vstack(scanned_view_pts)
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, 0.002)
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
@@ -202,29 +199,12 @@ class Inferencer(Runner):
if success > 3:
break
last_pred_cr = pred_cr
input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist()
result = {
"pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"],
"combined_scanned_pts": input_data["combined_scanned_pts"],
"target_pts_seq": scanned_view_pts,
"coverage_rate_seq": pred_cr_seq,
"max_coverage_rate": data["seq_max_coverage_rate"],
"pred_max_coverage_rate": max(pred_cr_seq),
"scene_name": scene_name,
"retry_no_pts_pose": retry_no_pts_pose,
"retry_duplication_pose": retry_duplication_pose,
"retry_overlap_pose": retry_overlap_pose,
"best_seq_len": data["best_seq_len"],
}
self.stat_result[scene_name] = {
"coverage_rate_seq": pred_cr_seq,
"pred_max_coverage_rate": max(pred_cr_seq),
"pred_seq_len": len(pred_cr_seq),
}
print('success rate: ', max(pred_cr_seq))
return result
pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0]
if pts_num - last_pts_num < 10 and pred_cr < data["seq_max_coverage_rate"] - 1e-3:
retry += 1
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
print("delta pts num < 10:", pts_num, last_pts_num)
last_pts_num = pts_num
def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005):
if new_pts is not None: