global: inference debug
This commit is contained in:
@@ -72,7 +72,7 @@ class Inferencer(Runner):
|
||||
data = test_set.__getitem__(i)
|
||||
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
|
||||
output = self.predict_sequence(data)
|
||||
self.save_inference_result(test_set_name, data["scene_name"][0], output)
|
||||
self.save_inference_result(test_set_name, data["scene_name"], output)
|
||||
|
||||
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
|
||||
|
||||
@@ -107,6 +107,7 @@ class Inferencer(Runner):
|
||||
retry_no_pts_pose = []
|
||||
retry = 0
|
||||
pred_cr_seq = [last_pred_cr]
|
||||
success = 0
|
||||
while len(pred_cr_seq) < max_iter and retry < max_retry:
|
||||
|
||||
output = self.pipeline(input_data)
|
||||
@@ -132,11 +133,13 @@ class Inferencer(Runner):
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
|
||||
pred_cr, new_added_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
|
||||
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
|
||||
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
|
||||
print("max coverage rate reached!: ", pred_cr)
|
||||
if new_added_pts_num < 10:
|
||||
success += 1
|
||||
elif new_added_pts_num < 10:
|
||||
print("min added pts num reached!: ", new_added_pts_num)
|
||||
if pred_cr <= last_pred_cr + cr_increase_threshold:
|
||||
retry += 1
|
||||
@@ -156,32 +159,29 @@ class Inferencer(Runner):
|
||||
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
|
||||
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
|
||||
|
||||
|
||||
if success > 3:
|
||||
break
|
||||
last_pred_cr = pred_cr
|
||||
|
||||
|
||||
input_data["scanned_pts"] = input_data["scanned_pts"][0].cpu().numpy().tolist()
|
||||
input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist()
|
||||
result = {
|
||||
"pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"],
|
||||
"pts_seq": input_data["scanned_pts"],
|
||||
"combined_scanned_pts": input_data["combined_scanned_pts"],
|
||||
"target_pts_seq": scanned_view_pts,
|
||||
"coverage_rate_seq": pred_cr_seq,
|
||||
"max_coverage_rate": data["max_coverage_rate"][0],
|
||||
"max_coverage_rate": data["seq_max_coverage_rate"],
|
||||
"pred_max_coverage_rate": max(pred_cr_seq),
|
||||
"scene_name": scene_name,
|
||||
"retry_no_pts_pose": retry_no_pts_pose,
|
||||
"retry_duplication_pose": retry_duplication_pose,
|
||||
"best_seq_len": data["best_seq_len"][0],
|
||||
"best_seq_len": data["best_seq_len"],
|
||||
}
|
||||
self.stat_result[scene_name] = {
|
||||
"max_coverage_rate": data["max_coverage_rate"][0],
|
||||
"success_rate": max(pred_cr_seq)/ data["max_coverage_rate"][0],
|
||||
"coverage_rate_seq": pred_cr_seq,
|
||||
"pred_max_coverage_rate": max(pred_cr_seq),
|
||||
"pred_seq_len": len(pred_cr_seq),
|
||||
}
|
||||
print('success rate: ', max(pred_cr_seq) / data["max_coverage_rate"][0])
|
||||
print('success rate: ', max(pred_cr_seq))
|
||||
|
||||
return result
|
||||
|
||||
|
Reference in New Issue
Block a user