finish PoseDiff and NBVDataset
This commit is contained in:
@@ -6,7 +6,7 @@ import PytorchBoot.namespace as namespace
|
||||
def get_view_data(cam_pose, scene_name):
|
||||
pass
|
||||
|
||||
@stereotype.evaluation_method("pose_diff", comment="not tested")
|
||||
@stereotype.evaluation_method("pose_diff")
|
||||
class PoseDiff:
|
||||
def __init__(self, _):
|
||||
pass
|
||||
@@ -16,7 +16,7 @@ class PoseDiff:
|
||||
rot_angle_list = []
|
||||
trans_dist_list = []
|
||||
for output, data in zip(output_list, data_list):
|
||||
gt_pose_9d = data['nbv_cam_pose']
|
||||
gt_pose_9d = data['best_to_1_pose_9d']
|
||||
pred_pose_9d = output['pred_pose_9d']
|
||||
gt_rot_6d = gt_pose_9d[:, :6]
|
||||
gt_trans = gt_pose_9d[:, 6:]
|
||||
@@ -49,9 +49,9 @@ class ConverageRateIncrease:
|
||||
cr_diff_list = []
|
||||
for output, data in zip(output_list, data_list):
|
||||
scanned_cr = data['scanned_coverages_rate']
|
||||
gt_cr = data["nbv_coverage_rate"]
|
||||
gt_cr = data["best_coverage_rate"]
|
||||
scene_name_list = data['scene_name']
|
||||
scanned_view_pts_list = data['scanned_views_pts']
|
||||
scanned_view_pts_list = data['scanned_pts']
|
||||
pred_pose_9ds = output['pred_pose_9d']
|
||||
pred_rot_mats = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9ds[:, :6])
|
||||
pred_pose_mats = torch.cat([pred_rot_mats, pred_pose_9ds[:, 6:]], dim=-1)
|
||||
|
Reference in New Issue
Block a user