first commit
This commit is contained in:
6
modules/func_lib/__init__.py
Normal file
6
modules/func_lib/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from modules.func_lib.samplers import (
|
||||
cond_ode_sampler
|
||||
)
|
||||
from modules.func_lib.sde import (
|
||||
init_sde
|
||||
)
|
BIN
modules/func_lib/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
modules/func_lib/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/func_lib/__pycache__/samplers.cpython-39.pyc
Normal file
BIN
modules/func_lib/__pycache__/samplers.cpython-39.pyc
Normal file
Binary file not shown.
BIN
modules/func_lib/__pycache__/sde.cpython-39.pyc
Normal file
BIN
modules/func_lib/__pycache__/sde.cpython-39.pyc
Normal file
Binary file not shown.
95
modules/func_lib/samplers.py
Normal file
95
modules/func_lib/samplers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from scipy import integrate
|
||||
from utils.pose import PoseUtil
|
||||
|
||||
|
||||
def global_prior_likelihood(z, sigma_max):
|
||||
"""The likelihood of a Gaussian distribution with mean zero and
|
||||
standard deviation sigma."""
|
||||
# z: [bs, pose_dim]
|
||||
shape = z.shape
|
||||
N = np.prod(shape[1:]) # pose_dim
|
||||
return -N / 2.0 * torch.log(2 * np.pi * sigma_max**2) - torch.sum(
|
||||
z**2, dim=-1
|
||||
) / (2 * sigma_max**2)
|
||||
|
||||
|
||||
def cond_ode_sampler(
|
||||
score_model,
|
||||
data,
|
||||
prior,
|
||||
sde_coeff,
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
device="cuda",
|
||||
eps=1e-5,
|
||||
T=1.0,
|
||||
num_steps=None,
|
||||
pose_mode="quat_wxyz",
|
||||
denoise=True,
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data["main_feat"].shape[0]
|
||||
init_x = (
|
||||
prior((batch_size, pose_dim), T=T).to(device)
|
||||
if init_x is None
|
||||
else init_x + prior((batch_size, pose_dim), T=T).to(device)
|
||||
)
|
||||
shape = init_x.shape
|
||||
|
||||
def score_eval_wrapper(data):
|
||||
"""A wrapper of the score-based model for use by the ODE solver."""
|
||||
with torch.no_grad():
|
||||
score = score_model(data)
|
||||
return score.cpu().numpy().reshape((-1,))
|
||||
|
||||
def ode_func(t, x):
|
||||
"""The ODE function for use by the ODE solver."""
|
||||
x = torch.tensor(x.reshape(-1, pose_dim), dtype=torch.float32, device=device)
|
||||
time_steps = torch.ones(batch_size, device=device).unsqueeze(-1) * t
|
||||
drift, diffusion = sde_coeff(torch.tensor(t))
|
||||
drift = drift.cpu().numpy()
|
||||
diffusion = diffusion.cpu().numpy()
|
||||
data["sampled_pose"] = x
|
||||
data["t"] = time_steps
|
||||
return drift - 0.5 * (diffusion**2) * score_eval_wrapper(data)
|
||||
|
||||
# Run the black-box ODE solver, note the
|
||||
t_eval = None
|
||||
if num_steps is not None:
|
||||
# num_steps, from T -> eps
|
||||
t_eval = np.linspace(T, eps, num_steps)
|
||||
res = integrate.solve_ivp(
|
||||
ode_func,
|
||||
(T, eps),
|
||||
init_x.reshape(-1).cpu().numpy(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
method="RK45",
|
||||
t_eval=t_eval,
|
||||
)
|
||||
xs = torch.tensor(res.y, device=device).T.view(
|
||||
-1, batch_size, pose_dim
|
||||
) # [num_steps, bs, pose_dim]
|
||||
x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim]
|
||||
# denoise, using the predictor step in P-C sampler
|
||||
if denoise:
|
||||
# Reverse diffusion predictor for denoising
|
||||
vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps
|
||||
drift, diffusion = sde_coeff(vec_eps)
|
||||
data["sampled_pose"] = x.float()
|
||||
data["t"] = vec_eps
|
||||
grad = score_model(data)
|
||||
drift = drift - diffusion**2 * grad # R-SDE
|
||||
mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps))
|
||||
x = mean_x
|
||||
|
||||
num_steps = xs.shape[0]
|
||||
xs = xs.reshape(batch_size*num_steps, -1)
|
||||
xs[:, :-3] = PoseUtil.normalize_rotation(xs[:, :-3], pose_mode)
|
||||
xs = xs.reshape(num_steps, batch_size, -1)
|
||||
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
||||
return xs.permute(1, 0, 2), x
|
121
modules/func_lib/sde.py
Normal file
121
modules/func_lib/sde.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import functools
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ----- VE SDE -----
|
||||
# ------------------
|
||||
def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90):
|
||||
std = sigma_min * (sigma_max / sigma_min) ** t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def ve_sde(t, sigma_min=0.01, sigma_max=90):
|
||||
sigma = sigma_min * (sigma_max / sigma_min) ** t
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device))
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0):
|
||||
_, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
return torch.randn(*shape) * sigma_max_prior
|
||||
|
||||
|
||||
# ----- VP SDE -----
|
||||
# ------------------
|
||||
def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
|
||||
return mean, std
|
||||
|
||||
|
||||
def vp_sde(t, beta_0=0.1, beta_1=20):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
diffusion_coeff = torch.sqrt(beta_t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def vp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- sub-VP SDE -----
|
||||
# ----------------------
|
||||
def subvp_marginal_prob(x, t, beta_0, beta_1):
|
||||
log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0
|
||||
mean = torch.exp(log_mean_coeff) * x
|
||||
std = 1 - torch.exp(2. * log_mean_coeff)
|
||||
return mean, std
|
||||
|
||||
|
||||
def subvp_sde(t, beta_0, beta_1):
|
||||
beta_t = beta_0 + t * (beta_1 - beta_0)
|
||||
drift_coeff = -0.5 * beta_t
|
||||
discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2)
|
||||
diffusion_coeff = torch.sqrt(beta_t * discount)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def subvp_prior(shape, beta_0=0.1, beta_1=20):
|
||||
return torch.randn(*shape)
|
||||
|
||||
|
||||
# ----- EDM SDE -----
|
||||
# ------------------
|
||||
def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80):
|
||||
std = t
|
||||
mean = x
|
||||
return mean, std
|
||||
|
||||
|
||||
def edm_sde(t, sigma_min=0.002, sigma_max=80):
|
||||
drift_coeff = torch.tensor(0)
|
||||
diffusion_coeff = torch.sqrt(2 * t)
|
||||
return drift_coeff, diffusion_coeff
|
||||
|
||||
|
||||
def edm_prior(shape, sigma_min=0.002, sigma_max=80):
|
||||
return torch.randn(*shape) * sigma_max
|
||||
|
||||
|
||||
def init_sde(sde_mode):
|
||||
# the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch
|
||||
if sde_mode == 'edm':
|
||||
sigma_min = 0.002
|
||||
sigma_max = 80
|
||||
eps = 0.002
|
||||
prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = sigma_max
|
||||
elif sde_mode == 've':
|
||||
sigma_min = 0.01
|
||||
sigma_max = 50
|
||||
eps = 1e-5
|
||||
marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
T = 1.0
|
||||
prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
elif sde_mode == 'vp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
elif sde_mode == 'subvp':
|
||||
beta_0 = 0.1
|
||||
beta_1 = 20
|
||||
eps = 1e-3
|
||||
prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1)
|
||||
marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1)
|
||||
sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1)
|
||||
T = 1.0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return prior_fn, marginal_prob_fn, sde_fn, eps, T
|
Reference in New Issue
Block a user