solve merge

This commit is contained in:
2024-10-28 18:25:53 +00:00
parent bd27226f0f
commit 3c9e2c8d12
4 changed files with 22 additions and 11 deletions

View File

@@ -1,4 +1,5 @@
import torch
import time
from torch import nn
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
@@ -58,7 +59,10 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
return perturbed_x, random_t, target_score, std
def forward_train(self, data):
start_time = time.time()
main_feat = self.get_main_feat(data)
end_time = time.time()
print("get_main_feat time: ", end_time - start_time)
""" get std """
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
perturbed_x, random_t, target_score, std = self.pertube_data(
@@ -117,7 +121,7 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
for seq_idx in range(seq_len):
partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask
partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl)
partial_feat = torch.mean(partial_perpoint_feat, dim=0) # Tensor(Dl)
partial_feat = torch.max(partial_perpoint_feat, dim=0) # Tensor(Dl)
partial_feat_seq.append(partial_feat)
scanned_target_pts_num.append(partial_perpoint_feat.shape[0])