success
This commit is contained in:
256
baselines/grasping/GSNet/preprocessor.py
Executable file
256
baselines/grasping/GSNet/preprocessor.py
Executable file
@@ -0,0 +1,256 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from omni_util import OmniUtil
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.join(ROOT_DIR, "pointnet2"))
|
||||
sys.path.append(os.path.join(ROOT_DIR, "utils"))
|
||||
sys.path.append(os.path.join(ROOT_DIR, "models"))
|
||||
sys.path.append(os.path.join(ROOT_DIR, "dataset"))
|
||||
from models.graspnet import GraspNet
|
||||
from dataset.graspnet_dataset import minkowski_collate_fn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class GSNetInferenceDataset(Dataset):
|
||||
CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json"
|
||||
DISTANCE_TEMPLATE = "distance_to_camera_{}.npy"
|
||||
RGB_TEMPLATE = "rgb_{}.png"
|
||||
MASK_TEMPLATE = "semantic_segmentation_{}.png"
|
||||
MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source="nbv1",
|
||||
data_type="sample",
|
||||
data_dir="/mnt/h/AI/Datasets",
|
||||
scene_pts_num=15000,
|
||||
):
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.scene_pts_num = scene_pts_num
|
||||
self.data_path = str(os.path.join(self.data_dir, source, data_type))
|
||||
self.scene_list = os.listdir(self.data_path)
|
||||
self.data_list = self.get_datalist()
|
||||
self.voxel_size = 0.005
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
frame_path = self.data_list[index]
|
||||
frame_data = self.load_frame_data(frame_path=frame_path)
|
||||
return frame_data
|
||||
|
||||
def get_datalist(self):
|
||||
for scene in self.scene_list:
|
||||
scene_path = os.path.join(self.data_path, scene)
|
||||
file_list = os.listdir(scene_path)
|
||||
scene_frame_list = []
|
||||
for file in file_list:
|
||||
if file.startswith("camera_params"):
|
||||
frame_index = re.findall(r"\d+", file)[0]
|
||||
frame_path = os.path.join(scene_path, frame_index)
|
||||
scene_frame_list.append(frame_path)
|
||||
|
||||
return scene_frame_list
|
||||
|
||||
def load_frame_data(self, frame_path):
|
||||
target_list = OmniUtil.get_object_list(path=frame_path, contains_nonobj=True)
|
||||
scene_pts, obj_pcl_dict = OmniUtil.get_segmented_points(
|
||||
path=frame_path, target_list=target_list
|
||||
)
|
||||
ret_dict = {
|
||||
"frame_path": frame_path,
|
||||
"point_clouds": scene_pts.astype(np.float32),
|
||||
"coors": scene_pts.astype(np.float32) / self.voxel_size,
|
||||
"feats": np.ones_like(scene_pts).astype(np.float32),
|
||||
"obj_pcl_dict": obj_pcl_dict,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
@staticmethod
|
||||
def sample_pcl(pcl, n_pts=1024):
|
||||
indices = np.random.choice(pcl.shape[0], n_pts, replace=pcl.shape[0] < n_pts)
|
||||
return pcl[indices, :]
|
||||
|
||||
|
||||
class GSNetPreprocessor:
|
||||
LABEL_TEMPLATE = "label_{}.json"
|
||||
|
||||
def __init__(self):
|
||||
self.voxel_size = 0.005
|
||||
self.camera = "kinect"
|
||||
self.num_point = 15000
|
||||
self.batch_size = 1
|
||||
self.seed_feat_dim = 512
|
||||
self.checkpoint_path = "logs/log_kn/epoch10.tar"
|
||||
self.dump_dir = "logs/log_kn/dump_kinect"
|
||||
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def get_dataloader(self, dataset_config=None):
|
||||
def my_worker_init_fn(worker_id):
|
||||
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
pass
|
||||
|
||||
dataset = GSNetInferenceDataset()
|
||||
print("Test dataset length: ", len(dataset))
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
worker_init_fn=my_worker_init_fn,
|
||||
collate_fn=minkowski_collate_fn,
|
||||
)
|
||||
print("Test dataloader length: ", len(dataloader))
|
||||
return dataloader
|
||||
|
||||
def get_model(self, model_config=None):
|
||||
model = GraspNet(seed_feat_dim=self.seed_feat_dim, is_training=False)
|
||||
model.to(self.device)
|
||||
checkpoint = torch.load(self.checkpoint_path)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
start_epoch = checkpoint["epoch"]
|
||||
print(
|
||||
"-> loaded checkpoint %s (epoch: %d)" % (self.checkpoint_path, start_epoch)
|
||||
)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def prediction(self, model, dataloader):
|
||||
preds = {}
|
||||
total = len(dataloader)
|
||||
for idx, batch_data in enumerate(dataloader):
|
||||
print(f"predicting... [{idx}/{total}]")
|
||||
for key in batch_data:
|
||||
if "list" in key:
|
||||
for i in range(len(batch_data[key])):
|
||||
for j in range(len(batch_data[key][i])):
|
||||
batch_data[key][i][j] = batch_data[key][i][j].to(
|
||||
self.device
|
||||
)
|
||||
elif not isinstance(batch_data[key], (list)):
|
||||
batch_data[key] = batch_data[key].to(self.device)
|
||||
with torch.no_grad():
|
||||
end_points = model(batch_data)
|
||||
grasp_preds = self.decode_pred(end_points)
|
||||
for frame_idx in range(len(batch_data["frame_path"])):
|
||||
preds[batch_data["frame_path"][frame_idx]] = grasp_preds[frame_idx]
|
||||
preds[batch_data["frame_path"][frame_idx]]["obj_pcl_dict"] = (
|
||||
batch_data["obj_pcl_dict"][frame_idx]
|
||||
)
|
||||
|
||||
results = {}
|
||||
top_k = 50
|
||||
for frame_path in preds:
|
||||
predict_results = {}
|
||||
grasp_center = preds[frame_path]["grasp_center"]
|
||||
grasp_score = preds[frame_path]["grasp_score"]
|
||||
obj_pcl_dict = preds[frame_path]["obj_pcl_dict"]
|
||||
grasp_center = grasp_center.unsqueeze(1)
|
||||
for obj_name in obj_pcl_dict:
|
||||
if obj_name in OmniUtil.NON_OBJECT_LIST:
|
||||
continue
|
||||
obj_pcl = obj_pcl_dict[obj_name]
|
||||
obj_pcl = torch.tensor(
|
||||
obj_pcl.astype(np.float32), device=grasp_center.device
|
||||
)
|
||||
obj_pcl = obj_pcl.unsqueeze(0)
|
||||
grasp_obj_table = (grasp_center == obj_pcl).all(axis=-1)
|
||||
obj_pts_on_grasp = grasp_obj_table.any(axis=1)
|
||||
obj_graspable_pts = grasp_center[obj_pts_on_grasp].squeeze(1)
|
||||
obj_graspable_pts_score = grasp_score[obj_pts_on_grasp]
|
||||
obj_graspable_pts_info = torch.cat(
|
||||
[obj_graspable_pts, obj_graspable_pts_score], dim=1
|
||||
)
|
||||
if obj_graspable_pts.shape[0] == 0:
|
||||
obj_graspable_pts_info = torch.zeros((top_k, 4))
|
||||
ranked_obj_graspable_pts_info = self.sample_graspable_pts(
|
||||
obj_graspable_pts_info, top_k=top_k
|
||||
)
|
||||
predict_results[obj_name] = {
|
||||
"positions": ranked_obj_graspable_pts_info[:, :3]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist(),
|
||||
"scores": ranked_obj_graspable_pts_info[:, 3]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist(),
|
||||
}
|
||||
results[frame_path] = {"predicted_results": predict_results}
|
||||
return results
|
||||
|
||||
def preprocess(self, predicted_data):
|
||||
obj_score_list_dict = {}
|
||||
for frame_path in predicted_data:
|
||||
frame_obj_info = predicted_data[frame_path]["predicted_results"]
|
||||
predicted_data[frame_path]["sum_score"] = {}
|
||||
for obj_name in frame_obj_info:
|
||||
if obj_name not in obj_score_list_dict:
|
||||
obj_score_list_dict[obj_name] = []
|
||||
obj_score_sum = np.sum(frame_obj_info[obj_name]["scores"])
|
||||
obj_score_list_dict[obj_name].append(obj_score_sum)
|
||||
predicted_data[frame_path]["sum_score"][obj_name] = obj_score_sum
|
||||
|
||||
for frame_path in predicted_data:
|
||||
frame_obj_info = predicted_data[frame_path]["predicted_results"]
|
||||
predicted_data[frame_path]["regularized_score"] = {}
|
||||
for obj_name in frame_obj_info:
|
||||
obj_score_sum = predicted_data[frame_path]["sum_score"][obj_name]
|
||||
max_obj_score = max(obj_score_list_dict[obj_name])
|
||||
predicted_data[frame_path]["regularized_score"][obj_name] = (
|
||||
obj_score_sum / (max_obj_score + 1e-6)
|
||||
)
|
||||
return predicted_data
|
||||
|
||||
@staticmethod
|
||||
def sample_graspable_pts(graspable_pts, top_k=50):
|
||||
if graspable_pts.shape[0] < top_k:
|
||||
sampled_indices = torch.randint(0, graspable_pts.shape[0], (top_k,))
|
||||
graspable_pts = graspable_pts[sampled_indices]
|
||||
sorted_indices = torch.argsort(graspable_pts[:, 3], descending=True)
|
||||
sampled_indices = graspable_pts[sorted_indices][:50]
|
||||
return sampled_indices
|
||||
|
||||
def save_processed_data(self, processed_data, dataset_config):
|
||||
import json
|
||||
|
||||
for frame_path in processed_data:
|
||||
data_item = processed_data[frame_path]
|
||||
save_root, idx = frame_path[:-4], frame_path[-4:]
|
||||
label_save_path = os.path.join(
|
||||
str(save_root), self.LABEL_TEMPLATE.format(idx)
|
||||
)
|
||||
with open(label_save_path, "w+") as f:
|
||||
json.dump(data_item, f)
|
||||
|
||||
def decode_pred(self, end_points):
|
||||
batch_size = len(end_points["point_clouds"])
|
||||
grasp_preds = []
|
||||
for i in range(batch_size):
|
||||
grasp_center = end_points["xyz_graspable"][i].float()
|
||||
num_pts = end_points["xyz_graspable"][i].shape[0]
|
||||
grasp_score = end_points["grasp_score_pred"][i].float()
|
||||
grasp_score = grasp_score.view(num_pts, -1)
|
||||
grasp_score, _ = torch.max(grasp_score, -1) # [M_POINT]
|
||||
grasp_score = grasp_score.view(-1, 1)
|
||||
grasp_preds.append(
|
||||
{"grasp_center": grasp_center, "grasp_score": grasp_score}
|
||||
)
|
||||
return grasp_preds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gs_preproc = GSNetPreprocessor()
|
||||
dataloader = gs_preproc.get_dataloader()
|
||||
model = gs_preproc.get_model()
|
||||
results = gs_preproc.prediction(model=model, dataloader=dataloader)
|
||||
results = gs_preproc.preprocess(results)
|
||||
gs_preproc.save_processed_data(results, None)
|
||||
# gs_preproc.evaluate()
|
Reference in New Issue
Block a user