global: debug inference
This commit is contained in:
@@ -77,17 +77,17 @@ class Inferencer(Runner):
|
||||
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
|
||||
|
||||
def predict_sequence(self, data, cr_increase_threshold=0, max_iter=50, max_retry=5):
|
||||
scene_name = data["scene_name"][0]
|
||||
scene_name = data["scene_name"]
|
||||
Log.info(f"Processing scene: {scene_name}")
|
||||
status_manager.set_status("inference", "inferencer", "scene", scene_name)
|
||||
|
||||
''' data for rendering '''
|
||||
scene_path = data["scene_path"][0]
|
||||
O_to_L_pose = data["O_to_L_pose"][0]
|
||||
scene_path = data["scene_path"]
|
||||
O_to_L_pose = data["O_to_L_pose"]
|
||||
voxel_threshold = self.voxel_size
|
||||
filter_degree = 75
|
||||
down_sampled_model_pts = data["gt_pts"]
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0]
|
||||
first_frame_to_world = np.eye(4)
|
||||
first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6])
|
||||
@@ -95,14 +95,13 @@ class Inferencer(Runner):
|
||||
|
||||
''' data for inference '''
|
||||
input_data = {}
|
||||
input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device)
|
||||
input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0)
|
||||
input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)]
|
||||
input_data["mode"] = namespace.Mode.TEST
|
||||
input_pts_N = input_data["combined_scanned_pts"].shape[1]
|
||||
|
||||
first_frame_target_pts, _ = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, down_sampled_model_pts, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
|
||||
first_frame_target_pts, first_frame_target_normals = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
|
||||
scanned_view_pts = [first_frame_target_pts]
|
||||
last_pred_cr = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
|
||||
last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
|
||||
|
||||
retry_duplication_pose = []
|
||||
retry_no_pts_pose = []
|
||||
@@ -118,7 +117,7 @@ class Inferencer(Runner):
|
||||
pred_pose[:3,3] = pred_pose_9d[0,6:]
|
||||
|
||||
try:
|
||||
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
|
||||
new_target_pts, new_target_normals = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
|
||||
except Exception as e:
|
||||
Log.warning(f"Error in scene {scene_path}, {e}")
|
||||
print("current pose: ", pred_pose)
|
||||
@@ -127,12 +126,18 @@ class Inferencer(Runner):
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
if new_target_pts.shape[0] == 0:
|
||||
print("no pts in new target")
|
||||
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
pred_cr = self.compute_coverage_rate(scanned_view_pts, new_target_pts_world, down_sampled_model_pts, threshold=voxel_threshold)
|
||||
|
||||
print(pred_cr, last_pred_cr, " max: ", data["max_coverage_rate"])
|
||||
if pred_cr >= data["max_coverage_rate"]:
|
||||
print("max coverage rate reached!")
|
||||
pred_cr, new_added_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
|
||||
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
|
||||
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
|
||||
print("max coverage rate reached!: ", pred_cr)
|
||||
if new_added_pts_num < 10:
|
||||
print("min added pts num reached!: ", new_added_pts_num)
|
||||
if pred_cr <= last_pred_cr + cr_increase_threshold:
|
||||
retry += 1
|
||||
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
|
||||
@@ -140,17 +145,14 @@ class Inferencer(Runner):
|
||||
|
||||
retry = 0
|
||||
pred_cr_seq.append(pred_cr)
|
||||
scanned_view_pts.append(new_target_pts_world)
|
||||
down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_pts_world, input_pts_N)
|
||||
new_pts_world_aug = np.hstack([down_sampled_new_pts_world, np.ones((down_sampled_new_pts_world.shape[0], 1))])
|
||||
new_pts = np.dot(np.linalg.inv(first_frame_to_world.cpu()), new_pts_world_aug.T).T[:,:3]
|
||||
|
||||
new_pts_tensor = torch.tensor(new_pts, dtype=torch.float32).unsqueeze(0).to(self.device)
|
||||
scanned_view_pts.append(new_target_pts)
|
||||
down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_target_pts, input_pts_N)
|
||||
|
||||
input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0] , new_pts_tensor], dim=0)]
|
||||
new_pts = down_sampled_new_pts_world
|
||||
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
|
||||
combined_scanned_views_pts = np.concatenate(input_data["scanned_pts"][0].tolist(), axis=0)
|
||||
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002)
|
||||
|
||||
combined_scanned_pts = np.concatenate([input_data["combined_scanned_pts"][0].cpu().numpy(), new_pts], axis=0)
|
||||
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, 0.002)
|
||||
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
|
||||
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
|
||||
|
||||
|
Reference in New Issue
Block a user