success
This commit is contained in:
188
modules/module_lib/pointnet2_utils/tools/dataset.py
Executable file
188
modules/module_lib/pointnet2_utils/tools/dataset.py
Executable file
@@ -0,0 +1,188 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch.utils.data as torch_data
|
||||
import kitti_utils
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
|
||||
USE_INTENSITY = False
|
||||
|
||||
|
||||
class KittiDataset(torch_data.Dataset):
|
||||
def __init__(self, root_dir, split='train', mode='TRAIN'):
|
||||
self.split = split
|
||||
self.mode = mode
|
||||
self.classes = ['Car']
|
||||
is_test = self.split == 'test'
|
||||
self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training')
|
||||
|
||||
split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt')
|
||||
self.image_idx_list = [x.strip() for x in open(split_dir).readlines()]
|
||||
self.sample_id_list = [int(sample_id) for sample_id in self.image_idx_list]
|
||||
self.num_sample = self.image_idx_list.__len__()
|
||||
|
||||
self.npoints = 16384
|
||||
|
||||
self.image_dir = os.path.join(self.imageset_dir, 'image_2')
|
||||
self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne')
|
||||
self.calib_dir = os.path.join(self.imageset_dir, 'calib')
|
||||
self.label_dir = os.path.join(self.imageset_dir, 'label_2')
|
||||
self.plane_dir = os.path.join(self.imageset_dir, 'planes')
|
||||
|
||||
def get_image(self, idx):
|
||||
img_file = os.path.join(self.image_dir, '%06d.png' % idx)
|
||||
assert os.path.exists(img_file)
|
||||
return cv2.imread(img_file) # (H, W, 3) BGR mode
|
||||
|
||||
def get_image_shape(self, idx):
|
||||
img_file = os.path.join(self.image_dir, '%06d.png' % idx)
|
||||
assert os.path.exists(img_file)
|
||||
im = Image.open(img_file)
|
||||
width, height = im.size
|
||||
return height, width, 3
|
||||
|
||||
def get_lidar(self, idx):
|
||||
lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx)
|
||||
assert os.path.exists(lidar_file)
|
||||
return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)
|
||||
|
||||
def get_calib(self, idx):
|
||||
calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx)
|
||||
assert os.path.exists(calib_file)
|
||||
return kitti_utils.Calibration(calib_file)
|
||||
|
||||
def get_label(self, idx):
|
||||
label_file = os.path.join(self.label_dir, '%06d.txt' % idx)
|
||||
assert os.path.exists(label_file)
|
||||
return kitti_utils.get_objects_from_label(label_file)
|
||||
|
||||
@staticmethod
|
||||
def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape):
|
||||
val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1])
|
||||
val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0])
|
||||
val_flag_merge = np.logical_and(val_flag_1, val_flag_2)
|
||||
pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0)
|
||||
return pts_valid_flag
|
||||
|
||||
def filtrate_objects(self, obj_list):
|
||||
type_whitelist = self.classes
|
||||
if self.mode == 'TRAIN':
|
||||
type_whitelist = list(self.classes)
|
||||
if 'Car' in self.classes:
|
||||
type_whitelist.append('Van')
|
||||
|
||||
valid_obj_list = []
|
||||
for obj in obj_list:
|
||||
if obj.cls_type not in type_whitelist:
|
||||
continue
|
||||
|
||||
valid_obj_list.append(obj)
|
||||
return valid_obj_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_id_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample_id = int(self.sample_id_list[index])
|
||||
calib = self.get_calib(sample_id)
|
||||
img_shape = self.get_image_shape(sample_id)
|
||||
pts_lidar = self.get_lidar(sample_id)
|
||||
|
||||
# get valid point (projected points should be in image)
|
||||
pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
|
||||
pts_intensity = pts_lidar[:, 3]
|
||||
|
||||
pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
|
||||
pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)
|
||||
|
||||
pts_rect = pts_rect[pts_valid_flag][:, 0:3]
|
||||
pts_intensity = pts_intensity[pts_valid_flag]
|
||||
|
||||
if self.npoints < len(pts_rect):
|
||||
pts_depth = pts_rect[:, 2]
|
||||
pts_near_flag = pts_depth < 40.0
|
||||
far_idxs_choice = np.where(pts_near_flag == 0)[0]
|
||||
near_idxs = np.where(pts_near_flag == 1)[0]
|
||||
near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False)
|
||||
|
||||
choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \
|
||||
if len(far_idxs_choice) > 0 else near_idxs_choice
|
||||
np.random.shuffle(choice)
|
||||
else:
|
||||
choice = np.arange(0, len(pts_rect), dtype=np.int32)
|
||||
if self.npoints > len(pts_rect):
|
||||
extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False)
|
||||
choice = np.concatenate((choice, extra_choice), axis=0)
|
||||
np.random.shuffle(choice)
|
||||
|
||||
ret_pts_rect = pts_rect[choice, :]
|
||||
ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5]
|
||||
|
||||
pts_features = [ret_pts_intensity.reshape(-1, 1)]
|
||||
ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0]
|
||||
|
||||
sample_info = {'sample_id': sample_id}
|
||||
|
||||
if self.mode == 'TEST':
|
||||
if USE_INTENSITY:
|
||||
pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C)
|
||||
else:
|
||||
pts_input = ret_pts_rect
|
||||
sample_info['pts_input'] = pts_input
|
||||
sample_info['pts_rect'] = ret_pts_rect
|
||||
sample_info['pts_features'] = ret_pts_features
|
||||
return sample_info
|
||||
|
||||
gt_obj_list = self.filtrate_objects(self.get_label(sample_id))
|
||||
|
||||
gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list)
|
||||
|
||||
# prepare input
|
||||
if USE_INTENSITY:
|
||||
pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C)
|
||||
else:
|
||||
pts_input = ret_pts_rect
|
||||
|
||||
# generate training labels
|
||||
cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d)
|
||||
sample_info['pts_input'] = pts_input
|
||||
sample_info['pts_rect'] = ret_pts_rect
|
||||
sample_info['cls_labels'] = cls_labels
|
||||
return sample_info
|
||||
|
||||
@staticmethod
|
||||
def generate_training_labels(pts_rect, gt_boxes3d):
|
||||
cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32)
|
||||
gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True)
|
||||
extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2)
|
||||
extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True)
|
||||
for k in range(gt_boxes3d.shape[0]):
|
||||
box_corners = gt_corners[k]
|
||||
fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners)
|
||||
cls_label[fg_pt_flag] = 1
|
||||
|
||||
# enlarge the bbox3d, ignore nearby points
|
||||
extend_box_corners = extend_gt_corners[k]
|
||||
fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners)
|
||||
ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag)
|
||||
cls_label[ignore_flag] = -1
|
||||
|
||||
return cls_label
|
||||
|
||||
def collate_batch(self, batch):
|
||||
batch_size = batch.__len__()
|
||||
ans_dict = {}
|
||||
|
||||
for key in batch[0].keys():
|
||||
if isinstance(batch[0][key], np.ndarray):
|
||||
ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0)
|
||||
|
||||
else:
|
||||
ans_dict[key] = [batch[k][key] for k in range(batch_size)]
|
||||
if isinstance(batch[0][key], int):
|
||||
ans_dict[key] = np.array(ans_dict[key], dtype=np.int32)
|
||||
elif isinstance(batch[0][key], float):
|
||||
ans_dict[key] = np.array(ans_dict[key], dtype=np.float32)
|
||||
|
||||
return ans_dict
|
Reference in New Issue
Block a user