upd inference

This commit is contained in:
2025-01-05 23:57:33 +08:00
parent 9c2625b11e
commit dec67e8255
4 changed files with 245 additions and 77 deletions

View File

@@ -4,6 +4,7 @@ from utils.render import RenderUtil
from utils.pose import PoseUtil
from utils.pts import PtsUtil
from utils.reconstruction import ReconstructionUtil
from beans.predict_result import PredictResult
import torch
from tqdm import tqdm
@@ -82,6 +83,7 @@ class Inferencer(Runner):
data = test_set.__getitem__(i)
scene_name = data["scene_name"]
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
if os.path.exists(inference_result_path):
Log.info(f"Inference result already exists for scene: {scene_name}")
continue
@@ -142,88 +144,87 @@ class Inferencer(Runner):
voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold)
output = self.pipeline(input_data)
pred_pose_9d = output["pred_pose_9d"]
import ipdb; ipdb.set_trace()
pred_pose = torch.eye(4, device=pred_pose_9d.device)
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0]
pred_pose[:3,3] = pred_pose_9d[0,6:]
# ----- Debug -----
from utils.vis import visualizeUtil
import ipdb; ipdb.set_trace()
all_directions = []
np.savetxt("input_pts.txt", input_data["combined_scanned_pts"].cpu().numpy()[0])
for i in range(50):
output = self.pipeline(input_data)
pred_pose_9d = output["pred_pose_9d"]
cam_pos, sample_points = visualizeUtil.get_cam_pose_and_cam_axis(pred_pose_9d.cpu().numpy()[0], is_6d_pose=True)
all_directions.append(sample_points)
all_directions = np.array(all_directions)
reshape_all_directions = all_directions.reshape(-1, 3)
np.savetxt("all_directions.txt", reshape_all_directions)
# ----- ----- -----
try:
new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
# # save pred_pose_9d ------
# root = "/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction/temp_output_result"
# scene_dir = os.path.join(root, scene_name)
# if not os.path.exists(scene_dir):
# os.makedirs(scene_dir)
# pred_9d_path = os.path.join(scene_dir,f"pred_pose_9d_{len(pred_cr_seq)}.npy")
# pts_path = os.path.join(scene_dir,f"combined_scanned_pts_{len(pred_cr_seq)}.txt")
# np_combined_scanned_pts = input_data["combined_scanned_pts"][0].cpu().numpy()
# np.save(pred_9d_path, pred_pose_9d.cpu().numpy())
# np.savetxt(pts_path, np_combined_scanned_pts)
# # ----- ----- -----
pred_pose_9d_candidates = PredictResult(pred_pose_9d.cpu().numpy(), input_pts=input_data["combined_scanned_pts"][0].cpu().numpy(), cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses
for pred_pose_9d in pred_pose_9d_candidates:
#import ipdb; ipdb.set_trace()
if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold):
curr_overlap_area_threshold = overlap_area_threshold
else:
curr_overlap_area_threshold = overlap_area_threshold * 0.5
pred_pose_9d = torch.tensor(pred_pose_9d, dtype=torch.float32).to(self.device).unsqueeze(0)
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0]
pred_pose[:3,3] = pred_pose_9d[0,6:]
try:
new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
#import ipdb; ipdb.set_trace()
if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold):
curr_overlap_area_threshold = overlap_area_threshold
else:
curr_overlap_area_threshold = overlap_area_threshold * 0.5
downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold)
overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, voxel_downsampled_combined_scanned_pts_np, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True)
# if not overlap:
# Log.yellow("no overlap!")
# retry += 1
# retry_overlap_pose.append(pred_pose.cpu().numpy().tolist())
# continue
downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold)
overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, voxel_downsampled_combined_scanned_pts_np, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True)
if not overlap:
Log.yellow("no overlap!")
retry += 1
retry_overlap_pose.append(pred_pose.cpu().numpy().tolist())
continue
history_indices.append(new_scan_points_indices)
except Exception as e:
Log.error(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
continue
history_indices.append(new_scan_points_indices)
except Exception as e:
Log.error(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
continue
if new_target_pts.shape[0] == 0:
Log.red("no pts in new target")
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
continue
pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}")
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
print("max coverage rate reached!: ", pred_cr)
if new_target_pts.shape[0] == 0:
Log.red("no pts in new target")
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
continue
pred_cr_seq.append(pred_cr)
scanned_view_pts.append(new_target_pts)
pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}")
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
print("max coverage rate reached!: ", pred_cr)
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
combined_scanned_pts = np.vstack(scanned_view_pts)
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold)
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
last_pred_cr = pred_cr
pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0]
Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}")
pred_cr_seq.append(pred_cr)
scanned_view_pts.append(new_target_pts)
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
combined_scanned_pts = np.vstack(scanned_view_pts)
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold)
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2:
retry += 1
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}")
elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2:
success += 1
Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}")
last_pred_cr = pred_cr
pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0]
Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}")
last_pts_num = pts_num
if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2:
retry += 1
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}")
elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2:
success += 1
Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}")
last_pts_num = pts_num
break
input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist()