success
This commit is contained in:
41
modules/module_lib/rot_head.py
Executable file
41
modules/module_lib/rot_head.py
Executable file
@@ -0,0 +1,41 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class RotHead(nn.Module):
|
||||
def __init__(self, in_feat_dim, out_dim=3):
|
||||
super(RotHead, self).__init__()
|
||||
self.f = in_feat_dim
|
||||
self.k = out_dim
|
||||
|
||||
self.conv1 = torch.nn.Conv1d(self.f, 1024, 1)
|
||||
self.conv2 = torch.nn.Conv1d(1024, 256, 1)
|
||||
self.conv3 = torch.nn.Conv1d(256, 256, 1)
|
||||
self.conv4 = torch.nn.Conv1d(256, self.k, 1)
|
||||
self.drop1 = nn.Dropout(0.2)
|
||||
self.bn1 = nn.BatchNorm1d(1024)
|
||||
self.bn2 = nn.BatchNorm1d(256)
|
||||
self.bn3 = nn.BatchNorm1d(256)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
x = self.drop1(x)
|
||||
x = self.conv4(x)
|
||||
|
||||
x = x.squeeze(2)
|
||||
x = x.contiguous()
|
||||
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
points = torch.rand(2, 1350, 1024) # batch_size x feature x num_of_point
|
||||
rot_head = RotHead(in_feat_dim=1350, out_dim=3)
|
||||
rot = rot_head(points)
|
||||
print(rot.shape)
|
Reference in New Issue
Block a user