load 16bit float

This commit is contained in:
2024-10-23 13:57:45 +08:00
parent be67be95e9
commit b18c1591b7
4 changed files with 19 additions and 16 deletions

View File

@@ -29,8 +29,8 @@ class DataLoadUtil:
# 读取 EXR 文件中的每个通道并转化为浮点数数组
img_data = []
for channel in float_channels:
channel_data = exr_file.channel(channel, Imath.PixelType(Imath.PixelType.FLOAT))
img_data.append(np.frombuffer(channel_data, dtype=np.float32).reshape((height, width)))
channel_data = exr_file.channel(channel)
img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width)))
# 将各通道组合成一个 (height, width, 3) 的 RGB 图像
img = np.stack(img_data, axis=-1)
@@ -141,8 +141,8 @@ class DataLoadUtil:
if binocular and not left_only:
def clean_mask(mask_image):
green = [0, 255, 0, 255]
red = [255, 0, 0, 255]
green = [0, 255, 0]
red = [255, 0, 0]
threshold = 2
mask_image = np.where(
np.abs(mask_image - green) <= threshold, green, mask_image