success
This commit is contained in:
139
modules/pipeline.py
Executable file
139
modules/pipeline.py
Executable file
@@ -0,0 +1,139 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import inspect
|
||||
|
||||
from configs.config import ConfigManager
|
||||
|
||||
from modules.pts_encoder.pts_encoder_factory import PointsEncoderFactory
|
||||
from modules.view_finder.view_finder_factory import ViewFinderFactory
|
||||
from modules.module_lib.fusion_layer import FeatureFusion
|
||||
from modules.rgb_encoder.rgb_encoder_factory import RGBEncoderFactory
|
||||
|
||||
|
||||
class Pipeline(nn.Module):
|
||||
TRAIN_MODE: str = "train"
|
||||
TEST_MODE: str = "test"
|
||||
|
||||
def __init__(self, pipeline_config):
|
||||
super(Pipeline, self).__init__()
|
||||
|
||||
self.modules_config = ConfigManager.get("modules")
|
||||
self.device = ConfigManager.get("settings", "general", "device")
|
||||
self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")
|
||||
self.pts_encoder = PointsEncoderFactory.create(pipeline_config["pts_encoder"], self.modules_config)
|
||||
self.view_finder = ViewFinderFactory.create(pipeline_config["view_finder"], self.modules_config)
|
||||
self.has_rgb_encoder = "rgb_encoder" in pipeline_config
|
||||
if self.has_rgb_encoder and not self.rgb_feat_cache:
|
||||
self.rgb_encoder = RGBEncoderFactory.create(pipeline_config["rgb_encoder"], self.modules_config)
|
||||
self.eps = 1e-5
|
||||
self.fusion_layer = FeatureFusion(rgb_dim=384, pts_dim=1024,output_dim=1024)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(self, data, mode):
|
||||
if mode == self.TRAIN_MODE:
|
||||
return self.forward_gradient(data)
|
||||
elif mode == self.TEST_MODE:
|
||||
return self.forward_view(data)
|
||||
raise ValueError("Unknown mode: {}".format(self.mode))
|
||||
|
||||
def forward_gradient(self, data):
|
||||
target_pts = data["target_pts"]
|
||||
scene_pts = data["scene_pts"]
|
||||
gt_delta_rot_6d = data["delta_rot_6d"]
|
||||
|
||||
if hasattr(self,"rgb_encoder"):
|
||||
if "rgb" in data:
|
||||
rgb_feat = self.rgb_encoder.encode_rgb(data["rgb"])
|
||||
else:
|
||||
rgb_feat = data["rgb_feat"]
|
||||
if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters:
|
||||
target_feat = self.pts_encoder.encode_points(target_pts)
|
||||
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
||||
target_feat = self.fusion_layer(rgb_feat, target_feat)
|
||||
scene_feat = self.fusion_layer(rgb_feat, scene_feat)
|
||||
else:
|
||||
target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat)
|
||||
scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat)
|
||||
else:
|
||||
target_feat = self.pts_encoder.encode_points(target_pts)
|
||||
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
||||
''' get std '''
|
||||
bs = target_pts.shape[0]
|
||||
random_t = torch.rand(bs, device=self.device) * (1. - self.eps) + self.eps
|
||||
random_t = random_t.unsqueeze(-1)
|
||||
mu, std = self.view_finder.marginal_prob(gt_delta_rot_6d, random_t)
|
||||
std = std.view(-1, 1)
|
||||
|
||||
''' perturb data and get estimated score '''
|
||||
z = torch.randn_like(gt_delta_rot_6d)
|
||||
perturbed_x = mu + z * std
|
||||
input_data = {
|
||||
"sampled_pose": perturbed_x,
|
||||
"t": random_t,
|
||||
"scene_feat": scene_feat,
|
||||
"target_feat": target_feat
|
||||
}
|
||||
estimated_score = self.view_finder(input_data)
|
||||
|
||||
''' get target score '''
|
||||
target_score = - z * std / (std ** 2)
|
||||
|
||||
result = {
|
||||
"estimated_score": estimated_score,
|
||||
"target_score": target_score,
|
||||
"std": std
|
||||
}
|
||||
return result
|
||||
|
||||
def forward_view(self, data):
|
||||
target_pts = data["target_pts"]
|
||||
scene_pts = data["scene_pts"]
|
||||
|
||||
if self.has_rgb_encoder :
|
||||
if self.rgb_feat_cache:
|
||||
rgb_feat = data["rgb_feat"]
|
||||
else:
|
||||
rgb = data["rgb"]
|
||||
rgb_feat = self.rgb_encoder.encode_rgb(rgb)
|
||||
if "rgb_feat" not in inspect.signature(self.pts_encoder.encode_points).parameters:
|
||||
target_feat = self.pts_encoder.encode_points(target_pts)
|
||||
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
||||
target_feat = self.fusion_layer(rgb_feat, target_feat)
|
||||
scene_feat = self.fusion_layer(rgb_feat, scene_feat)
|
||||
else:
|
||||
target_feat = self.pts_encoder.encode_points(target_pts, rgb_feat)
|
||||
scene_feat = self.pts_encoder.encode_points(scene_pts, rgb_feat)
|
||||
else:
|
||||
target_feat = self.pts_encoder.encode_points(target_pts)
|
||||
scene_feat = self.pts_encoder.encode_points(scene_pts)
|
||||
estimated_delta_rot_6d, in_process_sample = self.view_finder.next_best_view(scene_feat, target_feat)
|
||||
result = {
|
||||
"estimated_delta_rot_6d": estimated_delta_rot_6d,
|
||||
"in_process_sample": in_process_sample
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ConfigManager.load_config_with('../configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
||||
test_pipeline_config = ConfigManager.get("settings", "pipeline")
|
||||
pipeline = Pipeline(test_pipeline_config)
|
||||
test_scene = torch.rand(32, 1024, 3).to("cuda:0")
|
||||
test_target = torch.rand(32, 1024, 3).to("cuda:0")
|
||||
test_delta_rot_6d = torch.rand(32, 6).to("cuda:0")
|
||||
a = test_delta_rot_6d[:, :3]
|
||||
b = test_delta_rot_6d[:, 3:]
|
||||
a_norm = a / a.norm(dim=1, keepdim=True)
|
||||
b_norm = b / b.norm(dim=1, keepdim=True)
|
||||
normalized_test_delta_rot_6d = torch.cat((a_norm, b_norm), dim=1)
|
||||
test_data = {
|
||||
'target_pts': test_target,
|
||||
'scene_pts': test_scene,
|
||||
'delta_rot_6d': normalized_test_delta_rot_6d
|
||||
}
|
||||
out_data = pipeline(test_data, "train")
|
||||
print(out_data.keys())
|
||||
out_data_test = pipeline(test_data, "test")
|
||||
print(out_data_test.keys())
|
Reference in New Issue
Block a user