update inferencer; add load_from_preprocessed_pts

This commit is contained in:
2024-09-20 19:00:08 +08:00
parent a621749cc9
commit 8517255245
7 changed files with 69 additions and 40 deletions

View File

@@ -79,8 +79,7 @@ class Inferencer(Runner):
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
def predict_sequence(self, data, cr_increase_threshold=0, max_iter=100):
pred_cr_seq = []
def predict_sequence(self, data, cr_increase_threshold=0, max_iter=50, max_retry=5):
scene_name = data["scene_name"][0]
Log.info(f"Processing scene: {scene_name}")
status_manager.set_status("inference", "inferencer", "scene", scene_name)
@@ -98,7 +97,7 @@ class Inferencer(Runner):
''' data for inference '''
input_data = {}
input_data["scanned_pts"] = [data["first_pts"][0].to(self.device)]
input_data["scanned_n_to_world_pose_9d"] = [data["first_to_first_9d"][0].to(self.device)]
input_data["scanned_n_to_world_pose_9d"] = [data["first_frame_to_world"][0].to(self.device)]
input_data["mode"] = namespace.Mode.TEST
input_pts_N = input_data["scanned_pts"][0].shape[1]
@@ -107,9 +106,11 @@ class Inferencer(Runner):
scanned_view_pts = [first_frame_target_pts]
last_pred_cr = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
while len(pred_cr_seq) < max_iter:
retry_duplication_pose = []
retry_no_pts_pose = []
retry = 0
pred_cr_seq = [last_pred_cr]
while len(pred_cr_seq) < max_iter and retry < max_retry:
output = self.pipeline(input_data)
next_pose_9d = output["pred_pose_9d"]
@@ -118,22 +119,30 @@ class Inferencer(Runner):
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(next_pose_9d[:,:6])[0]
pred_pose[:3,3] = next_pose_9d[0,6:]
pred_n_to_world_pose_mat = torch.matmul(first_frame_to_world, pred_pose)
try:
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_n_to_world_pose_mat, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
except Exception as e:
Log.warning(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
retry_no_pts_pose.append(pred_n_to_world_pose_mat.cpu().numpy().tolist())
retry += 1
continue
pred_cr = self.compute_coverage_rate(scanned_view_pts, new_target_pts_world, down_sampled_model_pts, threshold=voxel_threshold)
pred_cr_seq.append(pred_cr)
print(pred_cr, last_pred_cr)
print(pred_cr, last_pred_cr, " max: ", data["max_coverage_rate"])
if pred_cr >= data["max_coverage_rate"]:
break
if pred_cr <= last_pred_cr + cr_increase_threshold:
break
retry += 1
retry_duplication_pose.append(pred_n_to_world_pose_mat.cpu().numpy().tolist())
continue
retry = 0
pred_cr_seq.append(pred_cr)
scanned_view_pts.append(new_target_pts_world)
down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_pts_world, input_pts_N)
new_pts_world_aug = np.hstack([down_sampled_new_pts_world, np.ones((down_sampled_new_pts_world.shape[0], 1))])
@@ -145,7 +154,7 @@ class Inferencer(Runner):
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], next_pose_9d], dim=0)]
last_pred_cr = pred_cr
print(last_pred_cr)
input_data["scanned_pts"] = input_data["scanned_pts"][0].cpu().numpy().tolist()
input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist()
@@ -154,8 +163,12 @@ class Inferencer(Runner):
"pts_seq": input_data["scanned_pts"],
"target_pts_seq": scanned_view_pts,
"coverage_rate_seq": pred_cr_seq,
"max_coverage_rate": data["max_coverage_rate"],
"pred_max_coverage_rate": max(pred_cr_seq)
"max_coverage_rate": data["max_coverage_rate"][0],
"pred_max_coverage_rate": max(pred_cr_seq),
"scene_name": scene_name,
"retry_no_pts_pose": retry_no_pts_pose,
"retry_duplication_pose": retry_duplication_pose,
"best_seq_len": data["best_seq_len"][0],
}
return result