success
This commit is contained in:
17
modules/module_lib/position_embedding.py
Executable file
17
modules/module_lib/position_embedding.py
Executable file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
class PositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, num_channels, max_positions=10000, endpoint=False):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.max_positions = max_positions
|
||||
self.endpoint = endpoint
|
||||
|
||||
def forward(self, x):
|
||||
freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device)
|
||||
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
|
||||
freqs = (1 / self.max_positions) ** freqs
|
||||
x = x.ger(freqs.to(x.dtype))
|
||||
x = torch.cat([x.cos(), x.sin()], dim=1)
|
||||
return x
|
Reference in New Issue
Block a user