update basic framework
This commit is contained in:
7
modules/func_lib/__init__.py
Normal file
7
modules/func_lib/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from modules.func_lib.samplers import (
|
||||
cond_pc_sampler,
|
||||
cond_ode_sampler
|
||||
)
|
||||
from modules.func_lib.sde import (
|
||||
init_sde
|
||||
)
|
280
modules/func_lib/samplers.py
Normal file
280
modules/func_lib/samplers.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from scipy import integrate
|
||||
from utils.pose import PoseUtil
|
||||
|
||||
|
||||
def global_prior_likelihood(z, sigma_max):
|
||||
"""The likelihood of a Gaussian distribution with mean zero and
|
||||
standard deviation sigma."""
|
||||
# z: [bs, pose_dim]
|
||||
shape = z.shape
|
||||
N = np.prod(shape[1:]) # pose_dim
|
||||
return -N / 2. * torch.log(2 * np.pi * sigma_max ** 2) - torch.sum(z ** 2, dim=-1) / (2 * sigma_max ** 2)
|
||||
|
||||
|
||||
def cond_ode_likelihood(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
marginal_prob_fn,
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
device='cuda',
|
||||
eps=1e-5,
|
||||
num_steps=None,
|
||||
pose_mode='quat_wxyz',
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['pts'].shape[0]
|
||||
epsilon = prior((batch_size, pose_dim)).to(device)
|
||||
init_x = data['sampled_pose'].clone().cpu().numpy() if init_x is None else init_x
|
||||
shape = init_x.shape
|
||||
init_logp = np.zeros((shape[0],)) # [bs]
|
||||
init_inp = np.concatenate([init_x.reshape(-1), init_logp], axis=0)
|
||||
|
||||
def score_eval_wrapper(data):
|
||||
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||
with torch.no_grad():
|
||||
score = score_model(data)
|
||||
return score.cpu().numpy().reshape((-1,))
|
||||
|
||||
def divergence_eval(data, epsilon):
|
||||
"""Compute the divergence of the score-based model with Skilling-Hutchinson."""
|
||||
# save ckpt of sampled_pose
|
||||
origin_sampled_pose = data['sampled_pose'].clone()
|
||||
with torch.enable_grad():
|
||||
# make sampled_pose differentiable
|
||||
data['sampled_pose'].requires_grad_(True)
|
||||
score = score_model(data)
|
||||
score_energy = torch.sum(score * epsilon) # [, ]
|
||||
grad_score_energy = torch.autograd.grad(score_energy, data['sampled_pose'])[0] # [bs, pose_dim]
|
||||
# reset sampled_pose
|
||||
data['sampled_pose'] = origin_sampled_pose
|
||||
return torch.sum(grad_score_energy * epsilon, dim=-1) # [bs, 1]
|
||||
|
||||
def divergence_eval_wrapper(data):
|
||||
"""A wrapper for evaluating the divergence of score for the black-box ODE solver."""
|
||||
with torch.no_grad():
|
||||
# Compute likelihood.
|
||||
div = divergence_eval(data, epsilon) # [bs, 1]
|
||||
return div.cpu().numpy().reshape((-1,)).astype(np.float64)
|
||||
|
||||
def ode_func(t, inp):
|
||||
"""The ODE function for use by the ODE solver."""
|
||||
# split x, logp from inp
|
||||
x = inp[:-shape[0]]
|
||||
# calc x-grad
|
||||
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||
drift = drift.cpu().numpy()
|
||||
diffusion = diffusion.cpu().numpy()
|
||||
data['sampled_pose'] = x
|
||||
data['t'] = time_steps
|
||||
x_grad = drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
|
||||
# calc logp-grad
|
||||
logp_grad = drift - 0.5 * (diffusion ** 2) * divergence_eval_wrapper(data)
|
||||
# concat curr grad
|
||||
return np.concatenate([x_grad, logp_grad], axis=0)
|
||||
|
||||
# Run the black-box ODE solver, note the
|
||||
res = integrate.solve_ivp(ode_func, (eps, 1.0), init_inp, rtol=rtol, atol=atol, method='RK45')
|
||||
zp = torch.tensor(res.y[:, -1], device=device) # [bs * (pose_dim + 1)]
|
||||
z = zp[:-shape[0]].reshape(shape) # [bs, pose_dim]
|
||||
delta_logp = zp[-shape[0]:].reshape(shape[0]) # [bs,] logp
|
||||
_, sigma_max = marginal_prob_fn(None, torch.tensor(1.).to(device)) # we assume T = 1
|
||||
prior_logp = global_prior_likelihood(z, sigma_max)
|
||||
log_likelihoods = (prior_logp + delta_logp) / np.log(2) # negative log-likelihoods (nlls)
|
||||
return z, log_likelihoods
|
||||
|
||||
|
||||
def cond_pc_sampler(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
num_steps=500,
|
||||
snr=0.16,
|
||||
device='cuda',
|
||||
eps=1e-5,
|
||||
pose_mode='quat_wxyz',
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['target_pts_feat'].shape[0]
|
||||
init_x = prior((batch_size, pose_dim)).to(device) if init_x is None else init_x
|
||||
time_steps = torch.linspace(1., eps, num_steps, device=device)
|
||||
step_size = time_steps[0] - time_steps[1]
|
||||
noise_norm = np.sqrt(pose_dim)
|
||||
x = init_x
|
||||
poses = []
|
||||
with torch.no_grad():
|
||||
for time_step in time_steps:
|
||||
batch_time_step = torch.ones(batch_size, device=device).unsqueeze(-1) * time_step
|
||||
# Corrector step (Langevin MCMC)
|
||||
data['sampled_pose'] = x
|
||||
data['t'] = batch_time_step
|
||||
grad = score_model(data)
|
||||
grad_norm = torch.norm(grad.reshape(batch_size, -1), dim=-1).mean()
|
||||
langevin_step_size = 2 * (snr * noise_norm / grad_norm) ** 2
|
||||
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
|
||||
|
||||
# normalisation
|
||||
if pose_mode == 'quat_wxyz' or pose_mode == 'quat_xyzw':
|
||||
# quat, should be normalised
|
||||
x[:, :4] /= torch.norm(x[:, :4], dim=-1, keepdim=True)
|
||||
elif pose_mode == 'euler_xyz':
|
||||
pass
|
||||
else:
|
||||
# rotation(x axis, y axis), should be normalised
|
||||
x[:, :3] /= torch.norm(x[:, :3], dim=-1, keepdim=True)
|
||||
x[:, 3:6] /= torch.norm(x[:, 3:6], dim=-1, keepdim=True)
|
||||
|
||||
# Predictor step (Euler-Maruyama)
|
||||
drift, diffusion = sde_coeff(batch_time_step)
|
||||
drift = drift - diffusion ** 2 * grad # R-SDE
|
||||
mean_x = x + drift * step_size
|
||||
x = mean_x + diffusion * torch.sqrt(step_size) * torch.randn_like(x)
|
||||
|
||||
# normalisation
|
||||
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
||||
poses.append(x.unsqueeze(0))
|
||||
|
||||
xs = torch.cat(poses, dim=0)
|
||||
xs[:, :, -3:] += data['pts_center'].unsqueeze(0).repeat(xs.shape[0], 1, 1)
|
||||
mean_x[:, -3:] += data['pts_center']
|
||||
mean_x[:, :-3] = PoseUtil.normalize_rotation(mean_x[:, :-3], pose_mode)
|
||||
# The last step does not include any noise
|
||||
return xs.permute(1, 0, 2), mean_x
|
||||
|
||||
|
||||
def cond_ode_sampler(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
device='cuda',
|
||||
eps=1e-5,
|
||||
T=1.0,
|
||||
num_steps=None,
|
||||
pose_mode='quat_wxyz',
|
||||
denoise=True,
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['target_feat'].shape[0]
|
||||
init_x = prior((batch_size, pose_dim), T=T).to(device) if init_x is None else init_x + prior((batch_size, pose_dim),
|
||||
T=T).to(device)
|
||||
shape = init_x.shape
|
||||
|
||||
def score_eval_wrapper(data):
|
||||
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||
with torch.no_grad():
|
||||
score = score_model(data)
|
||||
return score.cpu().numpy().reshape((-1,))
|
||||
|
||||
def ode_func(t, x):
|
||||
"""The ODE function for use by the ODE solver."""
|
||||
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||
drift = drift.cpu().numpy()
|
||||
diffusion = diffusion.cpu().numpy()
|
||||
data['sampled_pose'] = x
|
||||
data['t'] = time_steps
|
||||
return drift - 0.5 * (diffusion ** 2) * score_eval_wrapper(data)
|
||||
|
||||
# Run the black-box ODE solver, note the
|
||||
t_eval = None
|
||||
if num_steps is not None:
|
||||
# num_steps, from T -> eps
|
||||
t_eval = np.linspace(T, eps, num_steps)
|
||||
res = integrate.solve_ivp(ode_func, (T, eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45',
|
||||
t_eval=t_eval)
|
||||
xs = torch.tensor(res.y, device=device).T.view(-1, batch_size, pose_dim) # [num_steps, bs, pose_dim]
|
||||
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
||||
# denoise, using the predictor step in P-C sampler
|
||||
if denoise:
|
||||
# Reverse diffusion predictor for denoising
|
||||
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
||||
drift, diffusion = sde_coeff(vec_eps)
|
||||
data['sampled_pose'] = x.float()
|
||||
data['t'] = vec_eps
|
||||
grad = score_model(data)
|
||||
drift = drift - diffusion ** 2 * grad # R-SDE
|
||||
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
||||
x = mean_x
|
||||
|
||||
num_steps = xs.shape[0]
|
||||
xs = xs.reshape(batch_size * num_steps, -1)
|
||||
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
||||
xs = xs.reshape(num_steps, batch_size, -1)
|
||||
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||
return xs.permute(1, 0, 2), x
|
||||
|
||||
|
||||
def cond_edm_sampler(
|
||||
decoder_model, data, prior_fn, randn_like=torch.randn_like,
|
||||
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
||||
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
||||
pose_mode='quat_wxyz', device='cuda'
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data['pts'].shape[0]
|
||||
latents = prior_fn((batch_size, pose_dim)).to(device)
|
||||
|
||||
# Time step discretion. note that sigma and t is interchangeable
|
||||
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
||||
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
||||
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
||||
|
||||
def decoder_wrapper(decoder, data, x, t):
|
||||
# save temp
|
||||
x_, t_ = data['sampled_pose'], data['t']
|
||||
# init data
|
||||
data['sampled_pose'], data['t'] = x, t
|
||||
# denoise
|
||||
data, denoised = decoder(data)
|
||||
# recover data
|
||||
data['sampled_pose'], data['t'] = x_, t_
|
||||
return denoised.to(torch.float64)
|
||||
|
||||
# Main sampling loop.
|
||||
x_next = latents.to(torch.float64) * t_steps[0]
|
||||
xs = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
||||
x_cur = x_next
|
||||
|
||||
# Increase noise temporarily.
|
||||
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
||||
t_hat = torch.as_tensor(t_cur + gamma * t_cur)
|
||||
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
||||
|
||||
# Euler step.
|
||||
denoised = decoder_wrapper(decoder_model, data, x_hat, t_hat)
|
||||
d_cur = (x_hat - denoised) / t_hat
|
||||
x_next = x_hat + (t_next - t_hat) * d_cur
|
||||
|
||||
# Apply 2nd order correction.
|
||||
if i < num_steps - 1:
|
||||
denoised = decoder_wrapper(decoder_model, data, x_next, t_next)
|
||||
d_prime = (x_next - denoised) / t_next
|
||||
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
||||
xs.append(x_next.unsqueeze(0))
|
||||
|
||||
xs = torch.stack(xs, dim=0) # [num_steps, bs, pose_dim]
|
||||
x = xs[-1] # [bs, pose_dim]
|
||||
|
||||
# post-processing
|
||||
xs = xs.reshape(batch_size * num_steps, -1)
|
||||
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
||||
xs = xs.reshape(num_steps, batch_size, -1)
|
||||
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||
return xs.permute(1, 0, 2), x
|
121
modules/func_lib/sde.py
Normal file
121
modules/func_lib/sde.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import functools
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ----- VE SDE -----
|
||||
# ------------------
|
||||
def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
|
||||
std = sigma_min * (sigma_max / sigma_min) ** t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def ve_sde(t, sigma_min=0.01, sigma_max=90):
|
||||
sigma = sigma_min * (sigma_max / sigma_min) ** t
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
|
||||
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
return torch.randn(*shape) * sigma_max_prior
|
||||
|
||||
|
||||
# ----- VP SDE -----
|
||||
# ------------------
|
||||
def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
||||
return mean, std
|
||||
|
||||
|
||||
def vp_sde(t, beta_0=0.1, beta_1=20):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
diffusion_coeff = torch.sqrt(beta_t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def vp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- sub-VP SDE -----
|
||||
# ----------------------
|
||||
def subvp_marginal_prob(x, t, beta_0, beta_1):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = 1 - torch.exp(2. * log_mean_coeff)
|
||||
return mean, std
|
||||
|
||||
|
||||
def subvp_sde(t, beta_0, beta_1):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
|
||||
diffusion_coeff = torch.sqrt(beta_t * discount)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def subvp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- EDM SDE -----
|
||||
# ------------------
|
||||
def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
|
||||
std = t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def edm_sde(t, sigma_min=0.002, sigma_max=80):
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = torch.sqrt(2 * t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
|
||||
return torch.randn(*shape) * sigma_max
|
||||
|
||||
|
||||
def init_sde(sde_mode):
|
||||
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
|
||||
if sde_mode == 'edm':
|
||||
sigma_min = 0.002
|
||||
sigma_max = 80
|
||||
eps = 0.002
|
||||
prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = sigma_max
|
||||
elif sde_mode == 've':
|
||||
sigma_min = 0.01
|
||||
sigma_max = 50
|
||||
eps = 1e-5
|
||||
marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = 1.0
|
||||
prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
elif sde_mode == 'vp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
elif sde_mode == 'subvp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return prior_fn, marginal_prob_fn, sde_fn, eps, T
|
Reference in New Issue
Block a user