add global_feat

This commit is contained in:
2024-09-24 09:10:25 +00:00
parent b209ce050c
commit 43f22ad91b
7 changed files with 123 additions and 62 deletions

View File

@@ -32,7 +32,7 @@ def cond_ode_sampler(
init_x=None,
):
pose_dim = PoseUtil.get_pose_dim(pose_mode)
batch_size = data["seq_feat"].shape[0]
batch_size = data["main_feat"].shape[0]
init_x = (
prior((batch_size, pose_dim), T=T).to(device)
if init_x is None