diff --git a/core/pipeline.py b/core/pipeline.py index e295c87..f3558f5 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -108,7 +108,10 @@ class NBVReconstructionPipeline(nn.Module): partial_point_feat_seq = [] for j in range(seq_len): partial_per_point_feat = per_point_feat[scanned_pts_mask[j]] - partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp) + if partial_per_point_feat.shape[0] == 0: + partial_point_feat = torch.zeros(per_point_feat.shape[1], device=device) + else: + partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp) partial_point_feat_seq.append(partial_point_feat) partial_point_feat_seq = torch.stack(partial_point_feat_seq, dim=0) # Tensor(S x Dp) @@ -122,6 +125,13 @@ class NBVReconstructionPipeline(nn.Module): main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg)) if torch.isnan(main_feat).any(): + for i in range(len(main_feat)): + if torch.isnan(main_feat[i]).any(): + scanned_pts_mask = scanned_pts_mask_batch[i] + Log.info(f"scanned_pts_mask shape: {scanned_pts_mask.shape}") + Log.info(f"scanned_pts_mask sum: {scanned_pts_mask.sum()}") + import ipdb + ipdb.set_trace() Log.error("nan in main_feat", True) return main_feat