add scan points check
This commit is contained in:
@@ -76,16 +76,16 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
pose_feat_seq_list = []
|
||||
pts_num_feat_seq_list = []
|
||||
embedding_list_batch = []
|
||||
|
||||
for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch):
|
||||
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device)
|
||||
scanned_target_pts_num = scanned_target_pts_num.to(device)
|
||||
pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d))
|
||||
pts_num_feat_seq_list.append(self.pts_num_encoder.encode_pts_num(scanned_target_pts_num))
|
||||
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)
|
||||
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)
|
||||
embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1))
|
||||
|
||||
main_feat = self.pose_n_num_seq_encoder.encode_sequence(pts_num_feat_seq_list, pose_feat_seq_list)
|
||||
main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch)
|
||||
|
||||
|
||||
combined_scanned_pts_batch = data['combined_scanned_pts']
|
||||
|
Reference in New Issue
Block a user