upd ab_global_only

This commit is contained in:
2024-11-20 15:24:45 +08:00
parent 493639287e
commit 2c8ef20321
5 changed files with 80 additions and 99 deletions

View File

@@ -23,11 +23,15 @@ from utils.data_load import DataLoadUtil
@stereotype.runner("inferencer")
class Inferencer(Runner):
def __init__(self, config_path):
super().__init__(config_path)
self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path")
self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir")
self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size")
self.min_new_area = ConfigManager.get(namespace.Stereotype.RUNNER, "min_new_area")
CM = 0.01
self.min_new_pts_num = self.min_new_area * (CM / self.voxel_size) **2
''' Pipeline '''
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
@@ -74,22 +78,24 @@ class Inferencer(Runner):
total=int(len(test_set))
for i in tqdm(range(total), desc=f"Processing {test_set_name}", ncols=100):
data = test_set.__getitem__(i)
scene_name = data["scene_name"]
if scene_name != "omniobject3d-suitcase_001":
try:
data = test_set.__getitem__(i)
scene_name = data["scene_name"]
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
if os.path.exists(inference_result_path):
Log.info(f"Inference result already exists for scene: {scene_name}")
continue
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
output = self.predict_sequence(data)
self.save_inference_result(test_set_name, data["scene_name"], output)
except Exception as e:
Log.error(f"Error in scene {scene_name}, {e}")
continue
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
if os.path.exists(inference_result_path):
Log.info(f"Inference result already exists for scene: {scene_name}")
continue
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
output = self.predict_sequence(data)
self.save_inference_result(test_set_name, data["scene_name"], output)
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 5):
def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 10, max_success=3):
scene_name = data["scene_name"]
Log.info(f"Processing scene: {scene_name}")
status_manager.set_status("inference", "inferencer", "scene", scene_name)
@@ -128,13 +134,11 @@ class Inferencer(Runner):
retry = 0
pred_cr_seq = [last_pred_cr]
success = 0
last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], 0.002).shape[0]
last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], voxel_threshold).shape[0]
import time
while len(pred_cr_seq) < max_iter and retry < max_retry:
start_time = time.time()
while len(pred_cr_seq) < max_iter and retry < max_retry and success < max_success:
Log.green(f"iter: {len(pred_cr_seq)}, retry: {retry}/{max_retry}, success: {success}/{max_success}")
output = self.pipeline(input_data)
end_time = time.time()
print(f"Time taken for inference: {end_time - start_time} seconds")
pred_pose_9d = output["pred_pose_9d"]
pred_pose = torch.eye(4, device=pred_pose_9d.device)
@@ -142,7 +146,6 @@ class Inferencer(Runner):
pred_pose[:3,3] = pred_pose_9d[0,6:]
try:
start_time = time.time()
new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
#import ipdb; ipdb.set_trace()
if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold):
@@ -153,15 +156,14 @@ class Inferencer(Runner):
downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold)
overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, down_sampled_model_pts, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True)
if not overlap:
Log.yellow("no overlap!")
retry += 1
retry_overlap_pose.append(pred_pose.cpu().numpy().tolist())
continue
history_indices.append(new_scan_points_indices)
end_time = time.time()
print(f"Time taken for rendering: {end_time - start_time} seconds")
except Exception as e:
Log.warning(f"Error in scene {scene_path}, {e}")
Log.error(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
@@ -169,40 +171,41 @@ class Inferencer(Runner):
continue
if new_target_pts.shape[0] == 0:
print("no pts in new target")
Log.red("no pts in new target")
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
continue
start_time = time.time()
pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
end_time = time.time()
print(f"Time taken for coverage rate computation: {end_time - start_time} seconds")
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}")
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
print("max coverage rate reached!: ", pred_cr)
success += 1
retry = 0
pred_cr_seq.append(pred_cr)
scanned_view_pts.append(new_target_pts)
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
combined_scanned_pts = np.vstack(scanned_view_pts)
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, 0.002)
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold)
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
if success > 3:
break
last_pred_cr = pred_cr
pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0]
if pts_num - last_pts_num < 10 and pred_cr < data["seq_max_coverage_rate"] - 1e-3:
Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}")
if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2:
retry += 1
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
print("delta pts num < 10:", pts_num, last_pts_num)
Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}")
elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2:
success += 1
Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}")
last_pts_num = pts_num