diff --git a/modules/module_lib/pointnet2_utils/.gitignore b/modules/module_lib/pointnet2_utils/.gitignore new file mode 100644 index 0000000..cf42194 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/.gitignore @@ -0,0 +1,4 @@ +pointnet2/build/ +pointnet2/dist/ +pointnet2/pointnet2.egg-info/ +__pycache__/ diff --git a/modules/module_lib/pointnet2_utils/LICENSE b/modules/module_lib/pointnet2_utils/LICENSE new file mode 100644 index 0000000..77c8ebe --- /dev/null +++ b/modules/module_lib/pointnet2_utils/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Shaoshuai Shi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/modules/module_lib/pointnet2_utils/README.md b/modules/module_lib/pointnet2_utils/README.md new file mode 100644 index 0000000..c5a43f0 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/README.md @@ -0,0 +1,51 @@ +# Pointnet2.PyTorch + +* PyTorch implementation of [PointNet++](https://arxiv.org/abs/1706.02413) based on [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch). +* Faster than the original codes by re-implementing the CUDA operations. + +## Installation +### Requirements +* Linux (tested on Ubuntu 14.04/16.04) +* Python 3.6+ +* PyTorch 1.0 + +### Install +Install this library by running the following command: + +```shell +cd pointnet2 +python setup.py install +cd ../ +``` + +## Examples +Here I provide a simple example to use this library in the task of KITTI ourdoor foreground point cloud segmentation, and you could refer to the paper [PointRCNN](https://arxiv.org/abs/1812.04244) for the details of task description and foreground label generation. + +1. Download the training data from [KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) website and organize the downloaded files as follows: +``` +Pointnet2.PyTorch +├── pointnet2 +├── tools +│ ├──data +│ │ ├── KITTI +│ │ │ ├── ImageSets +│ │ │ ├── object +│ │ │ │ ├──training +│ │ │ │ ├──calib & velodyne & label_2 & image_2 +│ │ train_and_eval.py +``` + +2. Run the following command to train and evaluate: +```shell +cd tools +python train_and_eval.py --batch_size 8 --epochs 100 --ckpt_save_interval 2 +``` + + + +## Project using this repo: +* [PointRCNN](https://github.com/sshaoshuai/PointRCNN): 3D object detector from raw point cloud. + +## Acknowledgement +* [charlesq34/pointnet2](https://github.com/charlesq34/pointnet2): Paper author and official code repo. +* [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch): Initial work of PyTorch implementation of PointNet++. diff --git a/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_cuda.cpython-39-x86_64-linux-gnu.so b/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_cuda.cpython-39-x86_64-linux-gnu.so new file mode 100755 index 0000000..d0d5a94 Binary files /dev/null and b/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_cuda.cpython-39-x86_64-linux-gnu.so differ diff --git a/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_modules.py b/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_modules.py new file mode 100644 index 0000000..4b94326 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_modules.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import pointnet2_utils +from . import pytorch_utils as pt_utils +from typing import List + + +class _PointnetSAModuleBase(nn.Module): + + def __init__(self): + super().__init__() + self.npoint = None + self.groupers = None + self.mlps = None + self.pool_method = 'max_pool' + + def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): + """ + :param xyz: (B, N, 3) tensor of the xyz coordinates of the features + :param features: (B, N, C) tensor of the descriptors of the the features + :param new_xyz: + :return: + new_xyz: (B, npoint, 3) tensor of the new features' xyz + new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors + """ + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + if new_xyz is None: + new_xyz = pointnet2_utils.gather_operation( + xyz_flipped, + pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + for i in range(len(self.groupers)): + new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) + + new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) + + if self.pool_method == 'max_pool': + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + elif self.pool_method == 'avg_pool': + new_features = F.avg_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + else: + raise NotImplementedError + + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): + """Pointnet set abstraction layer with multiscale grouping""" + + def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, + use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param npoint: int + :param radii: list of float, list of radii to group with + :param nsamples: list of int, number of samples in each ball query + :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__() + + assert len(radii) == len(nsamples) == len(mlps) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) + if npoint is not None else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) + self.pool_method = pool_method + + +class PointnetSAModule(PointnetSAModuleMSG): + """Pointnet set abstraction layer""" + + def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, + bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param mlp: list of int, spec of the pointnet before the global max_pool + :param npoint: int, number of features + :param radius: float, radius of ball + :param nsample: int, number of samples in the ball query + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__( + mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, + pool_method=pool_method, instance_norm=instance_norm + ) + + +class PointnetFPModule(nn.Module): + r"""Propigates the features of one set to another""" + + def __init__(self, *, mlp: List[int], bn: bool = True): + """ + :param mlp: list of int + :param bn: whether to use batchnorm + """ + super().__init__() + self.mlp = pt_utils.SharedMLP(mlp, bn=bn) + + def forward( + self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor + ) -> torch.Tensor: + """ + :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features + :param known: (B, m, 3) tensor of the xyz positions of the known features + :param unknow_feats: (B, C1, n) tensor of the features to be propigated to + :param known_feats: (B, C2, m) tensor of features to be propigated + :return: + new_features: (B, mlp[-1], n) tensor of the features of the unknown features + """ + if known is not None: + dist, idx = pointnet2_utils.three_nn(unknown, known) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + + interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) + else: + interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) + + if unknow_feats is not None: + new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) + else: + new_features = interpolated_feats + + new_features = new_features.unsqueeze(-1) + + new_features = self.mlp(new_features) + + return new_features.squeeze(-1) + + +if __name__ == "__main__": + pass diff --git a/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_utils.py b/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_utils.py new file mode 100644 index 0000000..97a5466 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/pointnet2_utils.py @@ -0,0 +1,291 @@ +import torch +from torch.autograd import Variable +from torch.autograd import Function +import torch.nn as nn +from typing import Tuple +import sys + +import pointnet2_cuda as pointnet2 + + +class FurthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + :param ctx: + :param xyz: (B, N, 3) where N > npoint + :param npoint: int, number of features in the sampled set + :return: + output: (B, npoint) tensor containing the set + """ + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + output = torch.cuda.IntTensor(B, npoint) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + + pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +furthest_point_sample = FurthestPointSampling.apply + + +class GatherOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) + :param idx: (B, npoint) index tensor of the features to gather + :return: + output: (B, C, npoint) + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, npoint = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, npoint) + + pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) + + ctx.for_backwards = (idx, C, N) + return output + + @staticmethod + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + B, npoint = idx.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + grad_out_data = grad_out.data.contiguous() + pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) + return grad_features, None + + +gather_operation = GatherOperation.apply + + +class ThreeNN(Function): + + @staticmethod + def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Find the three nearest neighbors of unknown in known + :param ctx: + :param unknown: (B, N, 3) + :param known: (B, M, 3) + :return: + dist: (B, N, 3) l2 distance to the three nearest neighbors + idx: (B, N, 3) index of 3 nearest neighbors + """ + assert unknown.is_contiguous() + assert known.is_contiguous() + + B, N, _ = unknown.size() + m = known.size(1) + dist2 = torch.cuda.FloatTensor(B, N, 3) + idx = torch.cuda.IntTensor(B, N, 3) + + pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Performs weight linear interpolation on 3 features + :param ctx: + :param features: (B, C, M) Features descriptors to be interpolated from + :param idx: (B, n, 3) three nearest neighbors of the target features in features + :param weight: (B, n, 3) weights + :return: + output: (B, C, N) tensor of the interpolated features + """ + assert features.is_contiguous() + assert idx.is_contiguous() + assert weight.is_contiguous() + + B, c, m = features.size() + n = idx.size(1) + ctx.three_interpolate_for_backward = (idx, weight, m) + output = torch.cuda.FloatTensor(B, c, n) + + pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, N) tensor with gradients of outputs + :return: + grad_features: (B, C, M) tensor with gradients of features + None: + None: + """ + idx, weight, m = ctx.three_interpolate_for_backward + B, c, n = grad_out.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) + grad_out_data = grad_out.data.contiguous() + + pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply + + +class GroupingOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) tensor of features to group + :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with + :return: + output: (B, C, npoint, nsample) tensor + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, nfeatures, nsample = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) + + ctx.for_backwards = (idx, N) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward + :return: + grad_features: (B, C, N) gradient of the features + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + + grad_out_data = grad_out.data.contiguous() + pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply + + +class BallQuery(Function): + + @staticmethod + def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param radius: float, radius of the balls + :param nsample: int, maximum number of features in the balls + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centers of the ball query + :return: + idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls + """ + assert new_xyz.is_contiguous() + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + npoint = new_xyz.size(1) + idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() + + pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply + + +class QueryAndGroup(nn.Module): + def __init__(self, radius: float, nsample: int, use_xyz: bool = True): + """ + :param radius: float, radius of ball + :param nsample: int, maximum number of features to gather in the ball + :param use_xyz: + """ + super().__init__() + self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centroids + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, 3 + C, npoint, nsample) + """ + idx = ball_query(self.radius, self.nsample, xyz, new_xyz) + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" + new_features = grouped_xyz + + return new_features + + +class GroupAll(nn.Module): + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: ignored + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, C + 3, 1, N) + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features diff --git a/modules/module_lib/pointnet2_utils/pointnet2/pytorch_utils.py b/modules/module_lib/pointnet2_utils/pointnet2/pytorch_utils.py new file mode 100644 index 0000000..09cb7bc --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/pytorch_utils.py @@ -0,0 +1,236 @@ +import torch.nn as nn +from typing import List, Tuple + + +class SharedMLP(nn.Sequential): + + def __init__( + self, + args: List[int], + *, + bn: bool = False, + activation=nn.ReLU(inplace=True), + preact: bool = False, + first: bool = False, + name: str = "", + instance_norm: bool = False, + ): + super().__init__() + + for i in range(len(args) - 1): + self.add_module( + name + 'layer{}'.format(i), + Conv2d( + args[i], + args[i + 1], + bn=(not first or not preact or (i != 0)) and bn, + activation=activation + if (not first or not preact or (i != 0)) else None, + preact=preact, + instance_norm=instance_norm + ) + ) + + +class _ConvBase(nn.Sequential): + + def __init__( + self, + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=None, + batch_norm=None, + bias=True, + preact=False, + name="", + instance_norm=False, + instance_norm_func=None + ): + super().__init__() + + bias = bias and (not bn) + conv_unit = conv( + in_size, + out_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias + ) + init(conv_unit.weight) + if bias: + nn.init.constant_(conv_unit.bias, 0) + + if bn: + if not preact: + bn_unit = batch_norm(out_size) + else: + bn_unit = batch_norm(in_size) + if instance_norm: + if not preact: + in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) + else: + in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) + + if preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + self.add_module(name + 'conv', conv_unit) + + if not preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + +class _BNBase(nn.Sequential): + + def __init__(self, in_size, batch_norm=None, name=""): + super().__init__() + self.add_module(name + "bn", batch_norm(in_size)) + + nn.init.constant_(self[0].weight, 1.0) + nn.init.constant_(self[0].bias, 0) + + +class BatchNorm1d(_BNBase): + + def __init__(self, in_size: int, *, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) + + +class BatchNorm2d(_BNBase): + + def __init__(self, in_size: int, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) + + +class Conv1d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv1d, + batch_norm=BatchNorm1d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm1d + ) + + +class Conv2d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv2d, + batch_norm=BatchNorm2d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm2d + ) + + +class FC(nn.Sequential): + + def __init__( + self, + in_size: int, + out_size: int, + *, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=None, + preact: bool = False, + name: str = "" + ): + super().__init__() + + fc = nn.Linear(in_size, out_size, bias=not bn) + if init is not None: + init(fc.weight) + if not bn: + nn.init.constant(fc.bias, 0) + + if preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(in_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + + self.add_module(name + 'fc', fc) + + if not preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(out_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + diff --git a/modules/module_lib/pointnet2_utils/pointnet2/setup.py b/modules/module_lib/pointnet2_utils/pointnet2/setup.py new file mode 100644 index 0000000..99e59e3 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='pointnet2', + ext_modules=[ + CUDAExtension('pointnet2_cuda', [ + 'src/pointnet2_api.cpp', + + 'src/ball_query.cpp', + 'src/ball_query_gpu.cu', + 'src/group_points.cpp', + 'src/group_points_gpu.cu', + 'src/interpolate.cpp', + 'src/interpolate_gpu.cu', + 'src/sampling.cpp', + 'src/sampling_gpu.cu', + ], + extra_compile_args={'cxx': ['-g'], + 'nvcc': ['-O2']}) + ], + cmdclass={'build_ext': BuildExtension} +) diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query.cpp b/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query.cpp new file mode 100644 index 0000000..21f787e --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query.cpp @@ -0,0 +1,28 @@ +#include +#include +// #include +#include +#include +#include "ball_query_gpu.h" +#include +#include + +// extern THCState *state; + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") +#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) + +int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, + at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { + CHECK_INPUT(new_xyz_tensor); + CHECK_INPUT(xyz_tensor); + const float *new_xyz = new_xyz_tensor.data(); + const float *xyz = xyz_tensor.data(); + int *idx = idx_tensor.data(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); + return 1; +} \ No newline at end of file diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.cu b/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.cu new file mode 100644 index 0000000..f8840aa --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.cu @@ -0,0 +1,67 @@ +#include +#include +#include + +#include "ball_query_gpu.h" +#include "cuda_utils.h" + + +__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, + const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= m) return; + + new_xyz += bs_idx * m * 3 + pt_idx * 3; + xyz += bs_idx * n * 3; + idx += bs_idx * m * nsample + pt_idx * nsample; + + float radius2 = radius * radius; + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + int cnt = 0; + for (int k = 0; k < n; ++k) { + float x = xyz[k * 3 + 0]; + float y = xyz[k * 3 + 1]; + float z = xyz[k * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + if (d2 < radius2){ + if (cnt == 0){ + for (int l = 0; l < nsample; ++l) { + idx[l] = k; + } + } + idx[cnt] = k; + ++cnt; + if (cnt >= nsample) break; + } + } +} + + +void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ + const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { + // new_xyz: (B, M, 3) + // xyz: (B, N, 3) + // output: + // idx: (B, M, nsample) + + cudaError_t err; + + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} \ No newline at end of file diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.h b/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.h new file mode 100644 index 0000000..ffc831a --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/ball_query_gpu.h @@ -0,0 +1,15 @@ +#ifndef _BALL_QUERY_GPU_H +#define _BALL_QUERY_GPU_H + +#include +#include +#include +#include + +int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, + at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); + +void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, + const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); + +#endif diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/cuda_utils.h b/modules/module_lib/pointnet2_utils/pointnet2/src/cuda_utils.h new file mode 100644 index 0000000..7fe2796 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/cuda_utils.h @@ -0,0 +1,15 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include + +#define TOTAL_THREADS 1024 +#define THREADS_PER_BLOCK 256 +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} +#endif diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/group_points.cpp b/modules/module_lib/pointnet2_utils/pointnet2/src/group_points.cpp new file mode 100644 index 0000000..f0e74e9 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/group_points.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include +// #include +#include "group_points_gpu.h" +#include +#include +// extern THCState *state; + + +int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { + + float *grad_points = grad_points_tensor.data(); + const int *idx = idx_tensor.data(); + const float *grad_out = grad_out_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); + return 1; +} + + +int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { + + const float *points = points_tensor.data(); + const int *idx = idx_tensor.data(); + float *out = out_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); + return 1; +} \ No newline at end of file diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.cu b/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.cu new file mode 100644 index 0000000..c015a81 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.cu @@ -0,0 +1,86 @@ +#include +#include + +#include "cuda_utils.h" +#include "group_points_gpu.h" + + +__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, + const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); +} + +void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, npoints, nsample) + // idx: (B, npoints, nsample) + // output: + // grad_points: (B, C, N) + cudaError_t err; + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, + const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int pt_idx = index / nsample; + if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; + + int sample_idx = index % nsample; + + idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; + int in_idx = bs_idx * c * n + c_idx * n + idx[0]; + int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; + + out[out_idx] = points[in_idx]; +} + + +void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, float *out, cudaStream_t stream) { + // points: (B, C, N) + // idx: (B, npoints, nsample) + // output: + // out: (B, C, npoints, nsample) + cudaError_t err; + dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); + // cudaDeviceSynchronize(); // for using printf in kernel function + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.h b/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.h new file mode 100644 index 0000000..76c73ca --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/group_points_gpu.h @@ -0,0 +1,22 @@ +#ifndef _GROUP_POINTS_GPU_H +#define _GROUP_POINTS_GPU_H + +#include +#include +#include +#include + + +int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); + +void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *points, const int *idx, float *out, cudaStream_t stream); + +int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); + +void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); + +#endif diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate.cpp b/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate.cpp new file mode 100644 index 0000000..d01f045 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate.cpp @@ -0,0 +1,59 @@ +#include +#include +// #include +#include +#include +#include +#include +#include +#include +#include +#include "interpolate_gpu.h" + +// extern THCState *state; + + +void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, + at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { + const float *unknown = unknown_tensor.data(); + const float *known = known_tensor.data(); + float *dist2 = dist2_tensor.data(); + int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); +} + + +void three_interpolate_wrapper_fast(int b, int c, int m, int n, + at::Tensor points_tensor, + at::Tensor idx_tensor, + at::Tensor weight_tensor, + at::Tensor out_tensor) { + + const float *points = points_tensor.data(); + const float *weight = weight_tensor.data(); + float *out = out_tensor.data(); + const int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); +} + +void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, + at::Tensor grad_out_tensor, + at::Tensor idx_tensor, + at::Tensor weight_tensor, + at::Tensor grad_points_tensor) { + + const float *grad_out = grad_out_tensor.data(); + const float *weight = weight_tensor.data(); + float *grad_points = grad_points_tensor.data(); + const int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); +} \ No newline at end of file diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.cu b/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.cu new file mode 100644 index 0000000..a123dd8 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.cu @@ -0,0 +1,161 @@ +#include +#include +#include + +#include "cuda_utils.h" +#include "interpolate_gpu.h" + + +__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, + const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + int bs_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || pt_idx >= n) return; + + unknown += bs_idx * n * 3 + pt_idx * 3; + known += bs_idx * m * 3; + dist2 += bs_idx * n * 3 + pt_idx * 3; + idx += bs_idx * n * 3 + pt_idx * 3; + + float ux = unknown[0]; + float uy = unknown[1]; + float uz = unknown[2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + float x = known[k * 3 + 0]; + float y = known[k * 3 + 1]; + float z = known[k * 3 + 2]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; besti3 = besti2; + best2 = best1; besti2 = besti1; + best1 = d; besti1 = k; + } + else if (d < best2) { + best3 = best2; besti3 = besti2; + best2 = d; besti2 = k; + } + else if (d < best3) { + best3 = d; besti3 = k; + } + } + dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; + idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; +} + + +void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx, cudaStream_t stream) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, + const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + weight += bs_idx * n * 3 + pt_idx * 3; + points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + out += bs_idx * c * n + c_idx * n; + + out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; +} + +void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, + const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; + + grad_out += bs_idx * c * n + c_idx * n + pt_idx; + weight += bs_idx * n * 3 + pt_idx * 3; + grad_points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + + + atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); + atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); + atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); +} + +void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, + const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + cudaError_t err; + dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} \ No newline at end of file diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.h b/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.h new file mode 100644 index 0000000..f177108 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/interpolate_gpu.h @@ -0,0 +1,30 @@ +#ifndef _INTERPOLATE_GPU_H +#define _INTERPOLATE_GPU_H + +#include +#include +#include +#include + + +void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, + at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); + +void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, + const float *known, float *dist2, int *idx, cudaStream_t stream); + + +void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, + at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); + +void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, + const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); + + +void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, + at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); + +void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, + const int *idx, const float *weight, float *grad_points, cudaStream_t stream); + +#endif diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/pointnet2_api.cpp b/modules/module_lib/pointnet2_utils/pointnet2/src/pointnet2_api.cpp new file mode 100644 index 0000000..d91f0f2 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/pointnet2_api.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include "ball_query_gpu.h" +#include "group_points_gpu.h" +#include "sampling_gpu.h" +#include "interpolate_gpu.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); + + m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); + m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); + + m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); + m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); + + m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); + + m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); + m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); + m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); +} diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/sampling.cpp b/modules/module_lib/pointnet2_utils/pointnet2/src/sampling.cpp new file mode 100644 index 0000000..fbb277a --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/sampling.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +// #include + +#include "sampling_gpu.h" +#include +#include + +// extern THCState *state; + + +int gather_points_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ + const float *points = points_tensor.data(); + const int *idx = idx_tensor.data(); + float *out = out_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); + return 1; +} + + +int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { + + const float *grad_out = grad_out_tensor.data(); + const int *idx = idx_tensor.data(); + float *grad_points = grad_points_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); + return 1; +} + + +int furthest_point_sampling_wrapper(int b, int n, int m, + at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { + + const float *points = points_tensor.data(); + float *temp = temp_tensor.data(); + int *idx = idx_tensor.data(); + + // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); + return 1; +} diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.cu b/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.cu new file mode 100644 index 0000000..9e49a60 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.cu @@ -0,0 +1,253 @@ +#include +#include + +#include "cuda_utils.h" +#include "sampling_gpu.h" + + +__global__ void gather_points_kernel_fast(int b, int c, int n, int m, + const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { + // points: (B, C, N) + // idx: (B, M) + // output: + // out: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + points += bs_idx * c * n + c_idx * n; + out[0] = points[idx[0]]; +} + +void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *points, const int *idx, float *out, cudaStream_t stream) { + // points: (B, C, N) + // idx: (B, npoints) + // output: + // out: (B, C, npoints) + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + +__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, + const int *__restrict__ idx, float *__restrict__ grad_points) { + // grad_out: (B, C, M) + // idx: (B, M) + // output: + // grad_points: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; + + grad_out += bs_idx * c * m + c_idx * m + pt_idx; + idx += bs_idx * m + pt_idx; + grad_points += bs_idx * c * n + c_idx * n; + + atomicAdd(grad_points + idx[0], grad_out[0]); +} + +void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { + // grad_out: (B, C, npoints) + // idx: (B, npoints) + // output: + // grad_points: (B, C, N) + + cudaError_t err; + dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) + dim3 threads(THREADS_PER_BLOCK); + + gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} + + +__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +template +__global__ void furthest_point_sampling_kernel(int b, int n, int m, + const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + if (m <= 0) return; + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int batch_index = blockIdx.x; + dataset += batch_index * n * 3; + temp += batch_index * n; + idxs += batch_index * m; + + int tid = threadIdx.x; + const int stride = block_size; + + int old = 0; + if (threadIdx.x == 0) + idxs[0] = old; + + __syncthreads(); + for (int j = 1; j < m; j++) { + int besti = 0; + float best = -1; + float x1 = dataset[old * 3 + 0]; + float y1 = dataset[old * 3 + 1]; + float z1 = dataset[old * 3 + 2]; + for (int k = tid; k < n; k += stride) { + float x2, y2, z2; + x2 = dataset[k * 3 + 0]; + y2 = dataset[k * 3 + 1]; + z2 = dataset[k * 3 + 2]; + // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + // if (mag <= 1e-3) + // continue; + + float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, temp[k]); + temp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 1024) { + if (tid < 512) { + __update(dists, dists_i, tid, tid + 512); + } + __syncthreads(); + } + + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) + idxs[j] = old; + } +} + +void furthest_point_sampling_kernel_launcher(int b, int n, int m, + const float *dataset, float *temp, int *idxs, cudaStream_t stream) { + // dataset: (B, N, 3) + // tmp: (B, N) + // output: + // idx: (B, M) + + cudaError_t err; + unsigned int n_threads = opt_n_threads(n); + + switch (n_threads) { + case 1024: + furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; + case 512: + furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; + case 256: + furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; + case 128: + furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; + case 64: + furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; + case 32: + furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; + case 16: + furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; + case 8: + furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; + case 4: + furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; + case 2: + furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; + case 1: + furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; + default: + furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); + } + + err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } +} diff --git a/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.h b/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.h new file mode 100644 index 0000000..6200c59 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/pointnet2/src/sampling_gpu.h @@ -0,0 +1,29 @@ +#ifndef _SAMPLING_GPU_H +#define _SAMPLING_GPU_H + +#include +#include +#include + + +int gather_points_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); + +void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *points, const int *idx, float *out, cudaStream_t stream); + + +int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, + at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); + +void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, + const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); + + +int furthest_point_sampling_wrapper(int b, int n, int m, + at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); + +void furthest_point_sampling_kernel_launcher(int b, int n, int m, + const float *dataset, float *temp, int *idxs, cudaStream_t stream); + +#endif diff --git a/modules/module_lib/pointnet2_utils/tools/_init_path.py b/modules/module_lib/pointnet2_utils/tools/_init_path.py new file mode 100644 index 0000000..c6c4565 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/tools/_init_path.py @@ -0,0 +1,2 @@ +import os, sys +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')) diff --git a/modules/module_lib/pointnet2_utils/tools/dataset.py b/modules/module_lib/pointnet2_utils/tools/dataset.py new file mode 100644 index 0000000..deca8ec --- /dev/null +++ b/modules/module_lib/pointnet2_utils/tools/dataset.py @@ -0,0 +1,188 @@ +import os +import numpy as np +import torch.utils.data as torch_data +import kitti_utils +import cv2 +from PIL import Image + + +USE_INTENSITY = False + + +class KittiDataset(torch_data.Dataset): + def __init__(self, root_dir, split='train', mode='TRAIN'): + self.split = split + self.mode = mode + self.classes = ['Car'] + is_test = self.split == 'test' + self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training') + + split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt') + self.image_idx_list = [x.strip() for x in open(split_dir).readlines()] + self.sample_id_list = [int(sample_id) for sample_id in self.image_idx_list] + self.num_sample = self.image_idx_list.__len__() + + self.npoints = 16384 + + self.image_dir = os.path.join(self.imageset_dir, 'image_2') + self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne') + self.calib_dir = os.path.join(self.imageset_dir, 'calib') + self.label_dir = os.path.join(self.imageset_dir, 'label_2') + self.plane_dir = os.path.join(self.imageset_dir, 'planes') + + def get_image(self, idx): + img_file = os.path.join(self.image_dir, '%06d.png' % idx) + assert os.path.exists(img_file) + return cv2.imread(img_file) # (H, W, 3) BGR mode + + def get_image_shape(self, idx): + img_file = os.path.join(self.image_dir, '%06d.png' % idx) + assert os.path.exists(img_file) + im = Image.open(img_file) + width, height = im.size + return height, width, 3 + + def get_lidar(self, idx): + lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx) + assert os.path.exists(lidar_file) + return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4) + + def get_calib(self, idx): + calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx) + assert os.path.exists(calib_file) + return kitti_utils.Calibration(calib_file) + + def get_label(self, idx): + label_file = os.path.join(self.label_dir, '%06d.txt' % idx) + assert os.path.exists(label_file) + return kitti_utils.get_objects_from_label(label_file) + + @staticmethod + def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape): + val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1]) + val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0]) + val_flag_merge = np.logical_and(val_flag_1, val_flag_2) + pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0) + return pts_valid_flag + + def filtrate_objects(self, obj_list): + type_whitelist = self.classes + if self.mode == 'TRAIN': + type_whitelist = list(self.classes) + if 'Car' in self.classes: + type_whitelist.append('Van') + + valid_obj_list = [] + for obj in obj_list: + if obj.cls_type not in type_whitelist: + continue + + valid_obj_list.append(obj) + return valid_obj_list + + def __len__(self): + return len(self.sample_id_list) + + def __getitem__(self, index): + sample_id = int(self.sample_id_list[index]) + calib = self.get_calib(sample_id) + img_shape = self.get_image_shape(sample_id) + pts_lidar = self.get_lidar(sample_id) + + # get valid point (projected points should be in image) + pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3]) + pts_intensity = pts_lidar[:, 3] + + pts_img, pts_rect_depth = calib.rect_to_img(pts_rect) + pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape) + + pts_rect = pts_rect[pts_valid_flag][:, 0:3] + pts_intensity = pts_intensity[pts_valid_flag] + + if self.npoints < len(pts_rect): + pts_depth = pts_rect[:, 2] + pts_near_flag = pts_depth < 40.0 + far_idxs_choice = np.where(pts_near_flag == 0)[0] + near_idxs = np.where(pts_near_flag == 1)[0] + near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False) + + choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \ + if len(far_idxs_choice) > 0 else near_idxs_choice + np.random.shuffle(choice) + else: + choice = np.arange(0, len(pts_rect), dtype=np.int32) + if self.npoints > len(pts_rect): + extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False) + choice = np.concatenate((choice, extra_choice), axis=0) + np.random.shuffle(choice) + + ret_pts_rect = pts_rect[choice, :] + ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5] + + pts_features = [ret_pts_intensity.reshape(-1, 1)] + ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0] + + sample_info = {'sample_id': sample_id} + + if self.mode == 'TEST': + if USE_INTENSITY: + pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) + else: + pts_input = ret_pts_rect + sample_info['pts_input'] = pts_input + sample_info['pts_rect'] = ret_pts_rect + sample_info['pts_features'] = ret_pts_features + return sample_info + + gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) + + gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) + + # prepare input + if USE_INTENSITY: + pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) + else: + pts_input = ret_pts_rect + + # generate training labels + cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d) + sample_info['pts_input'] = pts_input + sample_info['pts_rect'] = ret_pts_rect + sample_info['cls_labels'] = cls_labels + return sample_info + + @staticmethod + def generate_training_labels(pts_rect, gt_boxes3d): + cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32) + gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True) + extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2) + extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True) + for k in range(gt_boxes3d.shape[0]): + box_corners = gt_corners[k] + fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners) + cls_label[fg_pt_flag] = 1 + + # enlarge the bbox3d, ignore nearby points + extend_box_corners = extend_gt_corners[k] + fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners) + ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag) + cls_label[ignore_flag] = -1 + + return cls_label + + def collate_batch(self, batch): + batch_size = batch.__len__() + ans_dict = {} + + for key in batch[0].keys(): + if isinstance(batch[0][key], np.ndarray): + ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0) + + else: + ans_dict[key] = [batch[k][key] for k in range(batch_size)] + if isinstance(batch[0][key], int): + ans_dict[key] = np.array(ans_dict[key], dtype=np.int32) + elif isinstance(batch[0][key], float): + ans_dict[key] = np.array(ans_dict[key], dtype=np.float32) + + return ans_dict diff --git a/modules/module_lib/pointnet2_utils/tools/kitti_utils.py b/modules/module_lib/pointnet2_utils/tools/kitti_utils.py new file mode 100644 index 0000000..43f06b3 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/tools/kitti_utils.py @@ -0,0 +1,229 @@ +import numpy as np +from scipy.spatial import Delaunay +import scipy + + +def cls_type_to_id(cls_type): + type_to_id = {'Car': 1, 'Pedestrian': 2, 'Cyclist': 3, 'Van': 4} + if cls_type not in type_to_id.keys(): + return -1 + return type_to_id[cls_type] + + +class Object3d(object): + def __init__(self, line): + label = line.strip().split(' ') + self.src = line + self.cls_type = label[0] + self.cls_id = cls_type_to_id(self.cls_type) + self.trucation = float(label[1]) + self.occlusion = float(label[2]) # 0:fully visible 1:partly occluded 2:largely occluded 3:unknown + self.alpha = float(label[3]) + self.box2d = np.array((float(label[4]), float(label[5]), float(label[6]), float(label[7])), dtype=np.float32) + self.h = float(label[8]) + self.w = float(label[9]) + self.l = float(label[10]) + self.pos = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32) + self.dis_to_cam = np.linalg.norm(self.pos) + self.ry = float(label[14]) + self.score = float(label[15]) if label.__len__() == 16 else -1.0 + self.level_str = None + self.level = self.get_obj_level() + + def get_obj_level(self): + height = float(self.box2d[3]) - float(self.box2d[1]) + 1 + + if height >= 40 and self.trucation <= 0.15 and self.occlusion <= 0: + self.level_str = 'Easy' + return 1 # Easy + elif height >= 25 and self.trucation <= 0.3 and self.occlusion <= 1: + self.level_str = 'Moderate' + return 2 # Moderate + elif height >= 25 and self.trucation <= 0.5 and self.occlusion <= 2: + self.level_str = 'Hard' + return 3 # Hard + else: + self.level_str = 'UnKnown' + return 4 + + def generate_corners3d(self): + """ + generate corners3d representation for this object + :return corners_3d: (8, 3) corners of box3d in camera coord + """ + l, h, w = self.l, self.h, self.w + x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2] + y_corners = [0, 0, 0, 0, -h, -h, -h, -h] + z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2] + + R = np.array([[np.cos(self.ry), 0, np.sin(self.ry)], + [0, 1, 0], + [-np.sin(self.ry), 0, np.cos(self.ry)]]) + corners3d = np.vstack([x_corners, y_corners, z_corners]) # (3, 8) + corners3d = np.dot(R, corners3d).T + corners3d = corners3d + self.pos + return corners3d + + def to_str(self): + print_str = '%s %.3f %.3f %.3f box2d: %s hwl: [%.3f %.3f %.3f] pos: %s ry: %.3f' \ + % (self.cls_type, self.trucation, self.occlusion, self.alpha, self.box2d, self.h, self.w, self.l, + self.pos, self.ry) + return print_str + + def to_kitti_format(self): + kitti_str = '%s %.2f %d %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f' \ + % (self.cls_type, self.trucation, int(self.occlusion), self.alpha, self.box2d[0], self.box2d[1], + self.box2d[2], self.box2d[3], self.h, self.w, self.l, self.pos[0], self.pos[1], self.pos[2], + self.ry) + return kitti_str + + +def get_calib_from_file(calib_file): + with open(calib_file) as f: + lines = f.readlines() + + obj = lines[2].strip().split(' ')[1:] + P2 = np.array(obj, dtype=np.float32) + obj = lines[3].strip().split(' ')[1:] + P3 = np.array(obj, dtype=np.float32) + obj = lines[4].strip().split(' ')[1:] + R0 = np.array(obj, dtype=np.float32) + obj = lines[5].strip().split(' ')[1:] + Tr_velo_to_cam = np.array(obj, dtype=np.float32) + + return {'P2': P2.reshape(3, 4), + 'P3': P3.reshape(3, 4), + 'R0': R0.reshape(3, 3), + 'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)} + + +class Calibration(object): + def __init__(self, calib_file): + if isinstance(calib_file, str): + calib = get_calib_from_file(calib_file) + else: + calib = calib_file + + self.P2 = calib['P2'] # 3 x 4 + self.R0 = calib['R0'] # 3 x 3 + self.V2C = calib['Tr_velo2cam'] # 3 x 4 + + def cart_to_hom(self, pts): + """ + :param pts: (N, 3 or 2) + :return pts_hom: (N, 4 or 3) + """ + pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32))) + return pts_hom + + def lidar_to_rect(self, pts_lidar): + """ + :param pts_lidar: (N, 3) + :return pts_rect: (N, 3) + """ + pts_lidar_hom = self.cart_to_hom(pts_lidar) + pts_rect = np.dot(pts_lidar_hom, np.dot(self.V2C.T, self.R0.T)) + return pts_rect + + def rect_to_img(self, pts_rect): + """ + :param pts_rect: (N, 3) + :return pts_img: (N, 2) + """ + pts_rect_hom = self.cart_to_hom(pts_rect) + pts_2d_hom = np.dot(pts_rect_hom, self.P2.T) + pts_img = (pts_2d_hom[:, 0:2].T / pts_rect_hom[:, 2]).T # (N, 2) + pts_rect_depth = pts_2d_hom[:, 2] - self.P2.T[3, 2] # depth in rect camera coord + return pts_img, pts_rect_depth + + def lidar_to_img(self, pts_lidar): + """ + :param pts_lidar: (N, 3) + :return pts_img: (N, 2) + """ + pts_rect = self.lidar_to_rect(pts_lidar) + pts_img, pts_depth = self.rect_to_img(pts_rect) + return pts_img, pts_depth + + +def get_objects_from_label(label_file): + with open(label_file, 'r') as f: + lines = f.readlines() + objects = [Object3d(line) for line in lines] + return objects + + +def objs_to_boxes3d(obj_list): + boxes3d = np.zeros((obj_list.__len__(), 7), dtype=np.float32) + for k, obj in enumerate(obj_list): + boxes3d[k, 0:3], boxes3d[k, 3], boxes3d[k, 4], boxes3d[k, 5], boxes3d[k, 6] \ + = obj.pos, obj.h, obj.w, obj.l, obj.ry + return boxes3d + + +def boxes3d_to_corners3d(boxes3d, rotate=True): + """ + :param boxes3d: (N, 7) [x, y, z, h, w, l, ry] + :param rotate: + :return: corners3d: (N, 8, 3) + """ + boxes_num = boxes3d.shape[0] + h, w, l = boxes3d[:, 3], boxes3d[:, 4], boxes3d[:, 5] + x_corners = np.array([l / 2., l / 2., -l / 2., -l / 2., l / 2., l / 2., -l / 2., -l / 2.], dtype=np.float32).T # (N, 8) + z_corners = np.array([w / 2., -w / 2., -w / 2., w / 2., w / 2., -w / 2., -w / 2., w / 2.], dtype=np.float32).T # (N, 8) + + y_corners = np.zeros((boxes_num, 8), dtype=np.float32) + y_corners[:, 4:8] = -h.reshape(boxes_num, 1).repeat(4, axis=1) # (N, 8) + + if rotate: + ry = boxes3d[:, 6] + zeros, ones = np.zeros(ry.size, dtype=np.float32), np.ones(ry.size, dtype=np.float32) + rot_list = np.array([[np.cos(ry), zeros, -np.sin(ry)], + [zeros, ones, zeros], + [np.sin(ry), zeros, np.cos(ry)]]) # (3, 3, N) + R_list = np.transpose(rot_list, (2, 0, 1)) # (N, 3, 3) + + temp_corners = np.concatenate((x_corners.reshape(-1, 8, 1), y_corners.reshape(-1, 8, 1), + z_corners.reshape(-1, 8, 1)), axis=2) # (N, 8, 3) + rotated_corners = np.matmul(temp_corners, R_list) # (N, 8, 3) + x_corners, y_corners, z_corners = rotated_corners[:, :, 0], rotated_corners[:, :, 1], rotated_corners[:, :, 2] + + x_loc, y_loc, z_loc = boxes3d[:, 0], boxes3d[:, 1], boxes3d[:, 2] + + x = x_loc.reshape(-1, 1) + x_corners.reshape(-1, 8) + y = y_loc.reshape(-1, 1) + y_corners.reshape(-1, 8) + z = z_loc.reshape(-1, 1) + z_corners.reshape(-1, 8) + + corners = np.concatenate((x.reshape(-1, 8, 1), y.reshape(-1, 8, 1), z.reshape(-1, 8, 1)), axis=2) + + return corners.astype(np.float32) + + +def enlarge_box3d(boxes3d, extra_width): + """ + :param boxes3d: (N, 7) [x, y, z, h, w, l, ry] + """ + if isinstance(boxes3d, np.ndarray): + large_boxes3d = boxes3d.copy() + else: + large_boxes3d = boxes3d.clone() + large_boxes3d[:, 3:6] += extra_width * 2 + large_boxes3d[:, 1] += extra_width + return large_boxes3d + + +def in_hull(p, hull): + """ + :param p: (N, K) test points + :param hull: (M, K) M corners of a box + :return (N) bool + """ + try: + if not isinstance(hull, Delaunay): + hull = Delaunay(hull) + flag = hull.find_simplex(p) >= 0 + except scipy.spatial.qhull.QhullError: + print('Warning: not a hull %s' % str(hull)) + flag = np.zeros(p.shape[0], dtype=np.bool) + + return flag diff --git a/modules/module_lib/pointnet2_utils/tools/pointnet2_msg.py b/modules/module_lib/pointnet2_utils/tools/pointnet2_msg.py new file mode 100644 index 0000000..59a2207 --- /dev/null +++ b/modules/module_lib/pointnet2_utils/tools/pointnet2_msg.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +import sys +sys.path.append('..') +from pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG +import pointnet2.pytorch_utils as pt_utils + + +def get_model(input_channels=0): + return Pointnet2MSG(input_channels=input_channels) + + +NPOINTS = [4096, 1024, 256, 64] +RADIUS = [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]] +NSAMPLE = [[16, 32], [16, 32], [16, 32], [16, 32]] +MLPS = [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]] +FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]] +CLS_FC = [128] +DP_RATIO = 0.5 + + +class Pointnet2MSG(nn.Module): + def __init__(self, input_channels=6): + super().__init__() + + self.SA_modules = nn.ModuleList() + channel_in = input_channels + + skip_channel_list = [input_channels] + for k in range(NPOINTS.__len__()): + mlps = MLPS[k].copy() + channel_out = 0 + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + PointnetSAModuleMSG( + npoint=NPOINTS[k], + radii=RADIUS[k], + nsamples=NSAMPLE[k], + mlps=mlps, + use_xyz=True, + bn=True + ) + ) + skip_channel_list.append(channel_out) + channel_in = channel_out + + self.FP_modules = nn.ModuleList() + + for k in range(FP_MLPS.__len__()): + pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out + self.FP_modules.append( + PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k]) + ) + + cls_layers = [] + pre_channel = FP_MLPS[0][-1] + for k in range(0, CLS_FC.__len__()): + cls_layers.append(pt_utils.Conv1d(pre_channel, CLS_FC[k], bn=True)) + pre_channel = CLS_FC[k] + cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None)) + cls_layers.insert(1, nn.Dropout(0.5)) + self.cls_layer = nn.Sequential(*cls_layers) + + def _break_up_pc(self, pc): + xyz = pc[..., 0:3].contiguous() + features = ( + pc[..., 3:].transpose(1, 2).contiguous() + if pc.size(-1) > 3 else None + ) + + return xyz, features + + def forward(self, pointcloud: torch.cuda.FloatTensor): + xyz, features = self._break_up_pc(pointcloud) + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + + print(li_xyz.shape, li_features.shape) + + l_xyz.append(li_xyz) + l_features.append(li_features) + + for i in range(-1, -(len(self.FP_modules) + 1), -1): + l_features[i - 1] = self.FP_modules[i]( + l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] + ) + + pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 1) + return pred_cls + +if __name__ == '__main__': + net = Pointnet2MSG(0).cuda() + pts = torch.randn(2, 1024, 3).cuda() + + pre = net(pts) + print(pre.shape) diff --git a/modules/module_lib/pointnet2_utils/tools/train_and_eval.py b/modules/module_lib/pointnet2_utils/tools/train_and_eval.py new file mode 100644 index 0000000..d35502b --- /dev/null +++ b/modules/module_lib/pointnet2_utils/tools/train_and_eval.py @@ -0,0 +1,217 @@ +import _init_path +import numpy as np +import os +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lr_sched +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader +import tensorboard_logger as tb_log +from dataset import KittiDataset +import argparse +import importlib + +parser = argparse.ArgumentParser(description="Arg parser") +parser.add_argument("--batch_size", type=int, default=8) +parser.add_argument("--epochs", type=int, default=100) +parser.add_argument("--ckpt_save_interval", type=int, default=5) +parser.add_argument('--workers', type=int, default=4) +parser.add_argument("--mode", type=str, default='train') +parser.add_argument("--ckpt", type=str, default='None') + +parser.add_argument("--net", type=str, default='pointnet2_msg') + +parser.add_argument('--lr', type=float, default=0.002) +parser.add_argument('--lr_decay', type=float, default=0.2) +parser.add_argument('--lr_clip', type=float, default=0.000001) +parser.add_argument('--decay_step_list', type=list, default=[50, 70, 80, 90]) +parser.add_argument('--weight_decay', type=float, default=0.001) + +parser.add_argument("--output_dir", type=str, default='output') +parser.add_argument("--extra_tag", type=str, default='default') + +args = parser.parse_args() + +FG_THRESH = 0.3 + + +def log_print(info, log_f=None): + print(info) + if log_f is not None: + print(info, file=log_f) + + +class DiceLoss(nn.Module): + def __init__(self, ignore_target=-1): + super().__init__() + self.ignore_target = ignore_target + + def forward(self, input, target): + """ + :param input: (N), logit + :param target: (N), {0, 1} + :return: + """ + input = torch.sigmoid(input.view(-1)) + target = target.float().view(-1) + mask = (target != self.ignore_target).float() + return 1.0 - (torch.min(input, target) * mask).sum() / torch.clamp((torch.max(input, target) * mask).sum(), min=1.0) + + +def train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f): + model.train() + log_print('===============TRAIN EPOCH %d================' % epoch, log_f=log_f) + loss_func = DiceLoss(ignore_target=-1) + + for it, batch in enumerate(train_loader): + optimizer.zero_grad() + + pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] + pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() + cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) + + pred_cls = model(pts_input) + pred_cls = pred_cls.view(-1) + + loss = loss_func(pred_cls, cls_labels) + loss.backward() + clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + + total_it += 1 + + pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) + fg_mask = cls_labels > 0 + correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() + union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct + iou = correct / torch.clamp(union, min=1.0) + + cur_lr = lr_scheduler.get_lr()[0] + tb_log.log_value('learning_rate', cur_lr, epoch) + if tb_log is not None: + tb_log.log_value('train_loss', loss, total_it) + tb_log.log_value('train_fg_iou', iou, total_it) + + log_print('training epoch %d: it=%d/%d, total_it=%d, loss=%.5f, fg_iou=%.3f, lr=%f' % + (epoch, it, len(train_loader), total_it, loss.item(), iou.item(), cur_lr), log_f=log_f) + + return total_it + + +def eval_one_epoch(model, eval_loader, epoch, tb_log=None, log_f=None): + model.train() + log_print('===============EVAL EPOCH %d================' % epoch, log_f=log_f) + + iou_list = [] + for it, batch in enumerate(eval_loader): + pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] + pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() + cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) + + pred_cls = model(pts_input) + pred_cls = pred_cls.view(-1) + + pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) + fg_mask = cls_labels > 0 + correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() + union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct + iou = correct / torch.clamp(union, min=1.0) + + iou_list.append(iou.item()) + log_print('EVAL: it=%d/%d, iou=%.3f' % (it, len(eval_loader), iou), log_f=log_f) + + iou_list = np.array(iou_list) + avg_iou = iou_list.mean() + if tb_log is not None: + tb_log.log_value('eval_fg_iou', avg_iou, epoch) + + log_print('\nEpoch %d: Average IoU (samples=%d): %.6f' % (epoch, iou_list.__len__(), avg_iou), log_f=log_f) + return avg_iou + + +def save_checkpoint(model, epoch, ckpt_name): + if isinstance(model, torch.nn.DataParallel): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() + + state = {'epoch': epoch, 'model_state': model_state} + ckpt_name = '{}.pth'.format(ckpt_name) + torch.save(state, ckpt_name) + + +def load_checkpoint(model, filename): + if os.path.isfile(filename): + log_print("==> Loading from checkpoint %s" % filename) + checkpoint = torch.load(filename) + epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['model_state']) + log_print("==> Done") + else: + raise FileNotFoundError + + return epoch + + +def train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f): + model.cuda() + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + def lr_lbmd(cur_epoch): + cur_decay = 1 + for decay_step in args.decay_step_list: + if cur_epoch >= decay_step: + cur_decay = cur_decay * args.lr_decay + return max(cur_decay, args.lr_clip / args.lr) + + lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) + + total_it = 0 + for epoch in range(1, args.epochs + 1): + lr_scheduler.step(epoch) + total_it = train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f) + + if epoch % args.ckpt_save_interval == 0: + with torch.no_grad(): + avg_iou = eval_one_epoch(model, eval_loader, epoch, tb_log, log_f) + ckpt_name = os.path.join(ckpt_dir, 'checkpoint_epoch_%d' % epoch) + save_checkpoint(model, epoch, ckpt_name) + + +if __name__ == '__main__': + MODEL = importlib.import_module(args.net) # import network module + model = MODEL.get_model(input_channels=0) + + eval_set = KittiDataset(root_dir='./data', mode='EVAL', split='val') + eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, + num_workers=args.workers, collate_fn=eval_set.collate_batch) + + if args.mode == 'train': + train_set = KittiDataset(root_dir='./data', mode='TRAIN', split='train') + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True, + num_workers=args.workers, collate_fn=train_set.collate_batch) + # output dir config + output_dir = os.path.join(args.output_dir, args.extra_tag) + os.makedirs(output_dir, exist_ok=True) + tb_log.configure(os.path.join(output_dir, 'tensorboard')) + ckpt_dir = os.path.join(output_dir, 'ckpt') + os.makedirs(ckpt_dir, exist_ok=True) + + log_file = os.path.join(output_dir, 'log.txt') + log_f = open(log_file, 'w') + + for key, val in vars(args).items(): + log_print("{:16} {}".format(key, val), log_f=log_f) + + # train and eval + train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f) + log_f.close() + elif args.mode == 'eval': + epoch = load_checkpoint(model, args.ckpt) + model.cuda() + with torch.no_grad(): + avg_iou = eval_one_epoch(model, eval_loader, epoch) + else: + raise NotImplementedError + diff --git a/modules/pointnet++_encoder.py b/modules/pointnet++_encoder.py new file mode 100644 index 0000000..ee982a3 --- /dev/null +++ b/modules/pointnet++_encoder.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import os +import sys +path = os.path.abspath(__file__) +for i in range(2): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) +from modules.module_lib.pointnet2_utils.pointnet2.pointnet2_modules import PointnetSAModuleMSG + +ClsMSG_CFG_Dense = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Light = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Lighter = { + 'NPOINTS': [512, 256, 128, 64, None], + 'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]], + 'NSAMPLE': [[64], [32], [16], [8], [None]], + 'MLPS': [[[32, 32, 64]], + [[64, 64, 128]], + [[128, 196, 256]], + [[256, 256, 512]], + [[512, 512, 1024]]], + 'DP_RATIO': 0.5, +} + + +def select_params(name): + if name == 'light': + return ClsMSG_CFG_Light + elif name == 'lighter': + return ClsMSG_CFG_Lighter + elif name == 'dense': + return ClsMSG_CFG_Dense + else: + raise NotImplementedError + + +def break_up_pc(pc): + xyz = pc[..., 0:3].contiguous() + features = ( + pc[..., 3:].transpose(1, 2).contiguous() + if pc.size(-1) > 3 else None + ) + + return xyz, features + + +class PointNet2Encoder(nn.Module): + def encode_points(self, pts): + return self.forward(pts) + + def __init__(self, config:dict): + super().__init__() + + input_channels = config.get("in_dim", 0) + params_name = config.get("params_name", "light") + + self.SA_modules = nn.ModuleList() + channel_in = input_channels + selected_params = select_params(params_name) + for k in range(selected_params['NPOINTS'].__len__()): + mlps = selected_params['MLPS'][k].copy() + channel_out = 0 + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + PointnetSAModuleMSG( + npoint=selected_params['NPOINTS'][k], + radii=selected_params['RADIUS'][k], + nsamples=selected_params['NSAMPLE'][k], + mlps=mlps, + use_xyz=True, + bn=True + ) + ) + channel_in = channel_out + + def forward(self, point_cloud: torch.cuda.FloatTensor): + xyz, features = break_up_pc(point_cloud) + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + l_xyz.append(li_xyz) + l_features.append(li_features) + return l_features[-1].squeeze(-1) + + +if __name__ == '__main__': + seed = 100 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + net = PointNet2Encoder(config={"in_dim": 0, "params_name": "light"}).cuda() + pts = torch.randn(2, 1024, 3).cuda() + print(torch.mean(pts, dim=1)) + pre = net.encode_points(pts) + print(pre.shape) diff --git a/runners/inferencer.py b/runners/inferencer.py index bf4185f..66e202e 100644 --- a/runners/inferencer.py +++ b/runners/inferencer.py @@ -142,6 +142,7 @@ class Inferencer(Runner): voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold) output = self.pipeline(input_data) pred_pose_9d = output["pred_pose_9d"] + import ipdb; ipdb.set_trace() pred_pose = torch.eye(4, device=pred_pose_9d.device) pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0] diff --git a/utils/render.py b/utils/render.py index 3949f50..1356208 100644 --- a/utils/render.py +++ b/utils/render.py @@ -83,7 +83,7 @@ class RenderUtil: shutil.copy(scene_info_path, os.path.join(temp_dir, "scene_info.json")) params_data_path = os.path.join(temp_dir, "params.json") with open(params_data_path, 'w') as f: - json.dump(params, f) + json.dump(params, f) result = subprocess.run([ '/home/hofee/blender-4.0.2-linux-x64/blender', '-b', '-P', script_path, '--', temp_dir ], capture_output=True, text=True)