solve merge
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import time
|
||||
from torch import nn
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
@@ -58,7 +59,10 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
return perturbed_x, random_t, target_score, std
|
||||
|
||||
def forward_train(self, data):
|
||||
start_time = time.time()
|
||||
main_feat = self.get_main_feat(data)
|
||||
end_time = time.time()
|
||||
print("get_main_feat time: ", end_time - start_time)
|
||||
""" get std """
|
||||
best_to_world_pose_9d_batch = data["best_to_world_pose_9d"]
|
||||
perturbed_x, random_t, target_score, std = self.pertube_data(
|
||||
@@ -117,7 +121,7 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
for seq_idx in range(seq_len):
|
||||
partial_idx_in_combined_pts = scanned_mask == seq_idx # Ndarray(V), N->V idx mask
|
||||
partial_perpoint_feat = perpoint_scanned_feat[partial_idx_in_combined_pts] # Ndarray(V x Dl)
|
||||
partial_feat = torch.mean(partial_perpoint_feat, dim=0) # Tensor(Dl)
|
||||
partial_feat = torch.max(partial_perpoint_feat, dim=0) # Tensor(Dl)
|
||||
partial_feat_seq.append(partial_feat)
|
||||
scanned_target_pts_num.append(partial_perpoint_feat.shape[0])
|
||||
|
||||
|
Reference in New Issue
Block a user