92 lines
3.0 KiB
Python
Executable File
92 lines
3.0 KiB
Python
Executable File
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the Apache License, Version 2.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Sequence
|
|
|
|
import torch
|
|
from torchvision import transforms
|
|
|
|
|
|
class GaussianBlur(transforms.RandomApply):
|
|
"""
|
|
Apply Gaussian Blur to the PIL image.
|
|
"""
|
|
|
|
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
|
|
# NOTE: torchvision is applying 1 - probability to return the original image
|
|
keep_p = 1 - p
|
|
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
|
|
super().__init__(transforms=[transform], p=keep_p)
|
|
|
|
|
|
class MaybeToTensor(transforms.ToTensor):
|
|
"""
|
|
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
|
|
"""
|
|
|
|
def __call__(self, pic):
|
|
"""
|
|
Args:
|
|
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
|
|
Returns:
|
|
Tensor: Converted image.
|
|
"""
|
|
if isinstance(pic, torch.Tensor):
|
|
return pic
|
|
return super().__call__(pic)
|
|
|
|
|
|
# Use timm's names
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
|
|
|
|
|
def make_normalize_transform(
|
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
|
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
|
) -> transforms.Normalize:
|
|
return transforms.Normalize(mean=mean, std=std)
|
|
|
|
|
|
# This roughly matches torchvision's preset for classification training:
|
|
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
|
|
def make_classification_train_transform(
|
|
*,
|
|
crop_size: int = 224,
|
|
interpolation=transforms.InterpolationMode.BICUBIC,
|
|
hflip_prob: float = 0.5,
|
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
|
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
|
):
|
|
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
|
if hflip_prob > 0.0:
|
|
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
|
|
transforms_list.extend(
|
|
[
|
|
MaybeToTensor(),
|
|
make_normalize_transform(mean=mean, std=std),
|
|
]
|
|
)
|
|
return transforms.Compose(transforms_list)
|
|
|
|
|
|
# This matches (roughly) torchvision's preset for classification evaluation:
|
|
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
|
|
def make_classification_eval_transform(
|
|
*,
|
|
resize_size: int = 256,
|
|
interpolation=transforms.InterpolationMode.BICUBIC,
|
|
crop_size: int = 224,
|
|
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
|
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
|
) -> transforms.Compose:
|
|
transforms_list = [
|
|
transforms.Resize(resize_size, interpolation=interpolation),
|
|
transforms.CenterCrop(crop_size),
|
|
MaybeToTensor(),
|
|
make_normalize_transform(mean=mean, std=std),
|
|
]
|
|
return transforms.Compose(transforms_list)
|