first commit

This commit is contained in:
2025-05-13 09:03:38 +08:00
commit b98753bfbb
121 changed files with 8665 additions and 0 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,6 @@
from modules.func_lib.samplers import (
cond_ode_sampler
)
from modules.func_lib.sde import (
init_sde
)

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,95 @@
import torch
import numpy as np
from scipy import integrate
from utils.pose import PoseUtil
def global_prior_likelihood(z, sigma_max):
"""The likelihood of a Gaussian distribution with mean zero and
standard deviation sigma."""
# z: [bs, pose_dim]
shape = z.shape
N = np.prod(shape[1:]) # pose_dim
return -N / 2.0 * torch.log(2 * np.pi * sigma_max**2) - torch.sum(
z**2, dim=-1
) / (2 * sigma_max**2)
def cond_ode_sampler(
score_model,
data,
prior,
sde_coeff,
atol=1e-5,
rtol=1e-5,
device="cuda",
eps=1e-5,
T=1.0,
num_steps=None,
pose_mode="quat_wxyz",
denoise=True,
init_x=None,
):
pose_dim = PoseUtil.get_pose_dim(pose_mode)
batch_size = data["main_feat"].shape[0]
init_x = (
prior((batch_size, pose_dim), T=T).to(device)
if init_x is None
else init_x + prior((batch_size, pose_dim), T=T).to(device)
)
shape = init_x.shape
def score_eval_wrapper(data):
"""A wrapper of the score-based model for use by the ODE solver."""
with torch.no_grad():
score = score_model(data)
return score.cpu().numpy().reshape((-1,))
def ode_func(t, x):
"""The ODE function for use by the ODE solver."""
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
drift, diffusion = sde_coeff(torch.tensor(t))
drift = drift.cpu().numpy()
diffusion = diffusion.cpu().numpy()
data["sampled_pose"] = x
data["t"] = time_steps
return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
# Run the black-box ODE solver, note the
t_eval = None
if num_steps is not None:
# num_steps, from T -> eps
t_eval = np.linspace(T, eps, num_steps)
res = integrate.solve_ivp(
ode_func,
(T, eps),
init_x.reshape(-1).cpu().numpy(),
rtol=rtol,
atol=atol,
method="RK45",
t_eval=t_eval,
)
xs = torch.tensor(res.y, device=device).T.view(
-1, batch_size, pose_dim
) # [num_steps, bs, pose_dim]
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
# denoise, using the predictor step in P-C sampler
if denoise:
# Reverse diffusion predictor for denoising
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
drift, diffusion = sde_coeff(vec_eps)
data["sampled_pose"] = x.float()
data["t"] = vec_eps
grad = score_model(data)
drift = drift - diffusion**2 * grad # R-SDE
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
x = mean_x
num_steps = xs.shape[0]
xs = xs.reshape(batch_size*num_steps, -1)
xs[:, :-3] = PoseUtil.normalize_rotation(xs[:, :-3], pose_mode)
xs = xs.reshape(num_steps, batch_size, -1)
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
return xs.permute(1, 0, 2), x

121
modules/func_lib/sde.py Normal file
View File

@@ -0,0 +1,121 @@
import functools
import torch
import numpy as np
# ----- VE SDE -----
# ------------------
def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
std = sigma_min * (sigma_max / sigma_min) ** t
mean = x
return mean, std
def ve_sde(t, sigma_min=0.01, sigma_max=90):
sigma = sigma_min * (sigma_max / sigma_min) ** t
drift_coeff = torch.tensor(0)
diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
return drift_coeff, diffusion_coeff
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
return torch.randn(*shape) * sigma_max_prior
# ----- VP SDE -----
# ------------------
def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
mean = torch.exp(log_mean_coeff) * x
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
return mean, std
def vp_sde(t, beta_0=0.1, beta_1=20):
beta_t = beta_0 + t * (beta_1 - beta_0)
drift_coeff = -0.5 * beta_t
diffusion_coeff = torch.sqrt(beta_t)
return drift_coeff, diffusion_coeff
def vp_prior(shape, beta_0=0.1, beta_1=20):
return torch.randn(*shape)
# ----- sub-VP SDE -----
# ----------------------
def subvp_marginal_prob(x, t, beta_0, beta_1):
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
mean = torch.exp(log_mean_coeff) * x
std = 1 - torch.exp(2. * log_mean_coeff)
return mean, std
def subvp_sde(t, beta_0, beta_1):
beta_t = beta_0 + t * (beta_1 - beta_0)
drift_coeff = -0.5 * beta_t
discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
diffusion_coeff = torch.sqrt(beta_t * discount)
return drift_coeff, diffusion_coeff
def subvp_prior(shape, beta_0=0.1, beta_1=20):
return torch.randn(*shape)
# ----- EDM SDE -----
# ------------------
def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
std = t
mean = x
return mean, std
def edm_sde(t, sigma_min=0.002, sigma_max=80):
drift_coeff = torch.tensor(0)
diffusion_coeff = torch.sqrt(2 * t)
return drift_coeff, diffusion_coeff
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
return torch.randn(*shape) * sigma_max
def init_sde(sde_mode):
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
if sde_mode == 'edm':
sigma_min = 0.002
sigma_max = 80
eps = 0.002
prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
T = sigma_max
elif sde_mode == 've':
sigma_min = 0.01
sigma_max = 50
eps = 1e-5
marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
T = 1.0
prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
elif sde_mode == 'vp':
beta_0 = 0.1
beta_1 = 20
eps = 1e-3
prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
T = 1.0
elif sde_mode == 'subvp':
beta_0 = 0.1
beta_1 = 20
eps = 1e-3
prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
T = 1.0
else:
raise NotImplementedError
return prior_fn, marginal_prob_fn, sde_fn, eps, T

167
modules/gf_view_finder.py Normal file
View File

@@ -0,0 +1,167 @@
import torch
import torch.nn as nn
import PytorchBoot.stereotype as stereotype
from utils.pose import PoseUtil
import modules.module_lib as mlib
import modules.func_lib as flib
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
@stereotype.module("gf_view_finder")
class GradientFieldViewFinder(nn.Module):
def __init__(self, config):
super(GradientFieldViewFinder, self).__init__()
self.regression_head = config["regression_head"]
self.per_point_feature = config["per_point_feature"]
self.act = nn.ReLU(True)
self.sample_mode = config["sample_mode"]
self.pose_mode = config["pose_mode"]
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
self.sampling_steps = config["sampling_steps"]
self.t_feat_dim = config["t_feat_dim"]
self.pose_feat_dim = config["pose_feat_dim"]
self.main_feat_dim = config["main_feat_dim"]
''' encode pose '''
self.pose_encoder = nn.Sequential(
nn.Linear(pose_dim, self.pose_feat_dim ),
self.act,
nn.Linear(self.pose_feat_dim , self.pose_feat_dim ),
self.act,
)
''' encode t '''
self.t_encoder = nn.Sequential(
mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ),
nn.Linear(self.t_feat_dim , self.t_feat_dim ),
self.act,
)
''' fusion tail '''
if self.regression_head == 'Rx_Ry_and_T':
if self.pose_mode != 'rot_matrix':
raise NotImplementedError
if not self.per_point_feature:
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
''' tranalation regress head '''
self.fusion_tail_trans = nn.Sequential(
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
else:
raise NotImplementedError
else:
raise NotImplementedError
def forward(self, data):
"""
Args:
data, dict {
'main_feat': [bs, c]
'pose_sample': [bs, pose_dim]
't': [bs, 1]
}
"""
main_feat = data['main_feat']
sampled_pose = data['sampled_pose']
t = data['t']
t_feat = self.t_encoder(t.squeeze(1))
pose_feat = self.pose_encoder(sampled_pose)
if self.per_point_feature:
raise NotImplementedError
else:
total_feat = torch.cat([main_feat, t_feat, pose_feat], dim=-1)
_, std = self.marginal_prob_fn(total_feat, t)
if self.regression_head == 'Rx_Ry_and_T':
rot_x = self.fusion_tail_rot_x(total_feat)
rot_y = self.fusion_tail_rot_y(total_feat)
trans = self.fusion_tail_trans(total_feat)
out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std+1e-7) # normalisation
else:
raise NotImplementedError
return out_score
def marginal_prob(self, x, t):
return self.marginal_prob_fn(x,t)
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
if self.sample_mode == 'ode':
T0 = self.T if T0 is None else T0
in_process_sample, res = flib.cond_ode_sampler(
score_model=self,
data=data,
prior=self.prior_fn,
sde_coeff=self.sde_fn,
atol=atol,
rtol=rtol,
eps=self.sampling_eps,
T=T0,
num_steps=self.sampling_steps,
pose_mode=self.pose_mode,
denoise=denoise,
init_x=init_x
)
else:
raise NotImplementedError
return in_process_sample, res
def next_best_view(self, main_feat):
data = {
'main_feat': main_feat,
}
in_process_sample, res = self.sample(data)
return res.to(dtype=torch.float32), in_process_sample
''' ----------- DEBUG -----------'''
if __name__ == "__main__":
config = {
"regression_head": "Rx_Ry_and_T",
"per_point_feature": False,
"pose_mode": "rot_matrix",
"sde_mode": "ve",
"sampling_steps": 500,
"sample_mode": "ode"
}
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
test_pose = torch.rand(32, 9).to("cuda:0")
test_t = torch.rand(32, 1).to("cuda:0")
view_finder = GradientFieldViewFinder(config).to("cuda:0")
test_data = {
'seq_feat': test_seq_feat,
'sampled_pose': test_pose,
't': test_t
}
score = view_finder(test_data)
print(score.shape)
res, inprocess = view_finder.next_best_view(test_seq_feat)
print(res.shape, inprocess.shape)

View File

@@ -0,0 +1,91 @@
import torch
import torch.nn as nn
import PytorchBoot.stereotype as stereotype
from utils.pose import PoseUtil
import modules.module_lib as mlib
import modules.func_lib as flib
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
@stereotype.module("mlp_view_finder")
class MLPViewFinder(nn.Module):
def __init__(self, config):
super(MLPViewFinder, self).__init__()
self.regression_head = 'Rx_Ry_and_T'
self.per_point_feature = False
self.act = nn.ReLU(True)
self.main_feat_dim = config["main_feat_dim"]
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
''' tranalation regress head '''
self.fusion_tail_trans = nn.Sequential(
nn.Linear(self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
def forward(self, data):
"""
Args:
data, dict {
'main_feat': [bs, c]
}
"""
total_feat = data['main_feat']
rot_x = self.fusion_tail_rot_x(total_feat)
rot_y = self.fusion_tail_rot_y(total_feat)
trans = self.fusion_tail_trans(total_feat)
output = torch.cat([rot_x,rot_y,trans], dim=-1)
return output
def next_best_view(self, main_feat):
data = {
'main_feat': main_feat,
}
res = self(data)
return res.to(dtype=torch.float32), None
''' ----------- DEBUG -----------'''
if __name__ == "__main__":
config = {
"regression_head": "Rx_Ry_and_T",
"per_point_feature": False,
"pose_mode": "rot_matrix",
"sde_mode": "ve",
"sampling_steps": 500,
"sample_mode": "ode"
}
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
test_pose = torch.rand(32, 9).to("cuda:0")
test_t = torch.rand(32, 1).to("cuda:0")
view_finder = GradientFieldViewFinder(config).to("cuda:0")
test_data = {
'seq_feat': test_seq_feat,
'sampled_pose': test_pose,
't': test_t
}
score = view_finder(test_data)
print(score.shape)
res, inprocess = view_finder.next_best_view(test_seq_feat)
print(res.shape, inprocess.shape)

View File

@@ -0,0 +1,2 @@
from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection
from modules.module_lib.linear import Linear

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,17 @@
import torch
import numpy as np
import torch.nn as nn
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

View File

@@ -0,0 +1,30 @@
import torch
import numpy as np
def weight_init(shape, mode, fan_in, fan_out):
if mode == 'xavier_uniform':
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == 'xavier_normal':
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == 'kaiming_uniform':
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == 'kaiming_normal':
return np.sqrt(1 / fan_in) * torch.randn(*shape)
raise ValueError(f'Invalid init mode "{mode}"')
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
def forward(self, x):
x = x @ self.weight.to(x.dtype).t()
if self.bias is not None:
x = x.add_(self.bias.to(x.dtype))
return x

View File

@@ -0,0 +1,162 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import pointnet2_utils
from . import pytorch_utils as pt_utils
from typing import List
class _PointnetSAModuleBase(nn.Module):
def __init__(self):
super().__init__()
self.npoint = None
self.groupers = None
self.mlps = None
self.pool_method = 'max_pool'
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
"""
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
:param features: (B, N, C) tensor of the descriptors of the the features
:param new_xyz:
:return:
new_xyz: (B, npoint, 3) tensor of the new features' xyz
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous()
if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation(
xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)):
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
if self.pool_method == 'max_pool':
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
elif self.pool_method == 'avg_pool':
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
else:
raise NotImplementedError
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1)
class PointnetSAModuleMSG(_PointnetSAModuleBase):
"""Pointnet set abstraction layer with multiscale grouping"""
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
"""
:param npoint: int
:param radii: list of float, list of radii to group with
:param nsamples: list of int, number of samples in each ball query
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
:param instance_norm: whether to use instance_norm
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
)
mlp_spec = mlps[i]
if use_xyz:
mlp_spec[0] += 3
self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
self.pool_method = pool_method
class PointnetSAModule(PointnetSAModuleMSG):
"""Pointnet set abstraction layer"""
def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
"""
:param mlp: list of int, spec of the pointnet before the global max_pool
:param npoint: int, number of features
:param radius: float, radius of ball
:param nsample: int, number of samples in the ball query
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
:param instance_norm: whether to use instance_norm
"""
super().__init__(
mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
pool_method=pool_method, instance_norm=instance_norm
)
class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another"""
def __init__(self, *, mlp: List[int], bn: bool = True):
"""
:param mlp: list of int
:param bn: whether to use batchnorm
"""
super().__init__()
self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
def forward(
self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
) -> torch.Tensor:
"""
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
:param known: (B, m, 3) tensor of the xyz positions of the known features
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
:param known_feats: (B, C2, m) tensor of features to be propigated
:return:
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
"""
if known is not None:
dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
else:
interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
if unknow_feats is not None:
new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlp(new_features)
return new_features.squeeze(-1)
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,291 @@
import torch
from torch.autograd import Variable
from torch.autograd import Function
import torch.nn as nn
from typing import Tuple
import sys
import pointnet2_cuda as pointnet2
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
:param npoint: int, number of features in the sampled set
:return:
output: (B, npoint) tensor containing the set
"""
assert xyz.is_contiguous()
B, N, _ = xyz.size()
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
class GatherOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N)
:param idx: (B, npoint) index tensor of the features to gather
:return:
output: (B, C, npoint)
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, npoint = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
ctx.for_backwards = (idx, C, N)
return output
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
return grad_features, None
gather_operation = GatherOperation.apply
class ThreeNN(Function):
@staticmethod
def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find the three nearest neighbors of unknown in known
:param ctx:
:param unknown: (B, N, 3)
:param known: (B, M, 3)
:return:
dist: (B, N, 3) l2 distance to the three nearest neighbors
idx: (B, N, 3) index of 3 nearest neighbors
"""
assert unknown.is_contiguous()
assert known.is_contiguous()
B, N, _ = unknown.size()
m = known.size(1)
dist2 = torch.cuda.FloatTensor(B, N, 3)
idx = torch.cuda.IntTensor(B, N, 3)
pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
return torch.sqrt(dist2), idx
@staticmethod
def backward(ctx, a=None, b=None):
return None, None
three_nn = ThreeNN.apply
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Performs weight linear interpolation on 3 features
:param ctx:
:param features: (B, C, M) Features descriptors to be interpolated from
:param idx: (B, n, 3) three nearest neighbors of the target features in features
:param weight: (B, n, 3) weights
:return:
output: (B, C, N) tensor of the interpolated features
"""
assert features.is_contiguous()
assert idx.is_contiguous()
assert weight.is_contiguous()
B, c, m = features.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
output = torch.cuda.FloatTensor(B, c, n)
pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, N) tensor with gradients of outputs
:return:
grad_features: (B, C, M) tensor with gradients of features
None:
None:
"""
idx, weight, m = ctx.three_interpolate_for_backward
B, c, n = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
return grad_features, None, None
three_interpolate = ThreeInterpolate.apply
class GroupingOperation(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param features: (B, C, N) tensor of features to group
:param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
:return:
output: (B, C, npoint, nsample) tensor
"""
assert features.is_contiguous()
assert idx.is_contiguous()
B, nfeatures, nsample = idx.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
ctx.for_backwards = (idx, N)
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param ctx:
:param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
:return:
grad_features: (B, C, N) gradient of the features
"""
idx, N = ctx.for_backwards
B, C, npoint, nsample = grad_out.size()
grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
grad_out_data = grad_out.data.contiguous()
pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
return grad_features, None
grouping_operation = GroupingOperation.apply
class BallQuery(Function):
@staticmethod
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param radius: float, radius of the balls
:param nsample: int, maximum number of features in the balls
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centers of the ball query
:return:
idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
assert new_xyz.is_contiguous()
assert xyz.is_contiguous()
B, N, _ = xyz.size()
npoint = new_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
class QueryAndGroup(nn.Module):
def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
"""
:param radius: float, radius of ball
:param nsample: int, maximum number of features to gather in the ball
:param use_xyz:
"""
super().__init__()
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: (B, npoint, 3) centroids
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, 3 + C, npoint, nsample)
"""
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
xyz_trans = xyz.transpose(1, 2).contiguous()
grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
if features is not None:
grouped_features = grouping_operation(features, idx)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
else:
new_features = grouped_features
else:
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
new_features = grouped_xyz
return new_features
class GroupAll(nn.Module):
def __init__(self, use_xyz: bool = True):
super().__init__()
self.use_xyz = use_xyz
def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
"""
:param xyz: (B, N, 3) xyz coordinates of the features
:param new_xyz: ignored
:param features: (B, C, N) descriptors of the features
:return:
new_features: (B, C + 3, 1, N)
"""
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
if features is not None:
grouped_features = features.unsqueeze(2)
if self.use_xyz:
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
else:
new_features = grouped_features
else:
new_features = grouped_xyz
return new_features

View File

@@ -0,0 +1,236 @@
import torch.nn as nn
from typing import List, Tuple
class SharedMLP(nn.Sequential):
def __init__(
self,
args: List[int],
*,
bn: bool = False,
activation=nn.ReLU(inplace=True),
preact: bool = False,
first: bool = False,
name: str = "",
instance_norm: bool = False,
):
super().__init__()
for i in range(len(args) - 1):
self.add_module(
name + 'layer{}'.format(i),
Conv2d(
args[i],
args[i + 1],
bn=(not first or not preact or (i != 0)) and bn,
activation=activation
if (not first or not preact or (i != 0)) else None,
preact=preact,
instance_norm=instance_norm
)
)
class _ConvBase(nn.Sequential):
def __init__(
self,
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=None,
batch_norm=None,
bias=True,
preact=False,
name="",
instance_norm=False,
instance_norm_func=None
):
super().__init__()
bias = bias and (not bn)
conv_unit = conv(
in_size,
out_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
init(conv_unit.weight)
if bias:
nn.init.constant_(conv_unit.bias, 0)
if bn:
if not preact:
bn_unit = batch_norm(out_size)
else:
bn_unit = batch_norm(in_size)
if instance_norm:
if not preact:
in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
else:
in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
if preact:
if bn:
self.add_module(name + 'bn', bn_unit)
if activation is not None:
self.add_module(name + 'activation', activation)
if not bn and instance_norm:
self.add_module(name + 'in', in_unit)
self.add_module(name + 'conv', conv_unit)
if not preact:
if bn:
self.add_module(name + 'bn', bn_unit)
if activation is not None:
self.add_module(name + 'activation', activation)
if not bn and instance_norm:
self.add_module(name + 'in', in_unit)
class _BNBase(nn.Sequential):
def __init__(self, in_size, batch_norm=None, name=""):
super().__init__()
self.add_module(name + "bn", batch_norm(in_size))
nn.init.constant_(self[0].weight, 1.0)
nn.init.constant_(self[0].bias, 0)
class BatchNorm1d(_BNBase):
def __init__(self, in_size: int, *, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
class BatchNorm2d(_BNBase):
def __init__(self, in_size: int, name: str = ""):
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
class Conv1d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = "",
instance_norm=False
):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv1d,
batch_norm=BatchNorm1d,
bias=bias,
preact=preact,
name=name,
instance_norm=instance_norm,
instance_norm_func=nn.InstanceNorm1d
)
class Conv2d(_ConvBase):
def __init__(
self,
in_size: int,
out_size: int,
*,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=nn.init.kaiming_normal_,
bias: bool = True,
preact: bool = False,
name: str = "",
instance_norm=False
):
super().__init__(
in_size,
out_size,
kernel_size,
stride,
padding,
activation,
bn,
init,
conv=nn.Conv2d,
batch_norm=BatchNorm2d,
bias=bias,
preact=preact,
name=name,
instance_norm=instance_norm,
instance_norm_func=nn.InstanceNorm2d
)
class FC(nn.Sequential):
def __init__(
self,
in_size: int,
out_size: int,
*,
activation=nn.ReLU(inplace=True),
bn: bool = False,
init=None,
preact: bool = False,
name: str = ""
):
super().__init__()
fc = nn.Linear(in_size, out_size, bias=not bn)
if init is not None:
init(fc.weight)
if not bn:
nn.init.constant(fc.bias, 0)
if preact:
if bn:
self.add_module(name + 'bn', BatchNorm1d(in_size))
if activation is not None:
self.add_module(name + 'activation', activation)
self.add_module(name + 'fc', fc)
if not preact:
if bn:
self.add_module(name + 'bn', BatchNorm1d(out_size))
if activation is not None:
self.add_module(name + 'activation', activation)

View File

@@ -0,0 +1,149 @@
import torch
import torch.nn as nn
import os
import sys
path = os.path.abspath(__file__)
for i in range(2):
path = os.path.dirname(path)
PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT)
import PytorchBoot.stereotype as stereotype
from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG
ClsMSG_CFG_Dense = {
'NPOINTS': [512, 256, 128, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Light = {
'NPOINTS': [512, 256, 128, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Light_2048 = {
'NPOINTS': [512, 256, 128, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 1024], [256, 512, 1024]]],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Strong = {
'NPOINTS': [512, 256, 128, 64, None],
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16],[0.16, 0.32], [None, None]],
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32], [None, None]],
'MLPS': [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 512, 512]],
[[512, 512, 2048], [512, 1024, 2048]]
],
'DP_RATIO': 0.5,
}
ClsMSG_CFG_Lighter = {
'NPOINTS': [512, 256, 128, 64, None],
'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
'NSAMPLE': [[64], [32], [16], [8], [None]],
'MLPS': [[[32, 32, 64]],
[[64, 64, 128]],
[[128, 196, 256]],
[[256, 256, 512]],
[[512, 512, 1024]]],
'DP_RATIO': 0.5,
}
def select_params(name):
if name == 'light':
return ClsMSG_CFG_Light
elif name == 'lighter':
return ClsMSG_CFG_Lighter
elif name == 'dense':
return ClsMSG_CFG_Dense
elif name == 'light_2048':
return ClsMSG_CFG_Light_2048
elif name == 'strong':
return ClsMSG_CFG_Strong
else:
raise NotImplementedError
def break_up_pc(pc):
xyz = pc[..., 0:3].contiguous()
features = (
pc[..., 3:].transpose(1, 2).contiguous()
if pc.size(-1) > 3 else None
)
return xyz, features
@stereotype.module("pointnet++_encoder")
class PointNet2Encoder(nn.Module):
def encode_points(self, pts, require_per_point_feat=False):
return self.forward(pts)
def __init__(self, config:dict):
super().__init__()
channel_in = config.get("in_dim", 3) - 3
params_name = config.get("params_name", "light")
self.SA_modules = nn.ModuleList()
selected_params = select_params(params_name)
for k in range(selected_params['NPOINTS'].__len__()):
mlps = selected_params['MLPS'][k].copy()
channel_out = 0
for idx in range(mlps.__len__()):
mlps[idx] = [channel_in] + mlps[idx]
channel_out += mlps[idx][-1]
self.SA_modules.append(
PointnetSAModuleMSG(
npoint=selected_params['NPOINTS'][k],
radii=selected_params['RADIUS'][k],
nsamples=selected_params['NSAMPLE'][k],
mlps=mlps,
use_xyz=True,
bn=True
)
)
channel_in = channel_out
def forward(self, point_cloud: torch.cuda.FloatTensor):
xyz, features = break_up_pc(point_cloud)
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
l_xyz.append(li_xyz)
l_features.append(li_features)
return l_features[-1].squeeze(-1)
if __name__ == '__main__':
seed = 100
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
net = PointNet2Encoder(config={"in_dim": 3, "params_name": "strong"}).cuda()
pts = torch.randn(2, 2444, 3).cuda()
print(torch.mean(pts, dim=1))
pre = net.encode_points(pts)
print(pre.shape)

107
modules/pointnet_encoder.py Normal file
View File

@@ -0,0 +1,107 @@
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import PytorchBoot.stereotype as stereotype
@stereotype.module("pointnet_encoder")
class PointNetEncoder(nn.Module):
def __init__(self, config:dict):
super(PointNetEncoder, self).__init__()
self.out_dim = config["out_dim"]
self.in_dim = config["in_dim"]
self.feature_transform = config.get("feature_transform", False)
self.stn = STNkd(k=self.in_dim)
self.conv1 = torch.nn.Conv1d(self.in_dim , 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 512, 1)
self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1)
if self.feature_transform:
self.f_stn = STNkd(k=64)
def forward(self, x):
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
x = x.transpose(2, 1)
x = F.relu(self.conv1(x))
if self.feature_transform:
trans_feat = self.f_stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2, 1)
point_feat = x
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.conv4(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, self.out_dim)
return x, point_feat
def encode_points(self, pts, require_per_point_feat=False):
pts = pts.transpose(2, 1)
global_pts_feature, per_point_feature = self(pts)
if require_per_point_feat:
return global_pts_feature, per_point_feature.transpose(2, 1)
else:
return global_pts_feature
class STNkd(nn.Module):
def __init__(self, k=64):
super(STNkd, self).__init__()
self.conv1 = torch.nn.Conv1d(k, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.relu = nn.ReLU()
self.k = k
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
iden = (
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
.view(1, self.k * self.k)
.repeat(batchsize, 1)
)
if x.is_cuda:
iden = iden.to(x.get_device())
x = x + iden
x = x.view(-1, self.k, self.k)
return x
if __name__ == "__main__":
sim_data = Variable(torch.rand(32, 2500, 3))
config = {
"in_dim": 3,
"out_dim": 1024,
"feature_transform": False
}
pointnet = PointNetEncoder(config)
out = pointnet.encode_points(sim_data)
print("global feat", out.size())
out, per_point_out = pointnet.encode_points(sim_data, require_per_point_feat=True)
print("point feat", out.size())
print("per point feat", per_point_out.size())

21
modules/pose_encoder.py Normal file
View File

@@ -0,0 +1,21 @@
from torch import nn
import PytorchBoot.stereotype as stereotype
@stereotype.module("pose_encoder")
class PoseEncoder(nn.Module):
def __init__(self, config):
super(PoseEncoder, self).__init__()
self.config = config
pose_dim = config["pose_dim"]
out_dim = config["out_dim"]
self.act = nn.ReLU(True)
self.pose_encoder = nn.Sequential(
nn.Linear(pose_dim, out_dim),
self.act,
nn.Linear(out_dim, out_dim),
self.act,
)
def encode_pose(self, pose):
return self.pose_encoder(pose)

View File

@@ -0,0 +1,20 @@
from torch import nn
import PytorchBoot.stereotype as stereotype
@stereotype.module("pts_num_encoder")
class PointsNumEncoder(nn.Module):
def __init__(self, config):
super(PointsNumEncoder, self).__init__()
self.config = config
out_dim = config["out_dim"]
self.act = nn.ReLU(True)
self.pts_num_encoder = nn.Sequential(
nn.Linear(1, out_dim),
self.act,
nn.Linear(out_dim, out_dim),
self.act,
)
def encode_pts_num(self, num_seq):
return self.pts_num_encoder(num_seq)

View File

@@ -0,0 +1,63 @@
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import PytorchBoot.stereotype as stereotype
@stereotype.module("transformer_seq_encoder")
class TransformerSequenceEncoder(nn.Module):
def __init__(self, config):
super(TransformerSequenceEncoder, self).__init__()
self.config = config
embed_dim = config["embed_dim"]
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=config["num_heads"],
dim_feedforward=config["ffn_dim"],
batch_first=True,
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config["num_layers"]
)
self.fc = nn.Linear(embed_dim, config["output_dim"])
def encode_sequence(self, embedding_list_batch):
lengths = []
for embedding_list in embedding_list_batch:
lengths.append(len(embedding_list))
embedding_tensor = pad_sequence(embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim]
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(embedding_tensor.device)
transformer_output = self.transformer_encoder(embedding_tensor, src_key_padding_mask=padding_mask)
final_feature = transformer_output.mean(dim=1)
final_output = self.fc(final_feature)
return final_output
if __name__ == "__main__":
config = {
"embed_dim": 256,
"num_heads": 4,
"ffn_dim": 256,
"num_layers": 3,
"output_dim": 1024,
}
encoder = TransformerSequenceEncoder(config)
seq_len = [5, 8, 9, 4]
batch_size = 4
embedding_list_batch = [
torch.randn(seq_len[idx], config["embed_dim"]) for idx in range(batch_size)
]
output_feature = encoder.encode_sequence(
embedding_list_batch
)
print("Encoded Feature:", output_feature)
print("Feature Shape:", output_feature.shape)