success
This commit is contained in:
0
modules/view_finder/__init__.py
Executable file
0
modules/view_finder/__init__.py
Executable file
12
modules/view_finder/abstract_view_finder.py
Executable file
12
modules/view_finder/abstract_view_finder.py
Executable file
@@ -0,0 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ViewFinder(nn.Module):
|
||||
def __init__(self):
|
||||
super(ViewFinder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
||||
pass
|
165
modules/view_finder/gf_view_finder.py
Executable file
165
modules/view_finder/gf_view_finder.py
Executable file
@@ -0,0 +1,165 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.pose_util import PoseUtil
|
||||
from modules.view_finder.abstract_view_finder import ViewFinder
|
||||
import modules.module_lib as mlib
|
||||
import modules.func_lib as flib
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class GradientFieldViewFinder(ViewFinder):
|
||||
def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False,
|
||||
sample_mode="ode", sampling_steps=None, sde_mode="ve"):
|
||||
|
||||
super(GradientFieldViewFinder, self).__init__()
|
||||
self.regression_head = regression_head
|
||||
self.per_point_feature = per_point_feature
|
||||
self.act = nn.ReLU(True)
|
||||
self.sample_mode = sample_mode
|
||||
self.pose_mode = pose_mode
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode)
|
||||
self.sampling_steps = sampling_steps
|
||||
|
||||
''' encode pose '''
|
||||
self.pose_encoder = nn.Sequential(
|
||||
nn.Linear(pose_dim, 256),
|
||||
self.act,
|
||||
nn.Linear(256, 256),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' encode t '''
|
||||
self.t_encoder = nn.Sequential(
|
||||
mlib.GaussianFourierProjection(embed_dim=128),
|
||||
nn.Linear(128, 128),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' fusion tail '''
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
if pose_mode != 'rot_matrix':
|
||||
raise NotImplementedError
|
||||
if not per_point_feature:
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
Args:
|
||||
data, dict {
|
||||
'target_pts_feat': [bs, c]
|
||||
'scene_pts_feat': [bs, c]
|
||||
'pose_sample': [bs, pose_dim]
|
||||
't': [bs, 1]
|
||||
}
|
||||
"""
|
||||
|
||||
scene_pts_feat = data['scene_feat']
|
||||
target_pts_feat = data['target_feat']
|
||||
sampled_pose = data['sampled_pose']
|
||||
t = data['t']
|
||||
t_feat = self.t_encoder(t.squeeze(1))
|
||||
pose_feat = self.pose_encoder(sampled_pose)
|
||||
|
||||
if self.per_point_feature:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1)
|
||||
_, std = self.marginal_prob_fn(total_feat, t)
|
||||
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
rot_x = self.fusion_tail_rot_x(total_feat)
|
||||
rot_y = self.fusion_tail_rot_y(total_feat)
|
||||
out_score = torch.cat([rot_x, rot_y], dim=-1) / (std + 1e-7) # normalisation
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return out_score
|
||||
|
||||
def marginal_prob(self, x, t):
|
||||
return self.marginal_prob_fn(x,t)
|
||||
|
||||
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
|
||||
|
||||
if self.sample_mode == 'pc':
|
||||
in_process_sample, res = flib.cond_pc_sampler(
|
||||
score_model=self,
|
||||
data=data,
|
||||
prior=self.prior_fn,
|
||||
sde_coeff=self.sde_fn,
|
||||
num_steps=self.sampling_steps,
|
||||
snr=snr,
|
||||
eps=self.sampling_eps,
|
||||
pose_mode=self.pose_mode,
|
||||
init_x=init_x
|
||||
)
|
||||
|
||||
elif self.sample_mode == 'ode':
|
||||
T0 = self.T if T0 is None else T0
|
||||
in_process_sample, res = flib.cond_ode_sampler(
|
||||
score_model=self,
|
||||
data=data,
|
||||
prior=self.prior_fn,
|
||||
sde_coeff=self.sde_fn,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
eps=self.sampling_eps,
|
||||
T=T0,
|
||||
num_steps=self.sampling_steps,
|
||||
pose_mode=self.pose_mode,
|
||||
denoise=denoise,
|
||||
init_x=init_x
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return in_process_sample, res
|
||||
|
||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
||||
data = {
|
||||
'scene_feat': scene_pts_feat,
|
||||
'target_feat': target_pts_feat,
|
||||
}
|
||||
in_process_sample, res = self.sample(data)
|
||||
return res.to(dtype=torch.float32), in_process_sample
|
||||
|
||||
|
||||
''' ----------- DEBUG -----------'''
|
||||
if __name__ == "__main__":
|
||||
test_scene_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_target_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_pose = torch.rand(32, 6).to("cuda:0")
|
||||
test_t = torch.rand(32, 1).to("cuda:0")
|
||||
view_finder = GradientFieldViewFinder().to("cuda:0")
|
||||
test_data = {
|
||||
'target_feat': test_target_feat,
|
||||
'scene_feat': test_scene_feat,
|
||||
'sampled_pose': test_pose,
|
||||
't': test_t
|
||||
}
|
||||
score = view_finder(test_data)
|
||||
|
||||
result = view_finder.next_best_view(test_scene_feat, test_target_feat)
|
||||
print(result)
|
45
modules/view_finder/view_finder_factory.py
Executable file
45
modules/view_finder/view_finder_factory.py
Executable file
@@ -0,0 +1,45 @@
|
||||
from modules.view_finder.abstract_view_finder import ViewFinder
|
||||
from modules.view_finder.gf_view_finder import GradientFieldViewFinder
|
||||
|
||||
|
||||
class ViewFinderFactory:
|
||||
@staticmethod
|
||||
def create(name, config) -> ViewFinder:
|
||||
general_config = config["general"]
|
||||
view_finder_config = config["view_finder"][name]
|
||||
if name == "gradient_field":
|
||||
return GradientFieldViewFinder(
|
||||
pose_mode=view_finder_config["pose_mode"],
|
||||
regression_head=view_finder_config["regression_head"],
|
||||
per_point_feature=general_config["per_point_feature"],
|
||||
sample_mode=view_finder_config["sample_mode"],
|
||||
sampling_steps=view_finder_config.get("sampling_steps", None),
|
||||
sde_mode=view_finder_config["sde_mode"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown next-best-view finder name: {name}")
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
import torch
|
||||
|
||||
ConfigManager.load_config_with('../../configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
||||
view_finder = ViewFinderFactory.create(name="gradient_field", config=ConfigManager.get("modules"))
|
||||
test_scene_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_target_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_pose = torch.rand(32, 6).to("cuda:0")
|
||||
test_t = torch.rand(32, 1).to("cuda:0")
|
||||
view_finder = view_finder.to("cuda:0")
|
||||
test_data = {
|
||||
'target_feat': test_target_feat,
|
||||
'scene_feat': test_scene_feat,
|
||||
'sampled_pose': test_pose,
|
||||
't': test_t
|
||||
}
|
||||
score = view_finder(test_data)
|
||||
print(score.shape)
|
||||
pose_6d = view_finder.next_best_view(scene_pts_feat=test_data["scene_feat"], target_pts_feat=test_data["target_feat"])
|
||||
print(pose_6d.shape)
|
Reference in New Issue
Block a user