success
This commit is contained in:
56
modules/pts_encoder/pts_encoder_factory.py
Executable file
56
modules/pts_encoder/pts_encoder_factory.py
Executable file
@@ -0,0 +1,56 @@
|
||||
import sys
|
||||
import os
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(3):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||
from modules.pts_encoder.pointnet_encoder import PointNetEncoder
|
||||
from modules.pts_encoder.pointnet2_encoder import PointNet2Encoder
|
||||
from modules.pts_encoder.pointnet3_encoder import PointNet3Encoder
|
||||
|
||||
class PointsEncoderFactory:
|
||||
@staticmethod
|
||||
def create(name, config) -> PointsEncoder:
|
||||
general_config = config["general"]
|
||||
pts_encoder_config = config["pts_encoder"][name]
|
||||
if name == "pointnet":
|
||||
return PointNetEncoder(
|
||||
in_dim=general_config["pts_channels"],
|
||||
out_dim=general_config["feature_dim"],
|
||||
global_feat=not general_config["per_point_feature"]
|
||||
)
|
||||
elif name == "pointnet++":
|
||||
return PointNet2Encoder(
|
||||
input_channels=general_config["pts_channels"] - 3,
|
||||
params_name=pts_encoder_config["params_name"]
|
||||
)
|
||||
elif name == "pointnet++rgb":
|
||||
return PointNet3Encoder(
|
||||
input_channels=general_config["pts_channels"] - 3,
|
||||
params_name=pts_encoder_config["params_name"],
|
||||
target_layer=pts_encoder_config["target_layer"],
|
||||
rgb_feat_dim=pts_encoder_config["rgb_feat_dim"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder name: {name}")
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
import torch
|
||||
|
||||
pts = torch.rand(32, 1200, 3) # BxNxC
|
||||
ConfigManager.load_config_with('configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
||||
pts_encoder = PointsEncoderFactory.create(name="pointnet++", config=ConfigManager.get("modules"))
|
||||
print(pts_encoder)
|
||||
pts = pts.to("cuda")
|
||||
pts_encoder = pts_encoder.to("cuda")
|
||||
|
||||
pts_feat = pts_encoder.encode_points(pts)
|
||||
|
||||
print(pts_feat.shape)
|
Reference in New Issue
Block a user