first commit
This commit is contained in:
BIN
modules/__pycache__/gf_view_finder.cpython-39.pyc
Normal file
BIN
modules/__pycache__/gf_view_finder.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/mlp_view_finder.cpython-39.pyc
Normal file
BIN
modules/__pycache__/mlp_view_finder.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/pointnet++_encoder.cpython-39.pyc
Normal file
BIN
modules/__pycache__/pointnet++_encoder.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/pointnet_encoder.cpython-39.pyc
Normal file
BIN
modules/__pycache__/pointnet_encoder.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/pose_encoder.cpython-39.pyc
Normal file
BIN
modules/__pycache__/pose_encoder.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/pts_num_encoder.cpython-39.pyc
Normal file
BIN
modules/__pycache__/pts_num_encoder.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/__pycache__/transformer_seq_encoder.cpython-39.pyc
Normal file
BIN
modules/__pycache__/transformer_seq_encoder.cpython-39.pyc
Normal file
Binary file not shown.
6
modules/func_lib/__init__.py
Normal file
6
modules/func_lib/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from modules.func_lib.samplers import (
|
||||
cond_ode_sampler
|
||||
)
|
||||
from modules.func_lib.sde import (
|
||||
init_sde
|
||||
)
|
BIN
modules/func_lib/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
modules/func_lib/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/func_lib/__pycache__/samplers.cpython-39.pyc
Normal file
BIN
modules/func_lib/__pycache__/samplers.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/func_lib/__pycache__/sde.cpython-39.pyc
Normal file
BIN
modules/func_lib/__pycache__/sde.cpython-39.pyc
Normal file
Binary file not shown.
95
modules/func_lib/samplers.py
Normal file
95
modules/func_lib/samplers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from scipy import integrate
|
||||
from utils.pose import PoseUtil
|
||||
|
||||
|
||||
def global_prior_likelihood(z, sigma_max):
|
||||
"""The likelihood of a Gaussian distribution with mean zero and
|
||||
standard deviation sigma."""
|
||||
# z: [bs, pose_dim]
|
||||
shape = z.shape
|
||||
N = np.prod(shape[1:]) # pose_dim
|
||||
return -N / 2.0 * torch.log(2 * np.pi * sigma_max**2) - torch.sum(
|
||||
z**2, dim=-1
|
||||
) / (2 * sigma_max**2)
|
||||
|
||||
|
||||
def cond_ode_sampler(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
device="cuda",
|
||||
eps=1e-5,
|
||||
T=1.0,
|
||||
num_steps=None,
|
||||
pose_mode="quat_wxyz",
|
||||
denoise=True,
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data["main_feat"].shape[0]
|
||||
init_x = (
|
||||
prior((batch_size, pose_dim), T=T).to(device)
|
||||
if init_x is None
|
||||
else init_x + prior((batch_size, pose_dim), T=T).to(device)
|
||||
)
|
||||
shape = init_x.shape
|
||||
|
||||
def score_eval_wrapper(data):
|
||||
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||
with torch.no_grad():
|
||||
score = score_model(data)
|
||||
return score.cpu().numpy().reshape((-1,))
|
||||
|
||||
def ode_func(t, x):
|
||||
"""The ODE function for use by the ODE solver."""
|
||||
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||
drift = drift.cpu().numpy()
|
||||
diffusion = diffusion.cpu().numpy()
|
||||
data["sampled_pose"] = x
|
||||
data["t"] = time_steps
|
||||
return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
|
||||
|
||||
# Run the black-box ODE solver, note the
|
||||
t_eval = None
|
||||
if num_steps is not None:
|
||||
# num_steps, from T -> eps
|
||||
t_eval = np.linspace(T, eps, num_steps)
|
||||
res = integrate.solve_ivp(
|
||||
ode_func,
|
||||
(T, eps),
|
||||
init_x.reshape(-1).cpu().numpy(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
method="RK45",
|
||||
t_eval=t_eval,
|
||||
)
|
||||
xs = torch.tensor(res.y, device=device).T.view(
|
||||
-1, batch_size, pose_dim
|
||||
) # [num_steps, bs, pose_dim]
|
||||
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
||||
# denoise, using the predictor step in P-C sampler
|
||||
if denoise:
|
||||
# Reverse diffusion predictor for denoising
|
||||
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
||||
drift, diffusion = sde_coeff(vec_eps)
|
||||
data["sampled_pose"] = x.float()
|
||||
data["t"] = vec_eps
|
||||
grad = score_model(data)
|
||||
drift = drift - diffusion**2 * grad # R-SDE
|
||||
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
||||
x = mean_x
|
||||
|
||||
num_steps = xs.shape[0]
|
||||
xs = xs.reshape(batch_size*num_steps, -1)
|
||||
xs[:, :-3] = PoseUtil.normalize_rotation(xs[:, :-3], pose_mode)
|
||||
xs = xs.reshape(num_steps, batch_size, -1)
|
||||
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
||||
return xs.permute(1, 0, 2), x
|
121
modules/func_lib/sde.py
Normal file
121
modules/func_lib/sde.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import functools
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ----- VE SDE -----
|
||||
# ------------------
|
||||
def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
|
||||
std = sigma_min * (sigma_max / sigma_min) ** t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def ve_sde(t, sigma_min=0.01, sigma_max=90):
|
||||
sigma = sigma_min * (sigma_max / sigma_min) ** t
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
|
||||
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
return torch.randn(*shape) * sigma_max_prior
|
||||
|
||||
|
||||
# ----- VP SDE -----
|
||||
# ------------------
|
||||
def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
||||
return mean, std
|
||||
|
||||
|
||||
def vp_sde(t, beta_0=0.1, beta_1=20):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
diffusion_coeff = torch.sqrt(beta_t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def vp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- sub-VP SDE -----
|
||||
# ----------------------
|
||||
def subvp_marginal_prob(x, t, beta_0, beta_1):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = 1 - torch.exp(2. * log_mean_coeff)
|
||||
return mean, std
|
||||
|
||||
|
||||
def subvp_sde(t, beta_0, beta_1):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
|
||||
diffusion_coeff = torch.sqrt(beta_t * discount)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def subvp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- EDM SDE -----
|
||||
# ------------------
|
||||
def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
|
||||
std = t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def edm_sde(t, sigma_min=0.002, sigma_max=80):
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = torch.sqrt(2 * t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
|
||||
return torch.randn(*shape) * sigma_max
|
||||
|
||||
|
||||
def init_sde(sde_mode):
|
||||
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
|
||||
if sde_mode == 'edm':
|
||||
sigma_min = 0.002
|
||||
sigma_max = 80
|
||||
eps = 0.002
|
||||
prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = sigma_max
|
||||
elif sde_mode == 've':
|
||||
sigma_min = 0.01
|
||||
sigma_max = 50
|
||||
eps = 1e-5
|
||||
marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = 1.0
|
||||
prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
elif sde_mode == 'vp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
elif sde_mode == 'subvp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return prior_fn, marginal_prob_fn, sde_fn, eps, T
|
167
modules/gf_view_finder.py
Normal file
167
modules/gf_view_finder.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
from utils.pose import PoseUtil
|
||||
import modules.module_lib as mlib
|
||||
import modules.func_lib as flib
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
@stereotype.module("gf_view_finder")
|
||||
class GradientFieldViewFinder(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
||||
super(GradientFieldViewFinder, self).__init__()
|
||||
|
||||
self.regression_head = config["regression_head"]
|
||||
self.per_point_feature = config["per_point_feature"]
|
||||
self.act = nn.ReLU(True)
|
||||
self.sample_mode = config["sample_mode"]
|
||||
self.pose_mode = config["pose_mode"]
|
||||
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
|
||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
|
||||
self.sampling_steps = config["sampling_steps"]
|
||||
self.t_feat_dim = config["t_feat_dim"]
|
||||
self.pose_feat_dim = config["pose_feat_dim"]
|
||||
self.main_feat_dim = config["main_feat_dim"]
|
||||
|
||||
''' encode pose '''
|
||||
self.pose_encoder = nn.Sequential(
|
||||
nn.Linear(pose_dim, self.pose_feat_dim ),
|
||||
self.act,
|
||||
nn.Linear(self.pose_feat_dim , self.pose_feat_dim ),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' encode t '''
|
||||
self.t_encoder = nn.Sequential(
|
||||
mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ),
|
||||
nn.Linear(self.t_feat_dim , self.t_feat_dim ),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' fusion tail '''
|
||||
if self.regression_head == 'Rx_Ry_and_T':
|
||||
if self.pose_mode != 'rot_matrix':
|
||||
raise NotImplementedError
|
||||
if not self.per_point_feature:
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
''' tranalation regress head '''
|
||||
self.fusion_tail_trans = nn.Sequential(
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
Args:
|
||||
data, dict {
|
||||
'main_feat': [bs, c]
|
||||
'pose_sample': [bs, pose_dim]
|
||||
't': [bs, 1]
|
||||
}
|
||||
"""
|
||||
|
||||
main_feat = data['main_feat']
|
||||
sampled_pose = data['sampled_pose']
|
||||
t = data['t']
|
||||
t_feat = self.t_encoder(t.squeeze(1))
|
||||
pose_feat = self.pose_encoder(sampled_pose)
|
||||
|
||||
if self.per_point_feature:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
total_feat = torch.cat([main_feat, t_feat, pose_feat], dim=-1)
|
||||
_, std = self.marginal_prob_fn(total_feat, t)
|
||||
|
||||
if self.regression_head == 'Rx_Ry_and_T':
|
||||
rot_x = self.fusion_tail_rot_x(total_feat)
|
||||
rot_y = self.fusion_tail_rot_y(total_feat)
|
||||
trans = self.fusion_tail_trans(total_feat)
|
||||
out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std+1e-7) # normalisation
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return out_score
|
||||
|
||||
def marginal_prob(self, x, t):
|
||||
return self.marginal_prob_fn(x,t)
|
||||
|
||||
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
|
||||
|
||||
if self.sample_mode == 'ode':
|
||||
T0 = self.T if T0 is None else T0
|
||||
in_process_sample, res = flib.cond_ode_sampler(
|
||||
score_model=self,
|
||||
data=data,
|
||||
prior=self.prior_fn,
|
||||
sde_coeff=self.sde_fn,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
eps=self.sampling_eps,
|
||||
T=T0,
|
||||
num_steps=self.sampling_steps,
|
||||
pose_mode=self.pose_mode,
|
||||
denoise=denoise,
|
||||
init_x=init_x
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return in_process_sample, res
|
||||
|
||||
def next_best_view(self, main_feat):
|
||||
data = {
|
||||
'main_feat': main_feat,
|
||||
}
|
||||
in_process_sample, res = self.sample(data)
|
||||
return res.to(dtype=torch.float32), in_process_sample
|
||||
|
||||
|
||||
''' ----------- DEBUG -----------'''
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
"regression_head": "Rx_Ry_and_T",
|
||||
"per_point_feature": False,
|
||||
"pose_mode": "rot_matrix",
|
||||
"sde_mode": "ve",
|
||||
"sampling_steps": 500,
|
||||
"sample_mode": "ode"
|
||||
}
|
||||
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
|
||||
test_pose = torch.rand(32, 9).to("cuda:0")
|
||||
test_t = torch.rand(32, 1).to("cuda:0")
|
||||
view_finder = GradientFieldViewFinder(config).to("cuda:0")
|
||||
test_data = {
|
||||
'seq_feat': test_seq_feat,
|
||||
'sampled_pose': test_pose,
|
||||
't': test_t
|
||||
}
|
||||
score = view_finder(test_data)
|
||||
print(score.shape)
|
||||
res, inprocess = view_finder.next_best_view(test_seq_feat)
|
||||
print(res.shape, inprocess.shape)
|
91
modules/mlp_view_finder.py
Normal file
91
modules/mlp_view_finder.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
from utils.pose import PoseUtil
|
||||
import modules.module_lib as mlib
|
||||
import modules.func_lib as flib
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
@stereotype.module("mlp_view_finder")
|
||||
class MLPViewFinder(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
||||
super(MLPViewFinder, self).__init__()
|
||||
|
||||
self.regression_head = 'Rx_Ry_and_T'
|
||||
self.per_point_feature = False
|
||||
self.act = nn.ReLU(True)
|
||||
self.main_feat_dim = config["main_feat_dim"]
|
||||
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
''' tranalation regress head '''
|
||||
self.fusion_tail_trans = nn.Sequential(
|
||||
nn.Linear(self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
Args:
|
||||
data, dict {
|
||||
'main_feat': [bs, c]
|
||||
}
|
||||
"""
|
||||
|
||||
total_feat = data['main_feat']
|
||||
rot_x = self.fusion_tail_rot_x(total_feat)
|
||||
rot_y = self.fusion_tail_rot_y(total_feat)
|
||||
trans = self.fusion_tail_trans(total_feat)
|
||||
output = torch.cat([rot_x,rot_y,trans], dim=-1)
|
||||
return output
|
||||
|
||||
def next_best_view(self, main_feat):
|
||||
data = {
|
||||
'main_feat': main_feat,
|
||||
}
|
||||
res = self(data)
|
||||
return res.to(dtype=torch.float32), None
|
||||
|
||||
''' ----------- DEBUG -----------'''
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
"regression_head": "Rx_Ry_and_T",
|
||||
"per_point_feature": False,
|
||||
"pose_mode": "rot_matrix",
|
||||
"sde_mode": "ve",
|
||||
"sampling_steps": 500,
|
||||
"sample_mode": "ode"
|
||||
}
|
||||
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
|
||||
test_pose = torch.rand(32, 9).to("cuda:0")
|
||||
test_t = torch.rand(32, 1).to("cuda:0")
|
||||
view_finder = GradientFieldViewFinder(config).to("cuda:0")
|
||||
test_data = {
|
||||
'seq_feat': test_seq_feat,
|
||||
'sampled_pose': test_pose,
|
||||
't': test_t
|
||||
}
|
||||
score = view_finder(test_data)
|
||||
print(score.shape)
|
||||
res, inprocess = view_finder.next_best_view(test_seq_feat)
|
||||
print(res.shape, inprocess.shape)
|
2
modules/module_lib/__init__.py
Normal file
2
modules/module_lib/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection
|
||||
from modules.module_lib.linear import Linear
|
BIN
modules/module_lib/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
modules/module_lib/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
modules/module_lib/__pycache__/linear.cpython-39.pyc
Normal file
BIN
modules/module_lib/__pycache__/linear.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/module_lib/__pycache__/pointnet2_modules.cpython-39.pyc
Normal file
BIN
modules/module_lib/__pycache__/pointnet2_modules.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/module_lib/__pycache__/pointnet2_utils.cpython-39.pyc
Normal file
BIN
modules/module_lib/__pycache__/pointnet2_utils.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/module_lib/__pycache__/pytorch_utils.cpython-39.pyc
Normal file
BIN
modules/module_lib/__pycache__/pytorch_utils.cpython-39.pyc
Normal file
Binary file not shown.
17
modules/module_lib/gaussian_fourier_projection.py
Normal file
17
modules/module_lib/gaussian_fourier_projection.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""Gaussian random features for encoding time steps."""
|
||||
|
||||
def __init__(self, embed_dim, scale=30.):
|
||||
super().__init__()
|
||||
# Randomly sample weights during initialization. These weights are fixed
|
||||
# during optimization and are not trainable.
|
||||
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
||||
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
30
modules/module_lib/linear.py
Normal file
30
modules/module_lib/linear.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def weight_init(shape, mode, fan_in, fan_out):
|
||||
if mode == 'xavier_uniform':
|
||||
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
|
||||
if mode == 'xavier_normal':
|
||||
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
|
||||
if mode == 'kaiming_uniform':
|
||||
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
|
||||
if mode == 'kaiming_normal':
|
||||
return np.sqrt(1 / fan_in) * torch.randn(*shape)
|
||||
raise ValueError(f'Invalid init mode "{mode}"')
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
|
||||
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
|
||||
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
x = x @ self.weight.to(x.dtype).t()
|
||||
if self.bias is not None:
|
||||
x = x.add_(self.bias.to(x.dtype))
|
||||
return x
|
162
modules/module_lib/pointnet2_modules.py
Normal file
162
modules/module_lib/pointnet2_modules.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import pointnet2_utils
|
||||
from . import pytorch_utils as pt_utils
|
||||
from typing import List
|
||||
|
||||
|
||||
class _PointnetSAModuleBase(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.npoint = None
|
||||
self.groupers = None
|
||||
self.mlps = None
|
||||
self.pool_method = 'max_pool'
|
||||
|
||||
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
|
||||
"""
|
||||
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
|
||||
:param features: (B, N, C) tensor of the descriptors of the the features
|
||||
:param new_xyz:
|
||||
:return:
|
||||
new_xyz: (B, npoint, 3) tensor of the new features' xyz
|
||||
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
|
||||
"""
|
||||
new_features_list = []
|
||||
|
||||
xyz_flipped = xyz.transpose(1, 2).contiguous()
|
||||
if new_xyz is None:
|
||||
new_xyz = pointnet2_utils.gather_operation(
|
||||
xyz_flipped,
|
||||
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
|
||||
).transpose(1, 2).contiguous() if self.npoint is not None else None
|
||||
|
||||
for i in range(len(self.groupers)):
|
||||
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
|
||||
|
||||
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
|
||||
|
||||
if self.pool_method == 'max_pool':
|
||||
new_features = F.max_pool2d(
|
||||
new_features, kernel_size=[1, new_features.size(3)]
|
||||
) # (B, mlp[-1], npoint, 1)
|
||||
elif self.pool_method == 'avg_pool':
|
||||
new_features = F.avg_pool2d(
|
||||
new_features, kernel_size=[1, new_features.size(3)]
|
||||
) # (B, mlp[-1], npoint, 1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
|
||||
new_features_list.append(new_features)
|
||||
|
||||
return new_xyz, torch.cat(new_features_list, dim=1)
|
||||
|
||||
|
||||
class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
||||
"""Pointnet set abstraction layer with multiscale grouping"""
|
||||
|
||||
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
|
||||
use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
|
||||
"""
|
||||
:param npoint: int
|
||||
:param radii: list of float, list of radii to group with
|
||||
:param nsamples: list of int, number of samples in each ball query
|
||||
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
|
||||
:param bn: whether to use batchnorm
|
||||
:param use_xyz:
|
||||
:param pool_method: max_pool / avg_pool
|
||||
:param instance_norm: whether to use instance_norm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(radii) == len(nsamples) == len(mlps)
|
||||
|
||||
self.npoint = npoint
|
||||
self.groupers = nn.ModuleList()
|
||||
self.mlps = nn.ModuleList()
|
||||
for i in range(len(radii)):
|
||||
radius = radii[i]
|
||||
nsample = nsamples[i]
|
||||
self.groupers.append(
|
||||
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
||||
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
|
||||
)
|
||||
mlp_spec = mlps[i]
|
||||
if use_xyz:
|
||||
mlp_spec[0] += 3
|
||||
|
||||
self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
|
||||
self.pool_method = pool_method
|
||||
|
||||
|
||||
class PointnetSAModule(PointnetSAModuleMSG):
|
||||
"""Pointnet set abstraction layer"""
|
||||
|
||||
def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
|
||||
bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
|
||||
"""
|
||||
:param mlp: list of int, spec of the pointnet before the global max_pool
|
||||
:param npoint: int, number of features
|
||||
:param radius: float, radius of ball
|
||||
:param nsample: int, number of samples in the ball query
|
||||
:param bn: whether to use batchnorm
|
||||
:param use_xyz:
|
||||
:param pool_method: max_pool / avg_pool
|
||||
:param instance_norm: whether to use instance_norm
|
||||
"""
|
||||
super().__init__(
|
||||
mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
|
||||
pool_method=pool_method, instance_norm=instance_norm
|
||||
)
|
||||
|
||||
|
||||
class PointnetFPModule(nn.Module):
|
||||
r"""Propigates the features of one set to another"""
|
||||
|
||||
def __init__(self, *, mlp: List[int], bn: bool = True):
|
||||
"""
|
||||
:param mlp: list of int
|
||||
:param bn: whether to use batchnorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
|
||||
|
||||
def forward(
|
||||
self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
|
||||
:param known: (B, m, 3) tensor of the xyz positions of the known features
|
||||
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
|
||||
:param known_feats: (B, C2, m) tensor of features to be propigated
|
||||
:return:
|
||||
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
|
||||
"""
|
||||
if known is not None:
|
||||
dist, idx = pointnet2_utils.three_nn(unknown, known)
|
||||
dist_recip = 1.0 / (dist + 1e-8)
|
||||
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
||||
weight = dist_recip / norm
|
||||
|
||||
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
|
||||
else:
|
||||
interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
|
||||
|
||||
if unknow_feats is not None:
|
||||
new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
|
||||
else:
|
||||
new_features = interpolated_feats
|
||||
|
||||
new_features = new_features.unsqueeze(-1)
|
||||
|
||||
new_features = self.mlp(new_features)
|
||||
|
||||
return new_features.squeeze(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
291
modules/module_lib/pointnet2_utils.py
Normal file
291
modules/module_lib/pointnet2_utils.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch.autograd import Function
|
||||
import torch.nn as nn
|
||||
from typing import Tuple
|
||||
import sys
|
||||
|
||||
import pointnet2_cuda as pointnet2
|
||||
|
||||
|
||||
class FurthestPointSampling(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
|
||||
"""
|
||||
Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
||||
minimum distance
|
||||
:param ctx:
|
||||
:param xyz: (B, N, 3) where N > npoint
|
||||
:param npoint: int, number of features in the sampled set
|
||||
:return:
|
||||
output: (B, npoint) tensor containing the set
|
||||
"""
|
||||
assert xyz.is_contiguous()
|
||||
|
||||
B, N, _ = xyz.size()
|
||||
output = torch.cuda.IntTensor(B, npoint)
|
||||
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
|
||||
|
||||
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(xyz, a=None):
|
||||
return None, None
|
||||
|
||||
|
||||
furthest_point_sample = FurthestPointSampling.apply
|
||||
|
||||
|
||||
class GatherOperation(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param ctx:
|
||||
:param features: (B, C, N)
|
||||
:param idx: (B, npoint) index tensor of the features to gather
|
||||
:return:
|
||||
output: (B, C, npoint)
|
||||
"""
|
||||
assert features.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
|
||||
B, npoint = idx.size()
|
||||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, npoint)
|
||||
|
||||
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
|
||||
|
||||
ctx.for_backwards = (idx, C, N)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
idx, C, N = ctx.for_backwards
|
||||
B, npoint = idx.size()
|
||||
|
||||
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
gather_operation = GatherOperation.apply
|
||||
|
||||
|
||||
class ThreeNN(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Find the three nearest neighbors of unknown in known
|
||||
:param ctx:
|
||||
:param unknown: (B, N, 3)
|
||||
:param known: (B, M, 3)
|
||||
:return:
|
||||
dist: (B, N, 3) l2 distance to the three nearest neighbors
|
||||
idx: (B, N, 3) index of 3 nearest neighbors
|
||||
"""
|
||||
assert unknown.is_contiguous()
|
||||
assert known.is_contiguous()
|
||||
|
||||
B, N, _ = unknown.size()
|
||||
m = known.size(1)
|
||||
dist2 = torch.cuda.FloatTensor(B, N, 3)
|
||||
idx = torch.cuda.IntTensor(B, N, 3)
|
||||
|
||||
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
|
||||
return torch.sqrt(dist2), idx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, a=None, b=None):
|
||||
return None, None
|
||||
|
||||
|
||||
three_nn = ThreeNN.apply
|
||||
|
||||
|
||||
class ThreeInterpolate(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Performs weight linear interpolation on 3 features
|
||||
:param ctx:
|
||||
:param features: (B, C, M) Features descriptors to be interpolated from
|
||||
:param idx: (B, n, 3) three nearest neighbors of the target features in features
|
||||
:param weight: (B, n, 3) weights
|
||||
:return:
|
||||
output: (B, C, N) tensor of the interpolated features
|
||||
"""
|
||||
assert features.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
assert weight.is_contiguous()
|
||||
|
||||
B, c, m = features.size()
|
||||
n = idx.size(1)
|
||||
ctx.three_interpolate_for_backward = (idx, weight, m)
|
||||
output = torch.cuda.FloatTensor(B, c, n)
|
||||
|
||||
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
:param ctx:
|
||||
:param grad_out: (B, C, N) tensor with gradients of outputs
|
||||
:return:
|
||||
grad_features: (B, C, M) tensor with gradients of features
|
||||
None:
|
||||
None:
|
||||
"""
|
||||
idx, weight, m = ctx.three_interpolate_for_backward
|
||||
B, c, n = grad_out.size()
|
||||
|
||||
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
|
||||
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
|
||||
return grad_features, None, None
|
||||
|
||||
|
||||
three_interpolate = ThreeInterpolate.apply
|
||||
|
||||
|
||||
class GroupingOperation(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param ctx:
|
||||
:param features: (B, C, N) tensor of features to group
|
||||
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
|
||||
:return:
|
||||
output: (B, C, npoint, nsample) tensor
|
||||
"""
|
||||
assert features.is_contiguous()
|
||||
assert idx.is_contiguous()
|
||||
|
||||
B, nfeatures, nsample = idx.size()
|
||||
_, C, N = features.size()
|
||||
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
||||
|
||||
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
|
||||
|
||||
ctx.for_backwards = (idx, N)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
:param ctx:
|
||||
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
||||
:return:
|
||||
grad_features: (B, C, N) gradient of the features
|
||||
"""
|
||||
idx, N = ctx.for_backwards
|
||||
|
||||
B, C, npoint, nsample = grad_out.size()
|
||||
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
|
||||
|
||||
grad_out_data = grad_out.data.contiguous()
|
||||
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
|
||||
return grad_features, None
|
||||
|
||||
|
||||
grouping_operation = GroupingOperation.apply
|
||||
|
||||
|
||||
class BallQuery(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param ctx:
|
||||
:param radius: float, radius of the balls
|
||||
:param nsample: int, maximum number of features in the balls
|
||||
:param xyz: (B, N, 3) xyz coordinates of the features
|
||||
:param new_xyz: (B, npoint, 3) centers of the ball query
|
||||
:return:
|
||||
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
||||
"""
|
||||
assert new_xyz.is_contiguous()
|
||||
assert xyz.is_contiguous()
|
||||
|
||||
B, N, _ = xyz.size()
|
||||
npoint = new_xyz.size(1)
|
||||
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
|
||||
|
||||
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
|
||||
return idx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, a=None):
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
ball_query = BallQuery.apply
|
||||
|
||||
|
||||
class QueryAndGroup(nn.Module):
|
||||
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
|
||||
"""
|
||||
:param radius: float, radius of ball
|
||||
:param nsample: int, maximum number of features to gather in the ball
|
||||
:param use_xyz:
|
||||
"""
|
||||
super().__init__()
|
||||
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
||||
|
||||
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
:param xyz: (B, N, 3) xyz coordinates of the features
|
||||
:param new_xyz: (B, npoint, 3) centroids
|
||||
:param features: (B, C, N) descriptors of the features
|
||||
:return:
|
||||
new_features: (B, 3 + C, npoint, nsample)
|
||||
"""
|
||||
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
||||
xyz_trans = xyz.transpose(1, 2).contiguous()
|
||||
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
|
||||
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
||||
|
||||
if features is not None:
|
||||
grouped_features = grouping_operation(features, idx)
|
||||
if self.use_xyz:
|
||||
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
|
||||
else:
|
||||
new_features = grouped_features
|
||||
else:
|
||||
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
|
||||
new_features = grouped_xyz
|
||||
|
||||
return new_features
|
||||
|
||||
|
||||
class GroupAll(nn.Module):
|
||||
def __init__(self, use_xyz: bool = True):
|
||||
super().__init__()
|
||||
self.use_xyz = use_xyz
|
||||
|
||||
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
|
||||
"""
|
||||
:param xyz: (B, N, 3) xyz coordinates of the features
|
||||
:param new_xyz: ignored
|
||||
:param features: (B, C, N) descriptors of the features
|
||||
:return:
|
||||
new_features: (B, C + 3, 1, N)
|
||||
"""
|
||||
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
||||
if features is not None:
|
||||
grouped_features = features.unsqueeze(2)
|
||||
if self.use_xyz:
|
||||
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
|
||||
else:
|
||||
new_features = grouped_features
|
||||
else:
|
||||
new_features = grouped_xyz
|
||||
|
||||
return new_features
|
236
modules/module_lib/pytorch_utils.py
Normal file
236
modules/module_lib/pytorch_utils.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import torch.nn as nn
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class SharedMLP(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: List[int],
|
||||
*,
|
||||
bn: bool = False,
|
||||
activation=nn.ReLU(inplace=True),
|
||||
preact: bool = False,
|
||||
first: bool = False,
|
||||
name: str = "",
|
||||
instance_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
for i in range(len(args) - 1):
|
||||
self.add_module(
|
||||
name + 'layer{}'.format(i),
|
||||
Conv2d(
|
||||
args[i],
|
||||
args[i + 1],
|
||||
bn=(not first or not preact or (i != 0)) and bn,
|
||||
activation=activation
|
||||
if (not first or not preact or (i != 0)) else None,
|
||||
preact=preact,
|
||||
instance_norm=instance_norm
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class _ConvBase(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation,
|
||||
bn,
|
||||
init,
|
||||
conv=None,
|
||||
batch_norm=None,
|
||||
bias=True,
|
||||
preact=False,
|
||||
name="",
|
||||
instance_norm=False,
|
||||
instance_norm_func=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
bias = bias and (not bn)
|
||||
conv_unit = conv(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias
|
||||
)
|
||||
init(conv_unit.weight)
|
||||
if bias:
|
||||
nn.init.constant_(conv_unit.bias, 0)
|
||||
|
||||
if bn:
|
||||
if not preact:
|
||||
bn_unit = batch_norm(out_size)
|
||||
else:
|
||||
bn_unit = batch_norm(in_size)
|
||||
if instance_norm:
|
||||
if not preact:
|
||||
in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
|
||||
else:
|
||||
in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
|
||||
|
||||
if preact:
|
||||
if bn:
|
||||
self.add_module(name + 'bn', bn_unit)
|
||||
|
||||
if activation is not None:
|
||||
self.add_module(name + 'activation', activation)
|
||||
|
||||
if not bn and instance_norm:
|
||||
self.add_module(name + 'in', in_unit)
|
||||
|
||||
self.add_module(name + 'conv', conv_unit)
|
||||
|
||||
if not preact:
|
||||
if bn:
|
||||
self.add_module(name + 'bn', bn_unit)
|
||||
|
||||
if activation is not None:
|
||||
self.add_module(name + 'activation', activation)
|
||||
|
||||
if not bn and instance_norm:
|
||||
self.add_module(name + 'in', in_unit)
|
||||
|
||||
|
||||
class _BNBase(nn.Sequential):
|
||||
|
||||
def __init__(self, in_size, batch_norm=None, name=""):
|
||||
super().__init__()
|
||||
self.add_module(name + "bn", batch_norm(in_size))
|
||||
|
||||
nn.init.constant_(self[0].weight, 1.0)
|
||||
nn.init.constant_(self[0].bias, 0)
|
||||
|
||||
|
||||
class BatchNorm1d(_BNBase):
|
||||
|
||||
def __init__(self, in_size: int, *, name: str = ""):
|
||||
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
|
||||
|
||||
|
||||
class BatchNorm2d(_BNBase):
|
||||
|
||||
def __init__(self, in_size: int, name: str = ""):
|
||||
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
|
||||
|
||||
|
||||
class Conv1d(_ConvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size: int,
|
||||
out_size: int,
|
||||
*,
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
activation=nn.ReLU(inplace=True),
|
||||
bn: bool = False,
|
||||
init=nn.init.kaiming_normal_,
|
||||
bias: bool = True,
|
||||
preact: bool = False,
|
||||
name: str = "",
|
||||
instance_norm=False
|
||||
):
|
||||
super().__init__(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation,
|
||||
bn,
|
||||
init,
|
||||
conv=nn.Conv1d,
|
||||
batch_norm=BatchNorm1d,
|
||||
bias=bias,
|
||||
preact=preact,
|
||||
name=name,
|
||||
instance_norm=instance_norm,
|
||||
instance_norm_func=nn.InstanceNorm1d
|
||||
)
|
||||
|
||||
|
||||
class Conv2d(_ConvBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size: int,
|
||||
out_size: int,
|
||||
*,
|
||||
kernel_size: Tuple[int, int] = (1, 1),
|
||||
stride: Tuple[int, int] = (1, 1),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
activation=nn.ReLU(inplace=True),
|
||||
bn: bool = False,
|
||||
init=nn.init.kaiming_normal_,
|
||||
bias: bool = True,
|
||||
preact: bool = False,
|
||||
name: str = "",
|
||||
instance_norm=False
|
||||
):
|
||||
super().__init__(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation,
|
||||
bn,
|
||||
init,
|
||||
conv=nn.Conv2d,
|
||||
batch_norm=BatchNorm2d,
|
||||
bias=bias,
|
||||
preact=preact,
|
||||
name=name,
|
||||
instance_norm=instance_norm,
|
||||
instance_norm_func=nn.InstanceNorm2d
|
||||
)
|
||||
|
||||
|
||||
class FC(nn.Sequential):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_size: int,
|
||||
out_size: int,
|
||||
*,
|
||||
activation=nn.ReLU(inplace=True),
|
||||
bn: bool = False,
|
||||
init=None,
|
||||
preact: bool = False,
|
||||
name: str = ""
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fc = nn.Linear(in_size, out_size, bias=not bn)
|
||||
if init is not None:
|
||||
init(fc.weight)
|
||||
if not bn:
|
||||
nn.init.constant(fc.bias, 0)
|
||||
|
||||
if preact:
|
||||
if bn:
|
||||
self.add_module(name + 'bn', BatchNorm1d(in_size))
|
||||
|
||||
if activation is not None:
|
||||
self.add_module(name + 'activation', activation)
|
||||
|
||||
self.add_module(name + 'fc', fc)
|
||||
|
||||
if not preact:
|
||||
if bn:
|
||||
self.add_module(name + 'bn', BatchNorm1d(out_size))
|
||||
|
||||
if activation is not None:
|
||||
self.add_module(name + 'activation', activation)
|
||||
|
149
modules/pointnet++_encoder.py
Normal file
149
modules/pointnet++_encoder.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import sys
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(2):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG
|
||||
|
||||
|
||||
ClsMSG_CFG_Dense = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 512], [256, 384, 512]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Light = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 512], [256, 384, 512]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Light_2048 = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 1024], [256, 512, 1024]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Strong = {
|
||||
'NPOINTS': [512, 256, 128, 64, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16],[0.16, 0.32], [None, None]],
|
||||
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 512], [256, 512, 512]],
|
||||
[[512, 512, 2048], [512, 1024, 2048]]
|
||||
],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Lighter = {
|
||||
'NPOINTS': [512, 256, 128, 64, None],
|
||||
'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
|
||||
'NSAMPLE': [[64], [32], [16], [8], [None]],
|
||||
'MLPS': [[[32, 32, 64]],
|
||||
[[64, 64, 128]],
|
||||
[[128, 196, 256]],
|
||||
[[256, 256, 512]],
|
||||
[[512, 512, 1024]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
|
||||
def select_params(name):
|
||||
if name == 'light':
|
||||
return ClsMSG_CFG_Light
|
||||
elif name == 'lighter':
|
||||
return ClsMSG_CFG_Lighter
|
||||
elif name == 'dense':
|
||||
return ClsMSG_CFG_Dense
|
||||
elif name == 'light_2048':
|
||||
return ClsMSG_CFG_Light_2048
|
||||
elif name == 'strong':
|
||||
return ClsMSG_CFG_Strong
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def break_up_pc(pc):
|
||||
xyz = pc[..., 0:3].contiguous()
|
||||
features = (
|
||||
pc[..., 3:].transpose(1, 2).contiguous()
|
||||
if pc.size(-1) > 3 else None
|
||||
)
|
||||
|
||||
return xyz, features
|
||||
|
||||
|
||||
@stereotype.module("pointnet++_encoder")
|
||||
class PointNet2Encoder(nn.Module):
|
||||
def encode_points(self, pts, require_per_point_feat=False):
|
||||
return self.forward(pts)
|
||||
|
||||
def __init__(self, config:dict):
|
||||
super().__init__()
|
||||
|
||||
channel_in = config.get("in_dim", 3) - 3
|
||||
params_name = config.get("params_name", "light")
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
selected_params = select_params(params_name)
|
||||
for k in range(selected_params['NPOINTS'].__len__()):
|
||||
mlps = selected_params['MLPS'][k].copy()
|
||||
channel_out = 0
|
||||
for idx in range(mlps.__len__()):
|
||||
mlps[idx] = [channel_in] + mlps[idx]
|
||||
channel_out += mlps[idx][-1]
|
||||
|
||||
self.SA_modules.append(
|
||||
PointnetSAModuleMSG(
|
||||
npoint=selected_params['NPOINTS'][k],
|
||||
radii=selected_params['RADIUS'][k],
|
||||
nsamples=selected_params['NSAMPLE'][k],
|
||||
mlps=mlps,
|
||||
use_xyz=True,
|
||||
bn=True
|
||||
)
|
||||
)
|
||||
channel_in = channel_out
|
||||
|
||||
def forward(self, point_cloud: torch.cuda.FloatTensor):
|
||||
xyz, features = break_up_pc(point_cloud)
|
||||
|
||||
l_xyz, l_features = [xyz], [features]
|
||||
for i in range(len(self.SA_modules)):
|
||||
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
|
||||
l_xyz.append(li_xyz)
|
||||
l_features.append(li_features)
|
||||
return l_features[-1].squeeze(-1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
seed = 100
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
net = PointNet2Encoder(config={"in_dim": 3, "params_name": "strong"}).cuda()
|
||||
pts = torch.randn(2, 2444, 3).cuda()
|
||||
print(torch.mean(pts, dim=1))
|
||||
pre = net.encode_points(pts)
|
||||
print(pre.shape)
|
107
modules/pointnet_encoder.py
Normal file
107
modules/pointnet_encoder.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.utils.data
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
@stereotype.module("pointnet_encoder")
|
||||
class PointNetEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config:dict):
|
||||
super(PointNetEncoder, self).__init__()
|
||||
|
||||
self.out_dim = config["out_dim"]
|
||||
self.in_dim = config["in_dim"]
|
||||
self.feature_transform = config.get("feature_transform", False)
|
||||
self.stn = STNkd(k=self.in_dim)
|
||||
self.conv1 = torch.nn.Conv1d(self.in_dim , 64, 1)
|
||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||
self.conv3 = torch.nn.Conv1d(128, 512, 1)
|
||||
self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1)
|
||||
if self.feature_transform:
|
||||
self.f_stn = STNkd(k=64)
|
||||
|
||||
def forward(self, x):
|
||||
trans = self.stn(x)
|
||||
x = x.transpose(2, 1)
|
||||
x = torch.bmm(x, trans)
|
||||
x = x.transpose(2, 1)
|
||||
x = F.relu(self.conv1(x))
|
||||
|
||||
if self.feature_transform:
|
||||
trans_feat = self.f_stn(x)
|
||||
x = x.transpose(2, 1)
|
||||
x = torch.bmm(x, trans_feat)
|
||||
x = x.transpose(2, 1)
|
||||
|
||||
point_feat = x
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = self.conv4(x)
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
x = x.view(-1, self.out_dim)
|
||||
return x, point_feat
|
||||
|
||||
def encode_points(self, pts, require_per_point_feat=False):
|
||||
pts = pts.transpose(2, 1)
|
||||
global_pts_feature, per_point_feature = self(pts)
|
||||
if require_per_point_feat:
|
||||
return global_pts_feature, per_point_feature.transpose(2, 1)
|
||||
else:
|
||||
return global_pts_feature
|
||||
|
||||
class STNkd(nn.Module):
|
||||
def __init__(self, k=64):
|
||||
super(STNkd, self).__init__()
|
||||
self.conv1 = torch.nn.Conv1d(k, 64, 1)
|
||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
||||
self.fc1 = nn.Linear(1024, 512)
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.fc3 = nn.Linear(256, k * k)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.k = k
|
||||
|
||||
def forward(self, x):
|
||||
batchsize = x.size()[0]
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
x = x.view(-1, 1024)
|
||||
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
|
||||
iden = (
|
||||
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
|
||||
.view(1, self.k * self.k)
|
||||
.repeat(batchsize, 1)
|
||||
)
|
||||
if x.is_cuda:
|
||||
iden = iden.to(x.get_device())
|
||||
x = x + iden
|
||||
x = x.view(-1, self.k, self.k)
|
||||
return x
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim_data = Variable(torch.rand(32, 2500, 3))
|
||||
config = {
|
||||
"in_dim": 3,
|
||||
"out_dim": 1024,
|
||||
"feature_transform": False
|
||||
}
|
||||
pointnet = PointNetEncoder(config)
|
||||
out = pointnet.encode_points(sim_data)
|
||||
|
||||
print("global feat", out.size())
|
||||
|
||||
out, per_point_out = pointnet.encode_points(sim_data, require_per_point_feat=True)
|
||||
print("point feat", out.size())
|
||||
print("per point feat", per_point_out.size())
|
21
modules/pose_encoder.py
Normal file
21
modules/pose_encoder.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from torch import nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.module("pose_encoder")
|
||||
class PoseEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PoseEncoder, self).__init__()
|
||||
self.config = config
|
||||
pose_dim = config["pose_dim"]
|
||||
out_dim = config["out_dim"]
|
||||
self.act = nn.ReLU(True)
|
||||
|
||||
self.pose_encoder = nn.Sequential(
|
||||
nn.Linear(pose_dim, out_dim),
|
||||
self.act,
|
||||
nn.Linear(out_dim, out_dim),
|
||||
self.act,
|
||||
)
|
||||
|
||||
def encode_pose(self, pose):
|
||||
return self.pose_encoder(pose)
|
20
modules/pts_num_encoder.py
Normal file
20
modules/pts_num_encoder.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from torch import nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.module("pts_num_encoder")
|
||||
class PointsNumEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PointsNumEncoder, self).__init__()
|
||||
self.config = config
|
||||
out_dim = config["out_dim"]
|
||||
self.act = nn.ReLU(True)
|
||||
|
||||
self.pts_num_encoder = nn.Sequential(
|
||||
nn.Linear(1, out_dim),
|
||||
self.act,
|
||||
nn.Linear(out_dim, out_dim),
|
||||
self.act,
|
||||
)
|
||||
|
||||
def encode_pts_num(self, num_seq):
|
||||
return self.pts_num_encoder(num_seq)
|
63
modules/transformer_seq_encoder.py
Normal file
63
modules/transformer_seq_encoder.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
|
||||
@stereotype.module("transformer_seq_encoder")
|
||||
class TransformerSequenceEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TransformerSequenceEncoder, self).__init__()
|
||||
self.config = config
|
||||
embed_dim = config["embed_dim"]
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=embed_dim,
|
||||
nhead=config["num_heads"],
|
||||
dim_feedforward=config["ffn_dim"],
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer_encoder = nn.TransformerEncoder(
|
||||
encoder_layer, num_layers=config["num_layers"]
|
||||
)
|
||||
self.fc = nn.Linear(embed_dim, config["output_dim"])
|
||||
|
||||
def encode_sequence(self, embedding_list_batch):
|
||||
|
||||
lengths = []
|
||||
|
||||
for embedding_list in embedding_list_batch:
|
||||
lengths.append(len(embedding_list))
|
||||
|
||||
embedding_tensor = pad_sequence(embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
|
||||
|
||||
max_len = max(lengths)
|
||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(embedding_tensor.device)
|
||||
|
||||
transformer_output = self.transformer_encoder(embedding_tensor, src_key_padding_mask=padding_mask)
|
||||
final_feature = transformer_output.mean(dim=1)
|
||||
final_output = self.fc(final_feature)
|
||||
|
||||
return final_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
"embed_dim": 256,
|
||||
"num_heads": 4,
|
||||
"ffn_dim": 256,
|
||||
"num_layers": 3,
|
||||
"output_dim": 1024,
|
||||
}
|
||||
|
||||
encoder = TransformerSequenceEncoder(config)
|
||||
seq_len = [5, 8, 9, 4]
|
||||
batch_size = 4
|
||||
|
||||
embedding_list_batch = [
|
||||
torch.randn(seq_len[idx], config["embed_dim"]) for idx in range(batch_size)
|
||||
]
|
||||
output_feature = encoder.encode_sequence(
|
||||
embedding_list_batch
|
||||
)
|
||||
print("Encoded Feature:", output_feature)
|
||||
print("Feature Shape:", output_feature.shape)
|
Reference in New Issue
Block a user