success
This commit is contained in:
17
modules/module_lib/gaussian_fourier_projection.py
Executable file
17
modules/module_lib/gaussian_fourier_projection.py
Executable file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""Gaussian random features for encoding time steps."""
|
||||
|
||||
def __init__(self, embed_dim, scale=30.):
|
||||
super().__init__()
|
||||
# Randomly sample weights during initialization. These weights are fixed
|
||||
# during optimization and are not trainable.
|
||||
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
||||
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
Reference in New Issue
Block a user