success
This commit is contained in:
30
modules/module_lib/linear.py
Executable file
30
modules/module_lib/linear.py
Executable file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def weight_init(shape, mode, fan_in, fan_out):
|
||||
if mode == 'xavier_uniform':
|
||||
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
|
||||
if mode == 'xavier_normal':
|
||||
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
|
||||
if mode == 'kaiming_uniform':
|
||||
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
|
||||
if mode == 'kaiming_normal':
|
||||
return np.sqrt(1 / fan_in) * torch.randn(*shape)
|
||||
raise ValueError(f'Invalid init mode "{mode}"')
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
|
||||
self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight)
|
||||
self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
x = x @ self.weight.to(x.dtype).t()
|
||||
if self.bias is not None:
|
||||
x = x.add_(self.bias.to(x.dtype))
|
||||
return x
|
Reference in New Issue
Block a user