change DataLoadUtil and Dataset to blender version
This commit is contained in:
@@ -38,7 +38,7 @@ class NBVReconstructionPipeline(nn.Module):
|
||||
def forward_train(self, data):
|
||||
pts_list = data['pts_list']
|
||||
pose_list = data['pose_list']
|
||||
gt_delta_rot_6d = data["delta_rot_6d"]
|
||||
gt_rot_6d = data["nbv_cam_pose"]
|
||||
pts_feat_list = []
|
||||
pose_feat_list = []
|
||||
for pts,pose in zip(pts_list,pose_list):
|
||||
@@ -46,7 +46,7 @@ class NBVReconstructionPipeline(nn.Module):
|
||||
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
|
||||
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
|
||||
''' get std '''
|
||||
perturbed_x, random_t, target_score, std = self.pertube_data(gt_delta_rot_6d)
|
||||
perturbed_x, random_t, target_score, std = self.pertube_data(gt_rot_6d)
|
||||
input_data = {
|
||||
"sampled_pose": perturbed_x,
|
||||
"t": random_t,
|
||||
@@ -69,9 +69,9 @@ class NBVReconstructionPipeline(nn.Module):
|
||||
pts_feat_list.append(self.pts_encoder.encode_points(pts))
|
||||
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
|
||||
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
|
||||
estimated_delta_rot_6d, in_process_sample = self.view_finder.next_best_view(seq_feat)
|
||||
estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(seq_feat)
|
||||
result = {
|
||||
"estimated_delta_rot_6d": estimated_delta_rot_6d,
|
||||
"pred_pose_9d": estimated_delta_rot_9d,
|
||||
"in_process_sample": in_process_sample
|
||||
}
|
||||
return result
|
||||
|
Reference in New Issue
Block a user