update inferencer: success rate

This commit is contained in:
2024-09-27 16:01:07 +08:00
parent 030bf55192
commit 3bc56af3d5
6 changed files with 81 additions and 52 deletions

View File

@@ -73,7 +73,6 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
device = next(self.parameters()).device
pts_feat_seq_list = []
pose_feat_seq_list = []
for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch:
@@ -82,10 +81,10 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
main_feat = self.pose_seq_encoder.encode_sequence(pose_feat_seq_list)
if self.enable_global_scanned_feat:
combined_scanned_pts_batch = data['combined_scanned_pts']
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
combined_scanned_pts_batch = data['combined_scanned_pts']
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
if torch.isnan(main_feat).any():