update basic framework

This commit is contained in:
hofee
2024-08-21 17:11:56 +08:00
parent 73dcd592df
commit f977fd4b8e
29 changed files with 1393 additions and 719 deletions

View File

@@ -0,0 +1,17 @@
import torch
import numpy as np
import torch.nn as nn
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

View File

@@ -0,0 +1,30 @@
import torch
import numpy as np
def weight_init(shape, mode, fan_in, fan_out):
if mode == 'xavier_uniform':
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == 'xavier_normal':
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == 'kaiming_uniform':
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == 'kaiming_normal':
return np.sqrt(1 / fan_in) * torch.randn(*shape)
raise ValueError(f'Invalid init mode "{mode}"')
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
def forward(self, x):
x = x @ self.weight.to(x.dtype).t()
if self.bias is not None:
x = x.add_(self.bias.to(x.dtype))
return x