update inferencer: success rate
This commit is contained in:
@@ -73,7 +73,6 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
pts_feat_seq_list = []
|
||||
pose_feat_seq_list = []
|
||||
|
||||
for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch:
|
||||
@@ -82,10 +81,10 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
|
||||
main_feat = self.pose_seq_encoder.encode_sequence(pose_feat_seq_list)
|
||||
|
||||
if self.enable_global_scanned_feat:
|
||||
combined_scanned_pts_batch = data['combined_scanned_pts']
|
||||
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
|
||||
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
|
||||
|
||||
combined_scanned_pts_batch = data['combined_scanned_pts']
|
||||
global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch)
|
||||
main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1)
|
||||
|
||||
|
||||
if torch.isnan(main_feat).any():
|
||||
|
Reference in New Issue
Block a user