add global_feat
This commit is contained in:
@@ -32,7 +32,7 @@ def cond_ode_sampler(
|
||||
init_x=None,
|
||||
):
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
batch_size = data["seq_feat"].shape[0]
|
||||
batch_size = data["main_feat"].shape[0]
|
||||
init_x = (
|
||||
prior((batch_size, pose_dim), T=T).to(device)
|
||||
if init_x is None
|
||||
|
@@ -80,13 +80,13 @@ class GradientFieldViewFinder(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
data, dict {
|
||||
'seq_feat': [bs, c]
|
||||
'main_feat': [bs, c]
|
||||
'pose_sample': [bs, pose_dim]
|
||||
't': [bs, 1]
|
||||
}
|
||||
"""
|
||||
|
||||
seq_feat = data['seq_feat']
|
||||
main_feat = data['main_feat']
|
||||
sampled_pose = data['sampled_pose']
|
||||
t = data['t']
|
||||
t_feat = self.t_encoder(t.squeeze(1))
|
||||
@@ -95,7 +95,7 @@ class GradientFieldViewFinder(nn.Module):
|
||||
if self.per_point_feature:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1)
|
||||
total_feat = torch.cat([main_feat, t_feat, pose_feat], dim=-1)
|
||||
_, std = self.marginal_prob_fn(total_feat, t)
|
||||
|
||||
if self.regression_head == 'Rx_Ry_and_T':
|
||||
@@ -134,9 +134,9 @@ class GradientFieldViewFinder(nn.Module):
|
||||
|
||||
return in_process_sample, res
|
||||
|
||||
def next_best_view(self, seq_feat):
|
||||
def next_best_view(self, main_feat):
|
||||
data = {
|
||||
'seq_feat': seq_feat,
|
||||
'main_feat': main_feat,
|
||||
}
|
||||
in_process_sample, res = self.sample(data)
|
||||
return res.to(dtype=torch.float32), in_process_sample
|
||||
|
Reference in New Issue
Block a user