global: debug inference
This commit is contained in:
parent
982a3b9b60
commit
287983277a
@ -50,6 +50,9 @@ class SeqReconstructionDataset(BaseDataset):
|
|||||||
scene_name_list.append(scene_name)
|
scene_name_list.append(scene_name)
|
||||||
return scene_name_list
|
return scene_name_list
|
||||||
|
|
||||||
|
def get_scene_name_list(self):
|
||||||
|
return self.scene_name_list
|
||||||
|
|
||||||
def get_datalist(self):
|
def get_datalist(self):
|
||||||
datalist = []
|
datalist = []
|
||||||
for scene_name in self.scene_name_list:
|
for scene_name in self.scene_name_list:
|
||||||
|
@ -68,9 +68,16 @@ class Inferencer(Runner):
|
|||||||
test_set_name = test_set.get_name()
|
test_set_name = test_set.get_name()
|
||||||
|
|
||||||
total=int(len(test_set))
|
total=int(len(test_set))
|
||||||
|
scene_name_list = test_set.get_scene_name_list()
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
|
scene_name = scene_name_list[i]
|
||||||
|
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
|
||||||
|
if os.path.exists(inference_result_path):
|
||||||
|
Log.info(f"Inference result already exists for scene: {scene_name}")
|
||||||
|
continue
|
||||||
data = test_set.__getitem__(i)
|
data = test_set.__getitem__(i)
|
||||||
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
|
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
|
||||||
|
scene_name = data["scene_name"]
|
||||||
output = self.predict_sequence(data)
|
output = self.predict_sequence(data)
|
||||||
self.save_inference_result(test_set_name, data["scene_name"], output)
|
self.save_inference_result(test_set_name, data["scene_name"], output)
|
||||||
|
|
||||||
|
@ -24,8 +24,6 @@ class DataLoadUtil:
|
|||||||
for channel in float_channels:
|
for channel in float_channels:
|
||||||
channel_data = exr_file.channel(channel)
|
channel_data = exr_file.channel(channel)
|
||||||
img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width)))
|
img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width)))
|
||||||
|
|
||||||
# 将各通道组合成一个 (height, width, 3) 的 RGB 图像
|
|
||||||
img = np.stack(img_data, axis=-1)
|
img = np.stack(img_data, axis=-1)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user