update strong p++
This commit is contained in:
@@ -33,6 +33,29 @@ ClsMSG_CFG_Light = {
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Light_2048 = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 1024], [256, 512, 1024]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Strong = {
|
||||
'NPOINTS': [1024, 512, 256, 128, None], # 增加采样点,获取更多细节
|
||||
'RADIUS': [[0.02, 0.05], [0.05, 0.1], [0.1, 0.2], [0.2, 0.4], [None, None]], # 增大感受野
|
||||
'NSAMPLE': [[32, 64], [32, 64], [32, 64], [32, 64], [None, None]], # 提高每层的采样点数
|
||||
'MLPS': [[[32, 32, 64], [64, 64, 128]], # 增强 MLP 层,增加特征提取能力
|
||||
[[128, 128, 256], [128, 128, 256]],
|
||||
[[256, 256, 512], [256, 384, 512]],
|
||||
[[512, 512, 1024], [512, 768, 1024]],
|
||||
[[1024, 1024, 2048], [1024, 1024, 2048]]], # 增加更深的特征层
|
||||
'DP_RATIO': 0.4, # Dropout 比率稍微降低,以保留更多信息
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Lighter = {
|
||||
'NPOINTS': [512, 256, 128, 64, None],
|
||||
'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
|
||||
@@ -53,6 +76,10 @@ def select_params(name):
|
||||
return ClsMSG_CFG_Lighter
|
||||
elif name == 'dense':
|
||||
return ClsMSG_CFG_Dense
|
||||
elif name == 'light_2048':
|
||||
return ClsMSG_CFG_Light_2048
|
||||
elif name == 'strong':
|
||||
return ClsMSG_CFG_Strong
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -114,8 +141,8 @@ if __name__ == '__main__':
|
||||
seed = 100
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
net = PointNet2Encoder(config={"in_dim": 3, "params_name": "light"}).cuda()
|
||||
pts = torch.randn(2, 1024, 3).cuda()
|
||||
net = PointNet2Encoder(config={"in_dim": 3, "params_name": "strong"}).cuda()
|
||||
pts = torch.randn(2, 2444, 3).cuda()
|
||||
print(torch.mean(pts, dim=1))
|
||||
pre = net.encode_points(pts)
|
||||
print(pre.shape)
|
||||
|
Reference in New Issue
Block a user