global: debug inference

This commit is contained in:
hofee 2024-11-01 22:51:16 +00:00
parent 982a3b9b60
commit 287983277a
3 changed files with 10 additions and 2 deletions

View File

@ -49,6 +49,9 @@ class SeqReconstructionDataset(BaseDataset):
scene_name = line.strip()
scene_name_list.append(scene_name)
return scene_name_list
def get_scene_name_list(self):
return self.scene_name_list
def get_datalist(self):
datalist = []

View File

@ -68,9 +68,16 @@ class Inferencer(Runner):
test_set_name = test_set.get_name()
total=int(len(test_set))
scene_name_list = test_set.get_scene_name_list()
for i in range(total):
scene_name = scene_name_list[i]
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
if os.path.exists(inference_result_path):
Log.info(f"Inference result already exists for scene: {scene_name}")
continue
data = test_set.__getitem__(i)
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
scene_name = data["scene_name"]
output = self.predict_sequence(data)
self.save_inference_result(test_set_name, data["scene_name"], output)

View File

@ -24,8 +24,6 @@ class DataLoadUtil:
for channel in float_channels:
channel_data = exr_file.channel(channel)
img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width)))
# 将各通道组合成一个 (height, width, 3) 的 RGB 图像
img = np.stack(img_data, axis=-1)
return img