update basic framework

This commit is contained in:
hofee
2024-08-21 17:11:56 +08:00
parent 73dcd592df
commit f977fd4b8e
29 changed files with 1393 additions and 719 deletions

View File

@@ -0,0 +1,7 @@
from modules.func_lib.samplers import (
cond_pc_sampler,
cond_ode_sampler
)
from modules.func_lib.sde import (
init_sde
)

View File

@@ -0,0 +1,280 @@
import torch
import numpy as np
from scipy import integrate
from utils.pose import PoseUtil
def global_prior_likelihood(z, sigma_max):
"""The likelihood of a Gaussian distribution with mean zero and
standard deviation sigma."""
# z: [bs, pose_dim]
shape = z.shape
N = np.prod(shape[1:]) # pose_dim
return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2)
def cond_ode_likelihood(
score_model,
data,
prior,
sde_coeff,
marginal_prob_fn,
atol=1e-5,
rtol=1e-5,
device='cuda',
eps=1e-5,
num_steps=None,
pose_mode='quat_wxyz',
init_x=None,
):
pose_dim = PoseUtil.get_pose_dim(pose_mode)
batch_size = data['pts'].shape[0]
epsilon = prior((batch_size, pose_dim)).to(device)
init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x
shape = init_x.shape
init_logp = np.zeros((shape[0],)) # [bs]
init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0)
def score_eval_wrapper(data):
"""A wrapper of the score-based model for use by the ODE solver."""
with torch.no_grad():
score = score_model(data)
return score.cpu().numpy().reshape((-1,))
def divergence_eval(data, epsilon):
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
# save ckpt of sampled_pose
origin_sampled_pose = data['sampled_pose'].clone()
with torch.enable_grad():
# make sampled_pose differentiable
data['sampled_pose'].requires_grad_(True)
score = score_model(data)
score_energy = torch.sum(score * epsilon) # [, ]
grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim]
# reset sampled_pose
data['sampled_pose'] = origin_sampled_pose
return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1]
def divergence_eval_wrapper(data):
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
with torch.no_grad():
# Compute likelihood.
div = divergence_eval(data, epsilon) # [bs, 1]
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
def ode_func(t, inp):
"""The ODE function for use by the ODE solver."""
# split x, logp from inp
x = inp[:-shape[0]]
# calc x-grad
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
drift, diffusion = sde_coeff(torch.tensor(t))
drift = drift.cpu().numpy()
diffusion = diffusion.cpu().numpy()
data['sampled_pose'] = x
data['t'] = time_steps
x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
# calc logp-grad
logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data)
# concat curr grad
return np.concatenate([x_grad, logp_grad], axis=0)
# Run the black-box ODE solver, note the
res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45')
zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)]
z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim]
delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp
_, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1
prior_logp = global_prior_likelihood(z, sigma_max)
log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls)
return z, log_likelihoods
def cond_pc_sampler(
score_model,
data,
prior,
sde_coeff,
num_steps=500,
snr=0.16,
device='cuda',
eps=1e-5,
pose_mode='quat_wxyz',
init_x=None,
):
pose_dim = PoseUtil.get_pose_dim(pose_mode)
batch_size = data['target_pts_feat'].shape[0]
init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1]
noise_norm = np.sqrt(pose_dim)
x = init_x
poses = []
with torch.no_grad():
for time_step in time_steps:
batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step
# Corrector step (Langevin MCMC)
data['sampled_pose'] = x
data['t'] = batch_time_step
grad = score_model(data)
grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean()
langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
# normalisation
if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw':
# quat, should be normalised
x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True)
elif pose_mode == 'euler_xyz':
pass
else:
# rotation(x axis, y axis), should be normalised
x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True)
x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True)
# Predictor step (Euler-Maruyama)
drift, diffusion = sde_coeff(batch_time_step)
drift = drift - diffusion ** 2 * grad # R-SDE
mean_x = x + drift * step_size
x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x)
# normalisation
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
poses.append(x.unsqueeze(0))
xs = torch.cat(poses, dim=0)
xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
mean_x[:, -3:] += data['pts_center']
mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode)
# The last step does not include any noise
return xs.permute(1, 0, 2), mean_x
def cond_ode_sampler(
score_model,
data,
prior,
sde_coeff,
atol=1e-5,
rtol=1e-5,
device='cuda',
eps=1e-5,
T=1.0,
num_steps=None,
pose_mode='quat_wxyz',
denoise=True,
init_x=None,
):
pose_dim = PoseUtil.get_pose_dim(pose_mode)
batch_size = data['target_feat'].shape[0]
init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim),
T=T).to(device)
shape = init_x.shape
def score_eval_wrapper(data):
"""A wrapper of the score-based model for use by the ODE solver."""
with torch.no_grad():
score = score_model(data)
return score.cpu().numpy().reshape((-1,))
def ode_func(t, x):
"""The ODE function for use by the ODE solver."""
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
drift, diffusion = sde_coeff(torch.tensor(t))
drift = drift.cpu().numpy()
diffusion = diffusion.cpu().numpy()
data['sampled_pose'] = x
data['t'] = time_steps
return drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
# Run the black-box ODE solver, note the
t_eval = None
if num_steps is not None:
# num_steps, from T -> eps
t_eval = np.linspace(T, eps, num_steps)
res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45',
t_eval=t_eval)
xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
# denoise, using the predictor step in P-C sampler
if denoise:
# Reverse diffusion predictor for denoising
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
drift, diffusion = sde_coeff(vec_eps)
data['sampled_pose'] = x.float()
data['t'] = vec_eps
grad = score_model(data)
drift = drift - diffusion ** 2 * grad # R-SDE
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
x = mean_x
num_steps = xs.shape[0]
xs = xs.reshape(batch_size * num_steps, -1)
xs = PoseUtil.normalize_rotation(xs, pose_mode)
xs = xs.reshape(num_steps, batch_size, -1)
x = PoseUtil.normalize_rotation(x, pose_mode)
return xs.permute(1, 0, 2), x
def cond_edm_sampler(
decoder_model, data, prior_fn, randn_like=torch.randn_like,
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
pose_mode='quat_wxyz', device='cuda'
):
pose_dim = PoseUtil.get_pose_dim(pose_mode)
batch_size = data['pts'].shape[0]
latents = prior_fn((batch_size, pose_dim)).to(device)
# Time step discretion. note that sigma and t is interchangeable
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
def decoder_wrapper(decoder, data, x, t):
# save temp
x_, t_ = data['sampled_pose'], data['t']
# init data
data['sampled_pose'], data['t'] = x, t
# denoise
data, denoised = decoder(data)
# recover data
data['sampled_pose'], data['t'] = x_, t_
return denoised.to(torch.float64)
# Main sampling loop.
x_next = latents.to(torch.float64) * t_steps[0]
xs = []
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = torch.as_tensor(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
# Euler step.
denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = decoder_wrapper(decoder_model, data, x_next, t_next)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
xs.append(x_next.unsqueeze(0))
xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim]
x = xs[-1] # [bs, pose_dim]
# post-processing
xs = xs.reshape(batch_size * num_steps, -1)
xs = PoseUtil.normalize_rotation(xs, pose_mode)
xs = xs.reshape(num_steps, batch_size, -1)
x = PoseUtil.normalize_rotation(x, pose_mode)
return xs.permute(1, 0, 2), x

121
modules/func_lib/sde.py Normal file
View File

@@ -0,0 +1,121 @@
import functools
import torch
import numpy as np
# ----- VE SDE -----
# ------------------
def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
std = sigma_min * (sigma_max / sigma_min) ** t
mean = x
return mean, std
def ve_sde(t, sigma_min=0.01, sigma_max=90):
sigma = sigma_min * (sigma_max / sigma_min) ** t
drift_coeff = torch.tensor(0)
diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
return drift_coeff, diffusion_coeff
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
return torch.randn(*shape) * sigma_max_prior
# ----- VP SDE -----
# ------------------
def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
mean = torch.exp(log_mean_coeff) * x
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
return mean, std
def vp_sde(t, beta_0=0.1, beta_1=20):
beta_t = beta_0 + t * (beta_1 - beta_0)
drift_coeff = -0.5 * beta_t
diffusion_coeff = torch.sqrt(beta_t)
return drift_coeff, diffusion_coeff
def vp_prior(shape, beta_0=0.1, beta_1=20):
return torch.randn(*shape)
# ----- sub-VP SDE -----
# ----------------------
def subvp_marginal_prob(x, t, beta_0, beta_1):
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
mean = torch.exp(log_mean_coeff) * x
std = 1 - torch.exp(2. * log_mean_coeff)
return mean, std
def subvp_sde(t, beta_0, beta_1):
beta_t = beta_0 + t * (beta_1 - beta_0)
drift_coeff = -0.5 * beta_t
discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
diffusion_coeff = torch.sqrt(beta_t * discount)
return drift_coeff, diffusion_coeff
def subvp_prior(shape, beta_0=0.1, beta_1=20):
return torch.randn(*shape)
# ----- EDM SDE -----
# ------------------
def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
std = t
mean = x
return mean, std
def edm_sde(t, sigma_min=0.002, sigma_max=80):
drift_coeff = torch.tensor(0)
diffusion_coeff = torch.sqrt(2 * t)
return drift_coeff, diffusion_coeff
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
return torch.randn(*shape) * sigma_max
def init_sde(sde_mode):
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
if sde_mode == 'edm':
sigma_min = 0.002
sigma_max = 80
eps = 0.002
prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
T = sigma_max
elif sde_mode == 've':
sigma_min = 0.01
sigma_max = 50
eps = 1e-5
marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
T = 1.0
prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
elif sde_mode == 'vp':
beta_0 = 0.1
beta_1 = 20
eps = 1e-3
prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
T = 1.0
elif sde_mode == 'subvp':
beta_0 = 0.1
beta_1 = 20
eps = 1e-3
prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
T = 1.0
else:
raise NotImplementedError
return prior_fn, marginal_prob_fn, sde_fn, eps, T

View File

@@ -0,0 +1,17 @@
import torch
import numpy as np
import torch.nn as nn
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

View File

@@ -0,0 +1,30 @@
import torch
import numpy as np
def weight_init(shape, mode, fan_in, fan_out):
if mode == 'xavier_uniform':
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == 'xavier_normal':
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == 'kaiming_uniform':
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == 'kaiming_normal':
return np.sqrt(1 / fan_in) * torch.randn(*shape)
raise ValueError(f'Invalid init mode "{mode}"')
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
def forward(self, x):
x = x @ self.weight.to(x.dtype).t()
if self.bias is not None:
x = x.add_(self.bias.to(x.dtype))
return x

View File

@@ -1,20 +0,0 @@
from torch import nn
from configs.config import ConfigManager
class Pipeline(nn.Module):
TRAIN_MODE: str = "train"
TEST_MODE: str = "test"
def __init__(self, pipeline_config):
super(Pipeline, self).__init__()
self.modules_config = ConfigManager.get("modules")
self.device = ConfigManager.get("settings", "general", "device")
def forward(self, data, mode):
pass
if __name__ == '__main__':
pass

View File

@@ -0,0 +1,12 @@
from abc import abstractmethod
from torch import nn
class PointsEncoder(nn.Module):
def __init__(self):
super(PointsEncoder, self).__init__()
@abstractmethod
def encode_points(self, pts):
pass

View File

@@ -0,0 +1,110 @@
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
import PytorchBoot.stereotype as stereotype
class STNkd(nn.Module):
def __init__(self, k=64):
super(STNkd, self).__init__()
self.conv1 = torch.nn.Conv1d(k, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.relu = nn.ReLU()
self.k = k
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
iden = (
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
.view(1, self.k * self.k)
.repeat(batchsize, 1)
)
if x.is_cuda:
iden = iden.to(x.get_device())
x = x + iden
x = x.view(-1, self.k, self.k)
return x
@stereotype.module("pointnet_encoder")
class PointNetEncoder(PointsEncoder):
def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False):
super(PointNetEncoder, self).__init__()
self.out_dim = out_dim
self.feature_transform = feature_transform
self.stn = STNkd(k=in_dim)
self.conv1 = torch.nn.Conv1d(in_dim, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 512, 1)
self.conv4 = torch.nn.Conv1d(512, out_dim, 1)
self.global_feat = global_feat
if self.feature_transform:
self.f_stn = STNkd(k=64)
def forward(self, x):
n_pts = x.shape[2]
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
x = x.transpose(2, 1)
x = F.relu(self.conv1(x))
if self.feature_transform:
trans_feat = self.f_stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2, 1)
point_feat = x
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = self.conv4(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, self.out_dim)
if self.global_feat:
return x
else:
x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts)
return torch.cat([x, point_feat], 1)
def encode_points(self, pts):
pts = pts.transpose(2, 1)
if not self.global_feat:
pts_feature = self(pts).transpose(2, 1)
else:
pts_feature = self(pts)
return pts_feature
if __name__ == "__main__":
sim_data = Variable(torch.rand(32, 2500, 3))
pointnet_global = PointNetEncoder(global_feat=True)
out = pointnet_global.encode_points(sim_data)
print("global feat", out.size())
pointnet = PointNetEncoder(global_feat=False)
out = pointnet.encode_points(sim_data)
print("point feat", out.size())

View File

@@ -0,0 +1,12 @@
from abc import abstractmethod
from torch import nn
class ViewFinder(nn.Module):
def __init__(self):
super(ViewFinder, self).__init__()
@abstractmethod
def next_best_view(self, scene_pts_feat, target_pts_feat):
pass

View File

@@ -0,0 +1,168 @@
import torch
import torch.nn as nn
import PytorchBoot.stereotype as stereotype
from utils.pose import PoseUtil
from modules.view_finder.abstract_view_finder import ViewFinder
import modules.module_lib as mlib
import modules.func_lib as flib
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
@stereotype.module("gf_view_finder")
class GradientFieldViewFinder(ViewFinder):
def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False,
sample_mode="ode", sampling_steps=None, sde_mode="ve"):
super(GradientFieldViewFinder, self).__init__()
self.regression_head = regression_head
self.per_point_feature = per_point_feature
self.act = nn.ReLU(True)
self.sample_mode = sample_mode
self.pose_mode = pose_mode
pose_dim = PoseUtil.get_pose_dim(pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode)
self.sampling_steps = sampling_steps
''' encode pose '''
self.pose_encoder = nn.Sequential(
nn.Linear(pose_dim, 256),
self.act,
nn.Linear(256, 256),
self.act,
)
''' encode t '''
self.t_encoder = nn.Sequential(
mlib.GaussianFourierProjection(embed_dim=128),
nn.Linear(128, 128),
self.act,
)
''' fusion tail '''
if self.regression_head == 'Rx_Ry':
if pose_mode != 'rot_matrix':
raise NotImplementedError
if not per_point_feature:
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(128 + 256 + 1024 + 1024, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(128 + 256 + 1024 + 1024, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
else:
raise NotImplementedError
else:
raise NotImplementedError
def forward(self, data):
"""
Args:
data, dict {
'target_pts_feat': [bs, c]
'scene_pts_feat': [bs, c]
'pose_sample': [bs, pose_dim]
't': [bs, 1]
}
"""
scene_pts_feat = data['scene_feat']
target_pts_feat = data['target_feat']
sampled_pose = data['sampled_pose']
t = data['t']
t_feat = self.t_encoder(t.squeeze(1))
pose_feat = self.pose_encoder(sampled_pose)
if self.per_point_feature:
raise NotImplementedError
else:
total_feat = torch.cat([scene_pts_feat, target_pts_feat, t_feat, pose_feat], dim=-1)
_, std = self.marginal_prob_fn(total_feat, t)
if self.regression_head == 'Rx_Ry':
rot_x = self.fusion_tail_rot_x(total_feat)
rot_y = self.fusion_tail_rot_y(total_feat)
out_score = torch.cat([rot_x, rot_y], dim=-1) / (std + 1e-7) # normalisation
else:
raise NotImplementedError
return out_score
def marginal_prob(self, x, t):
return self.marginal_prob_fn(x,t)
def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None):
if self.sample_mode == 'pc':
in_process_sample, res = flib.cond_pc_sampler(
score_model=self,
data=data,
prior=self.prior_fn,
sde_coeff=self.sde_fn,
num_steps=self.sampling_steps,
snr=snr,
eps=self.sampling_eps,
pose_mode=self.pose_mode,
init_x=init_x
)
elif self.sample_mode == 'ode':
T0 = self.T if T0 is None else T0
in_process_sample, res = flib.cond_ode_sampler(
score_model=self,
data=data,
prior=self.prior_fn,
sde_coeff=self.sde_fn,
atol=atol,
rtol=rtol,
eps=self.sampling_eps,
T=T0,
num_steps=self.sampling_steps,
pose_mode=self.pose_mode,
denoise=denoise,
init_x=init_x
)
else:
raise NotImplementedError
return in_process_sample, res
def next_best_view(self, scene_pts_feat, target_pts_feat):
data = {
'scene_feat': scene_pts_feat,
'target_feat': target_pts_feat,
}
in_process_sample, res = self.sample(data)
return res.to(dtype=torch.float32), in_process_sample
''' ----------- DEBUG -----------'''
if __name__ == "__main__":
test_scene_feat = torch.rand(32, 1024).to("cuda:0")
test_target_feat = torch.rand(32, 1024).to("cuda:0")
test_pose = torch.rand(32, 6).to("cuda:0")
test_t = torch.rand(32, 1).to("cuda:0")
view_finder = GradientFieldViewFinder().to("cuda:0")
test_data = {
'target_feat': test_target_feat,
'scene_feat': test_scene_feat,
'sampled_pose': test_pose,
't': test_t
}
score = view_finder(test_data)
result = view_finder.next_best_view(test_scene_feat, test_target_feat)
print(result)