global: debug inference

This commit is contained in:
hofee 2024-11-01 22:51:16 +00:00
parent 982a3b9b60
commit 287983277a
3 changed files with 10 additions and 2 deletions

View File

@ -50,6 +50,9 @@ class SeqReconstructionDataset(BaseDataset):
scene_name_list.append(scene_name) scene_name_list.append(scene_name)
return scene_name_list return scene_name_list
def get_scene_name_list(self):
return self.scene_name_list
def get_datalist(self): def get_datalist(self):
datalist = [] datalist = []
for scene_name in self.scene_name_list: for scene_name in self.scene_name_list:

View File

@ -68,9 +68,16 @@ class Inferencer(Runner):
test_set_name = test_set.get_name() test_set_name = test_set.get_name()
total=int(len(test_set)) total=int(len(test_set))
scene_name_list = test_set.get_scene_name_list()
for i in range(total): for i in range(total):
scene_name = scene_name_list[i]
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
if os.path.exists(inference_result_path):
Log.info(f"Inference result already exists for scene: {scene_name}")
continue
data = test_set.__getitem__(i) data = test_set.__getitem__(i)
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total) status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
scene_name = data["scene_name"]
output = self.predict_sequence(data) output = self.predict_sequence(data)
self.save_inference_result(test_set_name, data["scene_name"], output) self.save_inference_result(test_set_name, data["scene_name"], output)

View File

@ -24,8 +24,6 @@ class DataLoadUtil:
for channel in float_channels: for channel in float_channels:
channel_data = exr_file.channel(channel) channel_data = exr_file.channel(channel)
img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width))) img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width)))
# 将各通道组合成一个 (height, width, 3) 的 RGB 图像
img = np.stack(img_data, axis=-1) img = np.stack(img_data, axis=-1)
return img return img