success
This commit is contained in:
185
runners/preprocessors/object_pose/FoundationPose_preprocessor.py
Executable file
185
runners/preprocessors/object_pose/FoundationPose_preprocessor.py
Executable file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import trimesh
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(4):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from utils.omni_util import OmniUtil
|
||||
from utils.view_util import ViewUtil
|
||||
from runners.preprocessors.object_pose.abstract_object_pose_preprocessor import ObjectPosePreprocessor
|
||||
from configs.config import ConfigManager
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class ObjectPoseInferenceDataset(Dataset):
|
||||
CAMERA_PARAMS_TEMPLATE = "camera_params_{}.json"
|
||||
DISTANCE_TEMPLATE = "distance_to_camera_{}.npy"
|
||||
RGB_TEMPLATE = "rgb_{}.png"
|
||||
MASK_TEMPLATE = "semantic_segmentation_{}.png"
|
||||
MASK_LABELS_TEMPLATE = "semantic_segmentation_labels_{}.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source="nbv1",
|
||||
data_type="sample",
|
||||
data_dir="/mnt/h/AI/Datasets",
|
||||
):
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.empty_frame = set()
|
||||
self.data_path = str(os.path.join(self.data_dir, source, data_type))
|
||||
self.scene_list = os.listdir(self.data_path)
|
||||
self.data_list = self.get_datalist()
|
||||
|
||||
self.object_data_list = self.get_object_datalist()
|
||||
self.object_name_list = list(self.object_data_list.keys())
|
||||
self.mesh_dir_path = os.path.join(self.data_dir, source, "objects")
|
||||
|
||||
self.meshes = {}
|
||||
self.load_all_meshes()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
frame_path, target = self.data_list[index]
|
||||
frame_data = self.load_frame_data(frame_path=frame_path, object_name=target)
|
||||
return frame_data
|
||||
|
||||
def load_all_meshes(self):
|
||||
object_name_list = os.listdir(self.mesh_dir_path)
|
||||
for object_name in object_name_list:
|
||||
mesh_path = os.path.join(self.mesh_dir_path, object_name, "Scan", "Simp.obj")
|
||||
mesh = trimesh.load(mesh_path)
|
||||
object_model_scale = [0.001, 0.001, 0.001]
|
||||
mesh.apply_scale(object_model_scale)
|
||||
self.meshes[object_name] = mesh
|
||||
|
||||
def get_datalist(self):
|
||||
for scene in self.scene_list:
|
||||
scene_path = os.path.join(self.data_path, scene)
|
||||
file_list = os.listdir(scene_path)
|
||||
scene_frame_list = []
|
||||
for file in file_list:
|
||||
if file.startswith("camera_params"):
|
||||
frame_index = re.findall(r"\d+", file)[0]
|
||||
frame_path = os.path.join(scene_path, frame_index)
|
||||
target_list = OmniUtil.get_object_list(frame_path)
|
||||
for target in target_list:
|
||||
scene_frame_list.append((frame_path,target))
|
||||
if len(target_list) == 0:
|
||||
self.empty_frame.add(frame_path)
|
||||
|
||||
return scene_frame_list
|
||||
|
||||
def get_object_datalist(self):
|
||||
object_datalist = {}
|
||||
for data_item in self.data_list:
|
||||
frame_path, target = data_item
|
||||
if target not in object_datalist:
|
||||
object_datalist[target] = []
|
||||
object_datalist[target].append(frame_path)
|
||||
return object_datalist
|
||||
|
||||
def get_object_data_batch(self, object_name):
|
||||
object_data_list = self.object_data_list[object_name]
|
||||
batch_data = {"frame_path_list":[],
|
||||
"rgb_batch":[],
|
||||
"depth_batch":[],
|
||||
"seg_batch":[],
|
||||
"gt_pose_batch":[],
|
||||
"K":None,
|
||||
"mesh":None}
|
||||
for frame_path in object_data_list:
|
||||
frame_data = self.load_frame_data(frame_path, object_name)
|
||||
batch_data["frame_path_list"].append(frame_path)
|
||||
batch_data["rgb_batch"].append(frame_data["rgb"])
|
||||
batch_data["depth_batch"].append(frame_data["depth"])
|
||||
batch_data["seg_batch"].append(frame_data["seg"])
|
||||
batch_data["gt_pose_batch"].append(frame_data["gt_pose"])
|
||||
batch_data["K"] = frame_data["K"]
|
||||
batch_data["mesh"] = frame_data["mesh"]
|
||||
|
||||
batch_data["rgb_batch"] = np.asarray(batch_data["rgb_batch"],dtype=np.uint8)
|
||||
batch_data["depth_batch"] = np.asarray(batch_data["depth_batch"])
|
||||
batch_data["seg_batch"] = np.asarray(batch_data["seg_batch"])
|
||||
batch_data["gt_pose_batch"] = np.asarray(batch_data["gt_pose_batch"])
|
||||
return batch_data
|
||||
|
||||
def load_frame_data(self, frame_path, object_name):
|
||||
rgb = OmniUtil.get_rgb(frame_path)
|
||||
depth = OmniUtil.get_depth(frame_path)
|
||||
seg = OmniUtil.get_single_seg(frame_path, object_name)
|
||||
K = OmniUtil.get_intrinsic_matrix(frame_path)
|
||||
gt_obj_pose = OmniUtil.get_o2c_pose(frame_path, object_name)
|
||||
ret_dict = {
|
||||
"frame_path": frame_path,
|
||||
"rgb": rgb.astype(np.float32),
|
||||
"depth": depth.astype(np.float32),
|
||||
"seg": seg,
|
||||
"K": K.astype(np.float32),
|
||||
"object_name": object_name,
|
||||
"mesh": self.meshes[object_name],
|
||||
"gt_pose": gt_obj_pose.astype(np.float32)
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
class FoundationPosePreprocessor(ObjectPosePreprocessor):
|
||||
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
|
||||
def run(self):
|
||||
for dataset_config in self.dataset_list_config:
|
||||
dataset = ObjectPoseInferenceDataset(
|
||||
source=dataset_config["source"],
|
||||
data_type=dataset_config["data_type"],
|
||||
data_dir=dataset_config["data_dir"],
|
||||
)
|
||||
result = self.prediction(dataset)
|
||||
self.save_processed_data(result, dataset_config)
|
||||
|
||||
def prediction(self, dataset):
|
||||
final_result = {}
|
||||
cnt = 0
|
||||
for object_name in dataset.object_name_list:
|
||||
cnt += 1
|
||||
print(f"Processing object: {object_name} ({cnt}/{len(dataset.object_name_list)})")
|
||||
object_data_batch = dataset.get_object_data_batch(object_name)
|
||||
print(f"batch size of object {object_name}: {len(object_data_batch['frame_path_list'])}")
|
||||
pose_batch, result_batch = ViewUtil.get_object_pose_batch(
|
||||
object_data_batch["K"],
|
||||
object_data_batch["mesh"],
|
||||
object_data_batch["rgb_batch"],
|
||||
object_data_batch["depth_batch"],
|
||||
object_data_batch["seg_batch"],
|
||||
object_data_batch["gt_pose_batch"],
|
||||
self.web_server_config["port"]
|
||||
)
|
||||
for frame_path, pred_pose,gt_pose,result in zip(object_data_batch["frame_path_list"], pose_batch,object_data_batch["gt_pose_batch"],result_batch):
|
||||
if frame_path not in final_result:
|
||||
final_result[frame_path]={}
|
||||
final_result[frame_path][object_name] = {"gt_pose":gt_pose.tolist(),"pred_pose":pred_pose.tolist(),"eval_result":result}
|
||||
for frame_path in dataset.empty_frame:
|
||||
final_result[frame_path] = {}
|
||||
return final_result
|
||||
|
||||
if __name__ == "__main__":
|
||||
config_path = os.path.join(PROJECT_ROOT, "configs/server_object_preprocess_config.yaml")
|
||||
preprocessor = FoundationPosePreprocessor(config_path)
|
||||
preprocessor.run()
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user