add scan points check

This commit is contained in:
hofee
2024-09-30 00:55:34 +08:00
parent 2f6d156abd
commit cef7ab4429
3 changed files with 52 additions and 19 deletions

View File

@@ -76,16 +76,16 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
device = next(self.parameters()).device
pose_feat_seq_list = []
pts_num_feat_seq_list = []
embedding_list_batch = []
for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch):
scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device)
scanned_target_pts_num = scanned_target_pts_num.to(device)
pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d))
pts_num_feat_seq_list.append(self.pts_num_encoder.encode_pts_num(scanned_target_pts_num))
pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)
pts_num_feat_seq = self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)
embedding_list_batch.append(torch.cat([pose_feat_seq, pts_num_feat_seq], dim=-1))
main_feat = self.pose_n_num_seq_encoder.encode_sequence(pts_num_feat_seq_list, pose_feat_seq_list)
main_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch)
combined_scanned_pts_batch = data['combined_scanned_pts']