update inferencer: success rate
This commit is contained in:
@@ -20,7 +20,7 @@ from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
from PytorchBoot.status import status_manager
|
||||
|
||||
@stereotype.runner("inferencer", comment="not tested")
|
||||
@stereotype.runner("inferencer")
|
||||
class Inferencer(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
@@ -34,6 +34,7 @@ class Inferencer(Runner):
|
||||
|
||||
''' Experiment '''
|
||||
self.load_experiment("nbv_evaluator")
|
||||
self.stat_result = {}
|
||||
|
||||
''' Test '''
|
||||
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
|
||||
@@ -103,9 +104,9 @@ class Inferencer(Runner):
|
||||
input_data["scanned_pts"] = [data["first_pts"][0].to(self.device)]
|
||||
input_data["scanned_n_to_world_pose_9d"] = [data["first_to_world_9d"][0].to(self.device)]
|
||||
input_data["mode"] = namespace.Mode.TEST
|
||||
input_data["combined_scanned_pts"] = data["combined_scanned_pts"]
|
||||
input_pts_N = input_data["scanned_pts"][0].shape[1]
|
||||
|
||||
|
||||
first_frame_target_pts, _ = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
|
||||
scanned_view_pts = [first_frame_target_pts]
|
||||
last_pred_cr = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
|
||||
@@ -138,7 +139,7 @@ class Inferencer(Runner):
|
||||
|
||||
print(pred_cr, last_pred_cr, " max: ", data["max_coverage_rate"])
|
||||
if pred_cr >= data["max_coverage_rate"]:
|
||||
break
|
||||
print("max coverage rate reached!")
|
||||
if pred_cr <= last_pred_cr + cr_increase_threshold:
|
||||
retry += 1
|
||||
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
|
||||
@@ -155,6 +156,11 @@ class Inferencer(Runner):
|
||||
|
||||
input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0] , new_pts_tensor], dim=0)]
|
||||
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
|
||||
combined_scanned_views_pts = np.concatenate(input_data["scanned_pts"][0].tolist(), axis=0)
|
||||
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_views_pts, 0.002)
|
||||
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
|
||||
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
|
||||
|
||||
|
||||
last_pred_cr = pred_cr
|
||||
|
||||
@@ -173,6 +179,15 @@ class Inferencer(Runner):
|
||||
"retry_duplication_pose": retry_duplication_pose,
|
||||
"best_seq_len": data["best_seq_len"][0],
|
||||
}
|
||||
self.stat_result[scene_name] = {
|
||||
"max_coverage_rate": data["max_coverage_rate"][0],
|
||||
"success_rate": max(pred_cr_seq)/ data["max_coverage_rate"][0],
|
||||
"coverage_rate_seq": pred_cr_seq,
|
||||
"pred_max_coverage_rate": max(pred_cr_seq),
|
||||
"pred_seq_len": len(pred_cr_seq),
|
||||
}
|
||||
print('success rate: ', max(pred_cr_seq) / data["max_coverage_rate"][0])
|
||||
|
||||
return result
|
||||
|
||||
def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005):
|
||||
@@ -191,6 +206,8 @@ class Inferencer(Runner):
|
||||
os.makedirs(dataset_dir)
|
||||
output_path = os.path.join(dataset_dir, f"{scene_name}.pkl")
|
||||
pickle.dump(output, open(output_path, "wb"))
|
||||
with open(os.path.join(dataset_dir, "stat.json"), "w") as f:
|
||||
json.dump(self.stat_result, f)
|
||||
|
||||
|
||||
def get_checkpoint_path(self, is_last=False):
|
||||
|
Reference in New Issue
Block a user