add inference

This commit is contained in:
2024-09-19 00:14:26 +08:00
parent 9ec3a00fd4
commit 935069d68c
10 changed files with 302 additions and 139 deletions

View File

@@ -1,33 +1,35 @@
import os
import json
from datetime import datetime
from utils.render import RenderUtil
from utils.pose import PoseUtil
from utils.pts import PtsUtil
from utils.reconstruction import ReconstructionUtil
import torch
from tqdm import tqdm
import numpy as np
import pickle
from PytorchBoot.config import ConfigManager
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory import ComponentFactory
from PytorchBoot.factory import OptimizerFactory
from PytorchBoot.dataset import BaseDataset
from PytorchBoot.runners.runner import Runner
from PytorchBoot.stereotype import EXTERNAL_FRONZEN_MODULES
from PytorchBoot.utils import Log
from PytorchBoot.status import status_manager
@stereotype.runner("nbv_evaluator")
class NextBestViewEvaluator(Runner):
@stereotype.runner("inferencer", comment="not tested")
class Inferencer(Runner):
def __init__(self, config_path):
super().__init__(config_path)
self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path")
self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir")
''' Pipeline '''
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
self.parallel = self.config["general"]["parallel"]
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
if self.parallel and self.device == "cuda":
self.pipeline = torch.nn.DataParallel(self.pipeline)
self.pipeline = self.pipeline.to(self.device)
''' Experiment '''
@@ -46,55 +48,135 @@ class NextBestViewEvaluator(Runner):
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
self.test_set_list.append(test_set)
self.print_info()
def run(self):
Log.info("Loading from epoch {}.".format(self.current_epoch))
self.test()
self.inference()
Log.success("Inference finished.")
def test(self):
def inference(self):
self.pipeline.eval()
with torch.no_grad():
test_set: BaseDataset
for dataset_idx, test_set in enumerate(self.test_set_list):
test_set_config = test_set.get_config()
eval_list = test_set_config["eval_list"]
ratio = test_set_config["ratio"]
status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list))
test_set_name = test_set.get_name()
output_list = []
data_list = []
test_loader = test_set.get_loader()
if test_loader.batch_size > 1:
Log.error("Batch size should be 1 for inference, found {} in {}".format(test_loader.batch_size, test_set_name), terminate=True)
total=int(len(test_loader))
loop = tqdm(enumerate(test_loader), total=total)
for i, data in loop:
status_manager.set_progress("train", "default_trainer", f"(test) Batch[{test_set_name}]", i+1, total)
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
test_set.process_batch(data, self.device)
data["mode"] = namespace.Mode.TEST
output = self.pipeline(data)
output_list.append(output)
data_list.append(data)
loop.set_description(
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})')
result_dict = self.eval_fn(output_list, data_list, eval_list)
@staticmethod
def eval_fn(output_list, data_list, eval_list):
collected_result = {}
for eval_method_name in eval_list:
eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name)
eval_results:dict = eval_method.evaluate(output_list, data_list)
for data_type, eval_result in eval_results.items():
if data_type not in collected_result:
collected_result[data_type] = {}
for name, value in eval_result.items():
collected_result[data_type][name] = value
status_manager.set_status("train", "default_trainer", f"[eval]{name}", value)
output = self.predict_sequence(data)
self.save_inference_result(output, data)
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
return collected_result
def predict_sequence(self, data, cr_increase_threshold=0, max_iter=100):
pred_cr_seq = []
scene_name = data["scene_name"][0]
Log.info(f"Processing scene: {scene_name}")
status_manager.set_status("inference", "inferencer", "scene", scene_name)
''' data for rendering '''
scene_path = data["scene_path"][0]
O_to_L_pose = data["O_to_L_pose"][0]
voxel_threshold = data["voxel_threshold"][0]
filter_degree = data["filter_degree"][0]
model_points_normals = data["model_points_normals"][0]
model_pts = model_points_normals[:,:3]
down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold)
first_frame_to_world = data["first_frame_to_world"][0]
''' data for inference '''
input_data = {}
input_data["scanned_pts"] = [data["first_pts"][0].to(self.device)]
input_data["scanned_n_to_1_pose_9d"] = [data["first_to_first_9d"][0].to(self.device)]
input_data["mode"] = namespace.Mode.TEST
input_pts_N = input_data["scanned_pts"][0].shape[1]
first_frame_target_pts, _ = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
scanned_view_pts = [first_frame_target_pts]
last_pred_cr = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
while len(pred_cr_seq) < max_iter:
output = self.pipeline(input_data)
next_pose_9d = output["pred_pose_9d"]
pred_pose = torch.eye(4, device=next_pose_9d.device)
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(next_pose_9d[:,:6])[0]
pred_pose[:3,3] = next_pose_9d[0,6:]
pred_n_to_world_pose_mat = torch.matmul(first_frame_to_world, pred_pose)
try:
new_target_pts_world, new_pts_world = RenderUtil.render_pts(pred_n_to_world_pose_mat, scene_path, self.script_path, model_points_normals, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose, require_full_scene=True)
except Exception as e:
Log.warning(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
continue
pred_cr = self.compute_coverage_rate(scanned_view_pts, new_target_pts_world, down_sampled_model_pts, threshold=voxel_threshold)
pred_cr_seq.append(pred_cr)
if pred_cr >= data["max_coverage_rate"]:
break
if pred_cr < last_pred_cr + cr_increase_threshold:
break
scanned_view_pts.append(new_target_pts_world)
down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_pts_world, input_pts_N)
new_pts_world_aug = np.hstack([down_sampled_new_pts_world, np.ones((down_sampled_new_pts_world.shape[0], 1))])
new_pts = np.dot(np.linalg.inv(first_frame_to_world.cpu()), new_pts_world_aug.T).T[:,:3]
new_pts_tensor = torch.tensor(new_pts, dtype=torch.float32).unsqueeze(0).to(self.device)
input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0] , new_pts_tensor], dim=0)]
input_data["scanned_n_to_1_pose_9d"] = [torch.cat([input_data["scanned_n_to_1_pose_9d"][0], next_pose_9d], dim=0)]
last_pred_cr = pred_cr
# ------ Debug Start ------
import ipdb;ipdb.set_trace()
# ------ Debug End ------
input_data["scanned_pts"] = input_data["scanned_pts"][0].cpu().numpy().tolist()
input_data["scanned_n_to_1_pose_9d"] = input_data["scanned_n_to_1_pose_9d"][0].cpu().numpy().tolist()
result = {
"pred_pose_9d_seq": input_data["scanned_n_to_1_pose_9d"],
"pts_seq": input_data["scanned_pts"],
"target_pts_seq": scanned_view_pts,
"coverage_rate_seq": pred_cr_seq,
"max_coverage_rate": data["max_coverage_rate"],
"pred_max_coverage_rate": max(pred_cr_seq)
}
return result
def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005):
if new_pts is not None:
new_scanned_view_pts = scanned_view_pts + [new_pts]
else:
new_scanned_view_pts = scanned_view_pts
combined_point_cloud = np.vstack(new_scanned_view_pts)
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold)
def save_inference_result(self, dataset_name, scene_name, output):
dataset_dir = os.path.join(self.output_dir, dataset_name)
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
pickle.dump(output, open(f"result_{scene_name}.pkl", "wb"))
def get_checkpoint_path(self, is_last=False):
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
@@ -116,54 +198,18 @@ class NextBestViewEvaluator(Runner):
self.current_epoch = meta["last_epoch"]
self.current_iter = meta["last_iter"]
def save_checkpoint(self, is_last=False):
self.save(self.get_checkpoint_path(is_last))
if not is_last:
Log.success(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}")
else:
meta = {
"last_epoch": self.current_epoch,
"last_iter": self.current_iter,
"time": str(datetime.now())
}
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
file_path = os.path.join(checkpoint_root, "meta.json")
with open(file_path, "w") as f:
json.dump(meta, f)
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
if self.experiments_config["use_checkpoint"]:
self.current_epoch = self.experiments_config["epoch"]
self.load_checkpoint(is_last=(self.current_epoch == -1))
self.current_epoch = self.experiments_config["epoch"]
self.load_checkpoint(is_last=(self.current_epoch == -1))
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
ckpt_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.CHECKPOINT_DIR_NAME)
os.makedirs(ckpt_dir)
tensorboard_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.TENSORBOARD_DIR_NAME)
os.makedirs(tensorboard_dir)
def load(self, path):
state_dict = torch.load(path)
if self.parallel:
self.pipeline.module.load_state_dict(state_dict)
else:
self.pipeline.load_state_dict(state_dict)
def save(self, path):
if self.parallel:
state_dict = self.pipeline.module.state_dict()
else:
state_dict = self.pipeline.state_dict()
for name, module in self.pipeline.named_modules():
if module.__class__ in EXTERNAL_FRONZEN_MODULES:
if name in state_dict:
del state_dict[name]
torch.save(state_dict, path)
self.pipeline.load_state_dict(state_dict)
def print_info(self):
def print_dataset(dataset: BaseDataset):
@@ -178,8 +224,6 @@ class NextBestViewEvaluator(Runner):
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
Log.blue(self.pipeline)
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
Log.blue("train dataset: ")
print_dataset(self.train_set)
for i, test_set in enumerate(self.test_set_list):
Log.blue(f"test dataset {i}: ")
print_dataset(test_set)