change DataLoadUtil and Dataset to blender version
This commit is contained in:
@@ -1,23 +1,37 @@
|
||||
import torch
|
||||
from utils.pose import PoseUtil
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
import PytorchBoot.namespace as namespace
|
||||
|
||||
@stereotype.evaluation_method("delta_pose_diff")
|
||||
class DeltaPoseDiff:
|
||||
def get_view_data(cam_pose, scene_name):
|
||||
pass
|
||||
|
||||
@stereotype.evaluation_method("pose_diff", comment="not tested")
|
||||
class PoseDiff:
|
||||
def __init__(self, _):
|
||||
pass
|
||||
|
||||
def evaluate(self, output_list, data_list):
|
||||
results = {namespace.TensorBoard.SCALAR: {}}
|
||||
rot_angle_list = []
|
||||
trans_dist_list = []
|
||||
for output, data in zip(output_list, data_list):
|
||||
gt_delta_rot_6d = data['delta_rot_6d']
|
||||
est_delta_rot_6d = output['estimated_delta_rot_6d']
|
||||
gt_delta_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_delta_rot_6d)
|
||||
est_delta_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(est_delta_rot_6d)
|
||||
rotation_angles = PoseUtil.rotation_angle_distance(gt_delta_rot_mat, est_delta_rot_mat)
|
||||
gt_pose_9d = data['nbv_cam_pose']
|
||||
pred_pose_9d = output['pred_pose_9d']
|
||||
gt_rot_6d = gt_pose_9d[:, :6]
|
||||
gt_trans = gt_pose_9d[:, 6:]
|
||||
pred_rot_6d = pred_pose_9d[:, :6]
|
||||
pred_trans = pred_pose_9d[:, 6:]
|
||||
gt_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(gt_rot_6d)
|
||||
pred_rot_mat = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_rot_6d)
|
||||
rotation_angles = PoseUtil.rotation_angle_distance(gt_rot_mat, pred_rot_mat)
|
||||
rot_angle_list.extend(list(rotation_angles))
|
||||
trans_dist = torch.norm(gt_trans-pred_trans)
|
||||
trans_dist_list.append(trans_dist)
|
||||
|
||||
|
||||
results[namespace.TensorBoard.SCALAR]["delta_rotation"] = float(sum(rot_angle_list) / len(rot_angle_list))
|
||||
results[namespace.TensorBoard.SCALAR]["rot_diff"] = float(sum(rot_angle_list) / len(rot_angle_list))
|
||||
results[namespace.TensorBoard.SCALAR]["trans_diff"] = float(sum(trans_dist_list) / len(trans_dist_list))
|
||||
return results
|
||||
|
||||
|
||||
@@ -25,8 +39,40 @@ class DeltaPoseDiff:
|
||||
@stereotype.evaluation_method("coverage_rate_increase",comment="unfinished")
|
||||
class ConverageRateIncrease:
|
||||
def __init__(self, config):
|
||||
pass
|
||||
self.config = config
|
||||
|
||||
|
||||
def evaluate(self, output_list, data_list):
|
||||
return
|
||||
results = {namespace.TensorBoard.SCALAR: {}}
|
||||
gt_coverate_increase_list = []
|
||||
pred_coverate_increase_list = []
|
||||
cr_diff_list = []
|
||||
for output, data in zip(output_list, data_list):
|
||||
scanned_cr = data['scanned_coverages_rate']
|
||||
gt_cr = data["nbv_coverage_rate"]
|
||||
scene_name_list = data['scene_name']
|
||||
scanned_view_pts_list = data['scanned_views_pts']
|
||||
pred_pose_9ds = output['pred_pose_9d']
|
||||
pred_rot_mats = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9ds[:, :6])
|
||||
pred_pose_mats = torch.cat([pred_rot_mats, pred_pose_9ds[:, 6:]], dim=-1)
|
||||
|
||||
for idx in range(len(scanned_cr)):
|
||||
gt_coverate_increase_list.append(gt_cr-scanned_cr[idx])
|
||||
scene_name = scene_name_list[idx]
|
||||
pred_pose = pred_pose_mats[idx]
|
||||
scanned_view_pts = scanned_view_pts_list[idx]
|
||||
view_data = get_view_data(pred_pose, scene_name)
|
||||
pred_cr = self.compute_coverage_rate(pred_pose, scanned_view_pts, view_data)
|
||||
pred_coverate_increase_list.append(pred_cr-scanned_cr[idx])
|
||||
cr_diff_list.append(gt_cr-pred_cr)
|
||||
|
||||
results[namespace.TensorBoard.SCALAR]["gt_cr_increase"] = float(sum(gt_coverate_increase_list) / len(gt_coverate_increase_list))
|
||||
results[namespace.TensorBoard.SCALAR]["pred_cr_increase"] = float(sum(pred_coverate_increase_list) / len(pred_coverate_increase_list))
|
||||
results[namespace.TensorBoard.SCALAR]["cr_diff"] = float(sum(cr_diff_list) / len(cr_diff_list))
|
||||
return results
|
||||
|
||||
def compute_coverage_rate(self, pred_pose, scanned_view_pts, view_data):
|
||||
pass
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user