From 1123e69bffb19d59f7e6b664604ea51b5d588775 Mon Sep 17 00:00:00 2001 From: hofee Date: Thu, 31 Oct 2024 12:02:48 +0000 Subject: [PATCH] fix nan --- core/pipeline.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/pipeline.py b/core/pipeline.py index e295c87..f3558f5 100644 --- a/core/pipeline.py +++ b/core/pipeline.py @@ -108,7 +108,10 @@ class NBVReconstructionPipeline(nn.Module): partial_point_feat_seq = [] for j in range(seq_len): partial_per_point_feat = per_point_feat[scanned_pts_mask[j]] - partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp) + if partial_per_point_feat.shape[0] == 0: + partial_point_feat = torch.zeros(per_point_feat.shape[1], device=device) + else: + partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp) partial_point_feat_seq.append(partial_point_feat) partial_point_feat_seq = torch.stack(partial_point_feat_seq, dim=0) # Tensor(S x Dp) @@ -122,6 +125,13 @@ class NBVReconstructionPipeline(nn.Module): main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg)) if torch.isnan(main_feat).any(): + for i in range(len(main_feat)): + if torch.isnan(main_feat[i]).any(): + scanned_pts_mask = scanned_pts_mask_batch[i] + Log.info(f"scanned_pts_mask shape: {scanned_pts_mask.shape}") + Log.info(f"scanned_pts_mask sum: {scanned_pts_mask.sum()}") + import ipdb + ipdb.set_trace() Log.error("nan in main_feat", True) return main_feat