diff --git a/core/global_pts_n_num_pipeline.py b/core/global_pts_n_num_pipeline.py new file mode 100644 index 0000000..948b047 --- /dev/null +++ b/core/global_pts_n_num_pipeline.py @@ -0,0 +1,100 @@ +import torch +from torch import nn +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory.component_factory import ComponentFactory +from PytorchBoot.utils import Log + + +@stereotype.pipeline("nbv_reconstruction_global_pts_n_num_pipeline") +class NBVReconstructionGlobalPointsPipeline(nn.Module): + def __init__(self, config): + super(NBVReconstructionGlobalPointsPipeline, self).__init__() + self.config = config + self.module_config = config["modules"] + self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) + self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) + self.pose_n_num_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]) + self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) + self.pts_num_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]) + + self.eps = float(self.config["eps"]) + self.enable_global_scanned_feat = self.config["global_scanned_feat"] + + def forward(self, data): + mode = data["mode"] + + if mode == namespace.Mode.TRAIN: + return self.forward_train(data) + elif mode == namespace.Mode.TEST: + return self.forward_test(data) + else: + Log.error("Unknown mode: {}".format(mode), True) + + def pertube_data(self, gt_delta_9d): + bs = gt_delta_9d.shape[0] + random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) + std = std.view(-1, 1) + z = torch.randn_like(gt_delta_9d) + perturbed_x = mu + z * std + target_score = - z * std / (std ** 2) + return perturbed_x, random_t, target_score, std + + def forward_train(self, data): + main_feat = self.get_main_feat(data) + ''' get std ''' + best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] + perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "main_feat": main_feat, + } + estimated_score = self.view_finder(input_data) + output = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std + } + return output + + def forward_test(self,data): + main_feat = self.get_main_feat(data) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) + result = { + "pred_pose_9d": estimated_delta_rot_9d, + "in_process_sample": in_process_sample + } + return result + + + def get_main_feat(self, data): + scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] + scanned_target_pts_num_batch = data['scanned_target_points_num'] + + device = next(self.parameters()).device + + pose_feat_seq_list = [] + pts_num_feat_seq_list = [] + + for scanned_n_to_world_pose_9d,scanned_target_pts_num in zip(scanned_n_to_world_pose_9d_batch,scanned_target_pts_num_batch): + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) + scanned_target_pts_num = scanned_target_pts_num.to(device) + pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)) + pts_num_feat_seq_list.append(self.pts_num_encoder.encode_pts_num(scanned_target_pts_num)) + + main_feat = self.pose_n_num_seq_encoder.encode_sequence(pts_num_feat_seq_list, pose_feat_seq_list) + + + combined_scanned_pts_batch = data['combined_scanned_pts'] + global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch) + main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1) + + + if torch.isnan(main_feat).any(): + Log.error("nan in main_feat", True) + + return main_feat + diff --git a/modules/pts_num_encoder.py b/modules/pts_num_encoder.py new file mode 100644 index 0000000..2210c21 --- /dev/null +++ b/modules/pts_num_encoder.py @@ -0,0 +1,20 @@ +from torch import nn +import PytorchBoot.stereotype as stereotype + +@stereotype.module("pts_num_encoder") +class PointsNumEncoder(nn.Module): + def __init__(self, config): + super(PointsNumEncoder, self).__init__() + self.config = config + out_dim = config["out_dim"] + self.act = nn.ReLU(True) + + self.pts_num_encoder = nn.Sequential( + nn.Linear(1, out_dim), + self.act, + nn.Linear(out_dim, out_dim), + self.act, + ) + + def encode_pts_num(self, num_seq): + return self.pts_num_encoder(num_seq) diff --git a/modules/transformer_pose_n_num_seq_encoder.py b/modules/transformer_pose_n_num_seq_encoder.py new file mode 100644 index 0000000..c8cedcb --- /dev/null +++ b/modules/transformer_pose_n_num_seq_encoder.py @@ -0,0 +1,72 @@ +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence +import PytorchBoot.stereotype as stereotype + + +@stereotype.module("transformer_pose_n_num_seq_encoder") +class TransformerPoseAndNumSequenceEncoder(nn.Module): + def __init__(self, config): + super(TransformerPoseAndNumSequenceEncoder, self).__init__() + self.config = config + embed_dim = config["pts_num_embed_dim"] + config["pose_embed_dim"] + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=config["num_heads"], + dim_feedforward=config["ffn_dim"], + batch_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=config["num_layers"] + ) + self.fc = nn.Linear(embed_dim, config["output_dim"]) + + def encode_sequence(self, pts_num_embedding_list_batch, pose_embedding_list_batch): + combined_features_batch = [] + lengths = [] + + for pts_num_embedding_list, pose_embedding_list in zip(pts_num_embedding_list_batch, pose_embedding_list_batch): + combined_features = [ + torch.cat((pts_num_embed, pose_embed), dim=-1) + for pts_num_embed, pose_embed in zip(pts_num_embedding_list, pose_embedding_list) + ] + combined_features_batch.append(torch.stack(combined_features)) + lengths.append(len(combined_features)) + + combined_tensor = pad_sequence(combined_features_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] + + max_len = max(lengths) + padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device) + + transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask) + final_feature = transformer_output.mean(dim=1) + final_output = self.fc(final_feature) + + return final_output + + +if __name__ == "__main__": + config = { + "pts_num_embed_dim": 128, + "pose_embed_dim": 256, + "num_heads": 4, + "ffn_dim": 256, + "num_layers": 3, + "output_dim": 2048, + } + + encoder = TransformerPoseAndNumSequenceEncoder(config) + seq_len = [5, 8, 9, 4] + batch_size = 4 + + pts_num_embedding_list_batch = [ + torch.randn(seq_len[idx], config["pts_num_embed_dim"]) for idx in range(batch_size) + ] + pose_embedding_list_batch = [ + torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size) + ] + output_feature = encoder.encode_sequence( + pts_num_embedding_list_batch, pose_embedding_list_batch + ) + print("Encoded Feature:", output_feature) + print("Feature Shape:", output_feature.shape) diff --git a/modules/transformer_seq_encoder.py b/modules/transformer_pose_n_pts_seq_encoder.py similarity index 97% rename from modules/transformer_seq_encoder.py rename to modules/transformer_pose_n_pts_seq_encoder.py index 1eae505..98def3f 100644 --- a/modules/transformer_seq_encoder.py +++ b/modules/transformer_pose_n_pts_seq_encoder.py @@ -4,7 +4,7 @@ from torch.nn.utils.rnn import pad_sequence import PytorchBoot.stereotype as stereotype -@stereotype.module("transformer_seq_encoder") +@stereotype.module("transformer_pose_n_pts_seq_encoder") class TransformerSequenceEncoder(nn.Module): def __init__(self, config): super(TransformerSequenceEncoder, self).__init__()