This commit is contained in:
2024-10-09 16:13:22 +00:00
commit 0ea3f048dc
437 changed files with 44406 additions and 0 deletions

4
modules/module_lib/__init__.py Executable file
View File

@@ -0,0 +1,4 @@
from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection
from modules.module_lib.linear import Linear
from modules.module_lib.position_embedding import PositionalEmbedding
from modules.module_lib.rot_head import RotHead

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
__version__ = "0.0.1"

View File

@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import pathlib
from omegaconf import OmegaConf
def load_config(config_name: str):
config_filename = config_name + ".yaml"
return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
dinov2_default_config = load_config("ssl_default_config")
def load_and_merge_config(config_name: str):
default_config = OmegaConf.create(dinov2_default_config)
loaded_config = load_config(config_name)
return OmegaConf.merge(default_config, loaded_config)

View File

@@ -0,0 +1,6 @@
student:
arch: vit_base
patch_size: 14
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,9 @@
student:
arch: vit_base
patch_size: 14
num_register_tokens: 4
interpolate_antialias: true
interpolate_offset: 0.0
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,7 @@
student:
arch: vit_giant2
patch_size: 14
ffn_layer: swiglufused
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,10 @@
student:
arch: vit_giant2
patch_size: 14
ffn_layer: swiglufused
num_register_tokens: 4
interpolate_antialias: true
interpolate_offset: 0.0
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,6 @@
student:
arch: vit_large
patch_size: 14
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,9 @@
student:
arch: vit_large
patch_size: 14
num_register_tokens: 4
interpolate_antialias: true
interpolate_offset: 0.0
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,6 @@
student:
arch: vit_small
patch_size: 14
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,9 @@
student:
arch: vit_small
patch_size: 14
num_register_tokens: 4
interpolate_antialias: true
interpolate_offset: 0.0
crops:
global_crops_size: 518 # this is to set up the position embeddings properly
local_crops_size: 98

View File

@@ -0,0 +1,118 @@
MODEL:
WEIGHTS: ''
compute_precision:
grad_scaler: true
teacher:
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
dino_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
ibot_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
student:
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
dino_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp32
buffer_dtype: fp32
ibot_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp32
buffer_dtype: fp32
dino:
loss_weight: 1.0
head_n_prototypes: 65536
head_bottleneck_dim: 256
head_nlayers: 3
head_hidden_dim: 2048
koleo_loss_weight: 0.1
ibot:
loss_weight: 1.0
mask_sample_probability: 0.5
mask_ratio_min_max:
- 0.1
- 0.5
separate_head: false
head_n_prototypes: 65536
head_bottleneck_dim: 256
head_nlayers: 3
head_hidden_dim: 2048
train:
batch_size_per_gpu: 64
dataset_path: ImageNet:split=TRAIN
output_dir: .
saveckp_freq: 20
seed: 0
num_workers: 10
OFFICIAL_EPOCH_LENGTH: 1250
cache_dataset: true
centering: "centering" # or "sinkhorn_knopp"
student:
arch: vit_large
patch_size: 16
drop_path_rate: 0.3
layerscale: 1.0e-05
drop_path_uniform: true
pretrained_weights: ''
ffn_layer: "mlp"
block_chunks: 0
qkv_bias: true
proj_bias: true
ffn_bias: true
num_register_tokens: 0
interpolate_antialias: false
interpolate_offset: 0.1
teacher:
momentum_teacher: 0.992
final_momentum_teacher: 1
warmup_teacher_temp: 0.04
teacher_temp: 0.07
warmup_teacher_temp_epochs: 30
optim:
epochs: 100
weight_decay: 0.04
weight_decay_end: 0.4
base_lr: 0.004 # learning rate for a batch size of 1024
lr: 0. # will be set after applying scaling rule
warmup_epochs: 10
min_lr: 1.0e-06
clip_grad: 3.0
freeze_last_layer_epochs: 1
scaling_rule: sqrt_wrt_1024
patch_embed_lr_mult: 0.2
layerwise_decay: 0.9
adamw_beta1: 0.9
adamw_beta2: 0.999
crops:
global_crops_scale:
- 0.32
- 1.0
local_crops_number: 8
local_crops_scale:
- 0.05
- 0.32
global_crops_size: 224
local_crops_size: 96
evaluation:
eval_period_iterations: 12500

View File

@@ -0,0 +1,26 @@
dino:
head_n_prototypes: 131072
head_bottleneck_dim: 384
ibot:
separate_head: true
head_n_prototypes: 131072
train:
batch_size_per_gpu: 12
dataset_path: ImageNet22k
centering: sinkhorn_knopp
student:
arch: vit_giant2
patch_size: 14
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
teacher:
momentum_teacher: 0.994
optim:
epochs: 500
weight_decay_end: 0.2
base_lr: 2.0e-04 # learning rate for a batch size of 1024
warmup_epochs: 80
layerwise_decay: 1.0
crops:
local_crops_size: 98

View File

@@ -0,0 +1,26 @@
dino:
head_n_prototypes: 131072
head_bottleneck_dim: 384
ibot:
separate_head: true
head_n_prototypes: 131072
train:
batch_size_per_gpu: 32
dataset_path: ImageNet22k
centering: sinkhorn_knopp
student:
arch: vit_large
patch_size: 14
drop_path_rate: 0.4
ffn_layer: swiglufused
block_chunks: 4
teacher:
momentum_teacher: 0.994
optim:
epochs: 500
weight_decay_end: 0.2
base_lr: 2.0e-04 # learning rate for a batch size of 1024
warmup_epochs: 80
layerwise_decay: 1.0
crops:
local_crops_size: 98

View File

@@ -0,0 +1,6 @@
# this corresponds to the default config
train:
dataset_path: ImageNet:split=TRAIN
batch_size_per_gpu: 64
student:
block_chunks: 4

View File

@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .adapters import DatasetWithEnumeratedTargets
from .loaders import make_data_loader, make_dataset, SamplerType
from .collate import collate_data_and_cast
from .masking import MaskingGenerator
from .augmentations import DataAugmentationDINO

View File

@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from typing import Any, Tuple
from torch.utils.data import Dataset
class DatasetWithEnumeratedTargets(Dataset):
def __init__(self, dataset):
self._dataset = dataset
def get_image_data(self, index: int) -> bytes:
return self._dataset.get_image_data(index)
def get_target(self, index: int) -> Tuple[Any, int]:
target = self._dataset.get_target(index)
return (index, target)
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
image, target = self._dataset[index]
target = index if target is None else target
return image, (index, target)
def __len__(self) -> int:
return len(self._dataset)

View File

@@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import logging
from torchvision import transforms
from .transforms import (
GaussianBlur,
make_normalize_transform,
)
logger = logging.getLogger("dinov2")
class DataAugmentationDINO(object):
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
global_crops_size=224,
local_crops_size=96,
):
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
self.local_crops_number = local_crops_number
self.global_crops_size = global_crops_size
self.local_crops_size = local_crops_size
logger.info("###################################")
logger.info("Using data augmentation parameters:")
logger.info(f"global_crops_scale: {global_crops_scale}")
logger.info(f"local_crops_scale: {local_crops_scale}")
logger.info(f"local_crops_number: {local_crops_number}")
logger.info(f"global_crops_size: {global_crops_size}")
logger.info(f"local_crops_size: {local_crops_size}")
logger.info("###################################")
# random resized crop and flip
self.geometric_augmentation_global = transforms.Compose(
[
transforms.RandomResizedCrop(
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
self.geometric_augmentation_local = transforms.Compose(
[
transforms.RandomResizedCrop(
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
# color distorsions / blurring
color_jittering = transforms.Compose(
[
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
)
global_transfo1_extra = GaussianBlur(p=1.0)
global_transfo2_extra = transforms.Compose(
[
GaussianBlur(p=0.1),
transforms.RandomSolarize(threshold=128, p=0.2),
]
)
local_transfo_extra = GaussianBlur(p=0.5)
# normalization
self.normalize = transforms.Compose(
[
transforms.ToTensor(),
make_normalize_transform(),
]
)
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
def __call__(self, image):
output = {}
# global crops:
im1_base = self.geometric_augmentation_global(image)
global_crop_1 = self.global_transfo1(im1_base)
im2_base = self.geometric_augmentation_global(image)
global_crop_2 = self.global_transfo2(im2_base)
output["global_crops"] = [global_crop_1, global_crop_2]
# global crops for teacher:
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
# local crops:
local_crops = [
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
]
output["local_crops"] = local_crops
output["offsets"] = ()
return output

View File

@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import random
def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
# dtype = torch.half # TODO: Remove
n_global_crops = len(samples_list[0][0]["global_crops"])
n_local_crops = len(samples_list[0][0]["local_crops"])
collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
B = len(collated_global_crops)
N = n_tokens
n_samples_masked = int(B * mask_probability)
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
upperbound = 0
masks_list = []
for i in range(0, n_samples_masked):
prob_min = probs[i]
prob_max = probs[i + 1]
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
upperbound += int(N * prob_max)
for i in range(n_samples_masked, B):
masks_list.append(torch.BoolTensor(mask_generator(0)))
random.shuffle(masks_list)
collated_masks = torch.stack(masks_list).flatten(1)
mask_indices_list = collated_masks.flatten().nonzero().flatten()
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
return {
"collated_global_crops": collated_global_crops.to(dtype),
"collated_local_crops": collated_local_crops.to(dtype),
"collated_masks": collated_masks,
"mask_indices_list": mask_indices_list,
"masks_weight": masks_weight,
"upperbound": upperbound,
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
}

View File

@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .image_net import ImageNet
from .image_net_22k import ImageNet22k

View File

@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from io import BytesIO
from typing import Any
from PIL import Image
class Decoder:
def decode(self) -> Any:
raise NotImplementedError
class ImageDataDecoder(Decoder):
def __init__(self, image_data: bytes) -> None:
self._image_data = image_data
def decode(self) -> Image:
f = BytesIO(self._image_data)
return Image.open(f).convert(mode="RGB")
class TargetDecoder(Decoder):
def __init__(self, target: Any):
self._target = target
def decode(self) -> Any:
return self._target

View File

@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from typing import Any, Tuple
from torchvision.datasets import VisionDataset
from .decoders import TargetDecoder, ImageDataDecoder
class ExtendedVisionDataset(VisionDataset):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) # type: ignore
def get_image_data(self, index: int) -> bytes:
raise NotImplementedError
def get_target(self, index: int) -> Any:
raise NotImplementedError
def __getitem__(self, index: int) -> Tuple[Any, Any]:
try:
image_data = self.get_image_data(index)
image = ImageDataDecoder(image_data).decode()
except Exception as e:
raise RuntimeError(f"can not read image for sample {index}") from e
target = self.get_target(index)
target = TargetDecoder(target).decode()
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
raise NotImplementedError

View File

@@ -0,0 +1,290 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import csv
from enum import Enum
import logging
import os
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
from .extended import ExtendedVisionDataset
logger = logging.getLogger("dinov2")
_Target = int
class _Split(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test" # NOTE: torchvision does not support the test split
@property
def length(self) -> int:
split_lengths = {
_Split.TRAIN: 1_281_167,
_Split.VAL: 50_000,
_Split.TEST: 100_000,
}
return split_lengths[self]
def get_dirname(self, class_id: Optional[str] = None) -> str:
return self.value if class_id is None else os.path.join(self.value, class_id)
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
dirname = self.get_dirname(class_id)
if self == _Split.TRAIN:
basename = f"{class_id}_{actual_index}"
else: # self in (_Split.VAL, _Split.TEST):
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
return os.path.join(dirname, basename + ".JPEG")
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
assert self != _Split.TEST
dirname, filename = os.path.split(image_relpath)
class_id = os.path.split(dirname)[-1]
basename, _ = os.path.splitext(filename)
actual_index = int(basename.split("_")[-1])
return class_id, actual_index
class ImageNet(ExtendedVisionDataset):
Target = Union[_Target]
Split = Union[_Split]
def __init__(
self,
*,
split: "ImageNet.Split",
root: str,
extra: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms, transform, target_transform)
self._extra_root = extra
self._split = split
self._entries = None
self._class_ids = None
self._class_names = None
@property
def split(self) -> "ImageNet.Split":
return self._split
def _get_extra_full_path(self, extra_path: str) -> str:
return os.path.join(self._extra_root, extra_path)
def _load_extra(self, extra_path: str) -> np.ndarray:
extra_full_path = self._get_extra_full_path(extra_path)
return np.load(extra_full_path, mmap_mode="r")
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
extra_full_path = self._get_extra_full_path(extra_path)
os.makedirs(self._extra_root, exist_ok=True)
np.save(extra_full_path, extra_array)
@property
def _entries_path(self) -> str:
return f"entries-{self._split.value.upper()}.npy"
@property
def _class_ids_path(self) -> str:
return f"class-ids-{self._split.value.upper()}.npy"
@property
def _class_names_path(self) -> str:
return f"class-names-{self._split.value.upper()}.npy"
def _get_entries(self) -> np.ndarray:
if self._entries is None:
self._entries = self._load_extra(self._entries_path)
assert self._entries is not None
return self._entries
def _get_class_ids(self) -> np.ndarray:
if self._split == _Split.TEST:
assert False, "Class IDs are not available in TEST split"
if self._class_ids is None:
self._class_ids = self._load_extra(self._class_ids_path)
assert self._class_ids is not None
return self._class_ids
def _get_class_names(self) -> np.ndarray:
if self._split == _Split.TEST:
assert False, "Class names are not available in TEST split"
if self._class_names is None:
self._class_names = self._load_extra(self._class_names_path)
assert self._class_names is not None
return self._class_names
def find_class_id(self, class_index: int) -> str:
class_ids = self._get_class_ids()
return str(class_ids[class_index])
def find_class_name(self, class_index: int) -> str:
class_names = self._get_class_names()
return str(class_names[class_index])
def get_image_data(self, index: int) -> bytes:
entries = self._get_entries()
actual_index = entries[index]["actual_index"]
class_id = self.get_class_id(index)
image_relpath = self.split.get_image_relpath(actual_index, class_id)
image_full_path = os.path.join(self.root, image_relpath)
with open(image_full_path, mode="rb") as f:
image_data = f.read()
return image_data
def get_target(self, index: int) -> Optional[Target]:
entries = self._get_entries()
class_index = entries[index]["class_index"]
return None if self.split == _Split.TEST else int(class_index)
def get_targets(self) -> Optional[np.ndarray]:
entries = self._get_entries()
return None if self.split == _Split.TEST else entries["class_index"]
def get_class_id(self, index: int) -> Optional[str]:
entries = self._get_entries()
class_id = entries[index]["class_id"]
return None if self.split == _Split.TEST else str(class_id)
def get_class_name(self, index: int) -> Optional[str]:
entries = self._get_entries()
class_name = entries[index]["class_name"]
return None if self.split == _Split.TEST else str(class_name)
def __len__(self) -> int:
entries = self._get_entries()
assert len(entries) == self.split.length
return len(entries)
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
labels_full_path = os.path.join(self.root, labels_path)
labels = []
try:
with open(labels_full_path, "r") as f:
reader = csv.reader(f)
for row in reader:
class_id, class_name = row
labels.append((class_id, class_name))
except OSError as e:
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
return labels
def _dump_entries(self) -> None:
split = self.split
if split == ImageNet.Split.TEST:
dataset = None
sample_count = split.length
max_class_id_length, max_class_name_length = 0, 0
else:
labels_path = "labels.txt"
logger.info(f'loading labels from "{labels_path}"')
labels = self._load_labels(labels_path)
# NOTE: Using torchvision ImageFolder for consistency
from torchvision.datasets import ImageFolder
dataset_root = os.path.join(self.root, split.get_dirname())
dataset = ImageFolder(dataset_root)
sample_count = len(dataset)
max_class_id_length, max_class_name_length = -1, -1
for sample in dataset.samples:
_, class_index = sample
class_id, class_name = labels[class_index]
max_class_id_length = max(len(class_id), max_class_id_length)
max_class_name_length = max(len(class_name), max_class_name_length)
dtype = np.dtype(
[
("actual_index", "<u4"),
("class_index", "<u4"),
("class_id", f"U{max_class_id_length}"),
("class_name", f"U{max_class_name_length}"),
]
)
entries_array = np.empty(sample_count, dtype=dtype)
if split == ImageNet.Split.TEST:
old_percent = -1
for index in range(sample_count):
percent = 100 * (index + 1) // sample_count
if percent > old_percent:
logger.info(f"creating entries: {percent}%")
old_percent = percent
actual_index = index + 1
class_index = np.uint32(-1)
class_id, class_name = "", ""
entries_array[index] = (actual_index, class_index, class_id, class_name)
else:
class_names = {class_id: class_name for class_id, class_name in labels}
assert dataset
old_percent = -1
for index in range(sample_count):
percent = 100 * (index + 1) // sample_count
if percent > old_percent:
logger.info(f"creating entries: {percent}%")
old_percent = percent
image_full_path, class_index = dataset.samples[index]
image_relpath = os.path.relpath(image_full_path, self.root)
class_id, actual_index = split.parse_image_relpath(image_relpath)
class_name = class_names[class_id]
entries_array[index] = (actual_index, class_index, class_id, class_name)
logger.info(f'saving entries to "{self._entries_path}"')
self._save_extra(entries_array, self._entries_path)
def _dump_class_ids_and_names(self) -> None:
split = self.split
if split == ImageNet.Split.TEST:
return
entries_array = self._load_extra(self._entries_path)
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
for entry in entries_array:
class_index, class_id, class_name = (
entry["class_index"],
entry["class_id"],
entry["class_name"],
)
max_class_index = max(int(class_index), max_class_index)
max_class_id_length = max(len(str(class_id)), max_class_id_length)
max_class_name_length = max(len(str(class_name)), max_class_name_length)
class_count = max_class_index + 1
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
for entry in entries_array:
class_index, class_id, class_name = (
entry["class_index"],
entry["class_id"],
entry["class_name"],
)
class_ids_array[class_index] = class_id
class_names_array[class_index] = class_name
logger.info(f'saving class IDs to "{self._class_ids_path}"')
self._save_extra(class_ids_array, self._class_ids_path)
logger.info(f'saving class names to "{self._class_names_path}"')
self._save_extra(class_names_array, self._class_names_path)
def dump_extra(self) -> None:
self._dump_entries()
self._dump_class_ids_and_names()

View File

@@ -0,0 +1,302 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from gzip import GzipFile
from io import BytesIO
from mmap import ACCESS_READ, mmap
import os
from typing import Any, Callable, List, Optional, Set, Tuple
import warnings
import numpy as np
from .extended import ExtendedVisionDataset
_Labels = int
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
@dataclass
class _ClassEntry:
block_offset: int
maybe_filename: Optional[str] = None
@dataclass
class _Entry:
class_index: int # noqa: E701
start_offset: int
end_offset: int
filename: str
class _Split(Enum):
TRAIN = "train"
VAL = "val"
@property
def length(self) -> int:
return {
_Split.TRAIN: 11_797_647,
_Split.VAL: 561_050,
}[self]
def entries_path(self):
return f"imagenet21kp_{self.value}.txt"
def _get_tarball_path(class_id: str) -> str:
return f"{class_id}.tar"
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
@lru_cache(maxsize=mmap_cache_size)
def _mmap_tarball(class_id: str) -> mmap:
tarball_path = _get_tarball_path(class_id)
tarball_full_path = os.path.join(tarballs_root, tarball_path)
with open(tarball_full_path) as f:
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
return _mmap_tarball
class ImageNet22k(ExtendedVisionDataset):
_GZIPPED_INDICES: Set[int] = {
841_545,
1_304_131,
2_437_921,
2_672_079,
2_795_676,
2_969_786,
6_902_965,
6_903_550,
6_903_628,
7_432_557,
7_432_589,
7_813_809,
8_329_633,
10_296_990,
10_417_652,
10_492_265,
10_598_078,
10_782_398,
10_902_612,
11_203_736,
11_342_890,
11_397_596,
11_589_762,
11_705_103,
12_936_875,
13_289_782,
}
Labels = _Labels
def __init__(
self,
*,
root: str,
extra: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
) -> None:
super().__init__(root, transforms, transform, target_transform)
self._extra_root = extra
entries_path = self._get_entries_path(root)
self._entries = self._load_extra(entries_path)
class_ids_path = self._get_class_ids_path(root)
self._class_ids = self._load_extra(class_ids_path)
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
def _get_entries_path(self, root: Optional[str] = None) -> str:
return "entries.npy"
def _get_class_ids_path(self, root: Optional[str] = None) -> str:
return "class-ids.npy"
def _find_class_ids(self, path: str) -> List[str]:
class_ids = []
with os.scandir(path) as entries:
for entry in entries:
root, ext = os.path.splitext(entry.name)
if ext != ".tar":
continue
class_ids.append(root)
return sorted(class_ids)
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
root = self.get_root(root)
entries: List[_Entry] = []
class_ids = self._find_class_ids(root)
for class_index, class_id in enumerate(class_ids):
path = os.path.join(root, "blocks", f"{class_id}.log")
class_entries = []
try:
with open(path) as f:
for line in f:
line = line.rstrip()
block, filename = line.split(":")
block_offset = int(block[6:])
filename = filename[1:]
maybe_filename = None
if filename != "** Block of NULs **":
maybe_filename = filename
_, ext = os.path.splitext(filename)
# assert ext == ".JPEG"
class_entry = _ClassEntry(block_offset, maybe_filename)
class_entries.append(class_entry)
except OSError as e:
raise RuntimeError(f'can not read blocks file "{path}"') from e
assert class_entries[-1].maybe_filename is None
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
assert class_entry1.block_offset <= class_entry2.block_offset
start_offset = 512 * class_entry1.block_offset
end_offset = 512 * class_entry2.block_offset
assert class_entry1.maybe_filename is not None
filename = class_entry1.maybe_filename
entry = _Entry(class_index, start_offset, end_offset, filename)
# Skip invalid image files (PIL throws UnidentifiedImageError)
if filename == "n06470073_47249.JPEG":
continue
entries.append(entry)
return entries, class_ids
def _load_extra(self, extra_path: str) -> np.ndarray:
extra_root = self._extra_root
extra_full_path = os.path.join(extra_root, extra_path)
return np.load(extra_full_path, mmap_mode="r")
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
extra_root = self._extra_root
extra_full_path = os.path.join(extra_root, extra_path)
os.makedirs(extra_root, exist_ok=True)
np.save(extra_full_path, extra_array)
@property
def _tarballs_root(self) -> str:
return self.root
def find_class_id(self, class_index: int) -> str:
return str(self._class_ids[class_index])
def get_image_data(self, index: int) -> bytes:
entry = self._entries[index]
class_id = entry["class_id"]
class_mmap = self._mmap_tarball(class_id)
start_offset, end_offset = entry["start_offset"], entry["end_offset"]
try:
mapped_data = class_mmap[start_offset:end_offset]
data = mapped_data[512:] # Skip entry header block
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
with GzipFile(fileobj=BytesIO(data)) as g:
data = g.read()
except Exception as e:
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
return data
def get_target(self, index: int) -> Any:
return int(self._entries[index]["class_index"])
def get_targets(self) -> np.ndarray:
return self._entries["class_index"]
def get_class_id(self, index: int) -> str:
return str(self._entries[index]["class_id"])
def get_class_ids(self) -> np.ndarray:
return self._entries["class_id"]
def __getitem__(self, index: int) -> Tuple[Any, Any]:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return super().__getitem__(index)
def __len__(self) -> int:
return len(self._entries)
def _dump_entries(self, *args, **kwargs) -> None:
entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
for entry in entries:
class_id = class_ids[entry.class_index]
max_class_index = max(entry.class_index, max_class_index)
max_class_id_length = max(len(class_id), max_class_id_length)
max_filename_length = max(len(entry.filename), max_filename_length)
dtype = np.dtype(
[
("class_index", "<u4"),
("class_id", f"U{max_class_id_length}"),
("start_offset", "<u4"),
("end_offset", "<u4"),
("filename", f"U{max_filename_length}"),
]
)
sample_count = len(entries)
entries_array = np.empty(sample_count, dtype=dtype)
for i, entry in enumerate(entries):
class_index = entry.class_index
class_id = class_ids[class_index]
start_offset = entry.start_offset
end_offset = entry.end_offset
filename = entry.filename
entries_array[i] = (
class_index,
class_id,
start_offset,
end_offset,
filename,
)
entries_path = self._get_entries_path(*args, **kwargs)
self._save_extra(entries_array, entries_path)
def _dump_class_ids(self, *args, **kwargs) -> None:
entries_path = self._get_entries_path(*args, **kwargs)
entries_array = self._load_extra(entries_path)
max_class_id_length, max_class_index = -1, -1
for entry in entries_array:
class_index, class_id = entry["class_index"], entry["class_id"]
max_class_index = max(int(class_index), max_class_index)
max_class_id_length = max(len(str(class_id)), max_class_id_length)
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
for entry in entries_array:
class_index, class_id = entry["class_index"], entry["class_id"]
class_ids_array[class_index] = class_id
class_ids_path = self._get_class_ids_path(*args, **kwargs)
self._save_extra(class_ids_array, class_ids_path)
def _dump_extra(self, *args, **kwargs) -> None:
self._dump_entries(*args, *kwargs)
self._dump_class_ids(*args, *kwargs)
def dump_extra(self, root: Optional[str] = None) -> None:
return self._dump_extra(root)

View File

@@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import logging
from enum import Enum
from typing import Any, Callable, List, Optional, TypeVar
import torch
from torch.utils.data import Sampler
from .datasets import ImageNet, ImageNet22k
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
logger = logging.getLogger("dinov2")
class SamplerType(Enum):
DISTRIBUTED = 0
EPOCH = 1
INFINITE = 2
SHARDED_INFINITE = 3
SHARDED_INFINITE_NEW = 4
def _make_bool_str(b: bool) -> str:
return "yes" if b else "no"
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
def transform(sample):
image, target = sample
if image_transform is not None:
image = image_transform(image)
if target_transform is not None:
target = target_transform(target)
return image, target
return transform
def _parse_dataset_str(dataset_str: str):
tokens = dataset_str.split(":")
name = tokens[0]
kwargs = {}
for token in tokens[1:]:
key, value = token.split("=")
assert key in ("root", "extra", "split")
kwargs[key] = value
if name == "ImageNet":
class_ = ImageNet
if "split" in kwargs:
kwargs["split"] = ImageNet.Split[kwargs["split"]]
elif name == "ImageNet22k":
class_ = ImageNet22k
else:
raise ValueError(f'Unsupported dataset "{name}"')
return class_, kwargs
def make_dataset(
*,
dataset_str: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
"""
Creates a dataset with the specified parameters.
Args:
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
transform: A transform to apply to images.
target_transform: A transform to apply to targets.
Returns:
The created dataset.
"""
logger.info(f'using dataset: "{dataset_str}"')
class_, kwargs = _parse_dataset_str(dataset_str)
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
logger.info(f"# of dataset samples: {len(dataset):,d}")
# Aggregated datasets do not expose (yet) these attributes, so add them.
if not hasattr(dataset, "transform"):
setattr(dataset, "transform", transform)
if not hasattr(dataset, "target_transform"):
setattr(dataset, "target_transform", target_transform)
return dataset
def _make_sampler(
*,
dataset,
type: Optional[SamplerType] = None,
shuffle: bool = False,
seed: int = 0,
size: int = -1,
advance: int = 0,
) -> Optional[Sampler]:
sample_count = len(dataset)
if type == SamplerType.INFINITE:
logger.info("sampler: infinite")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
return InfiniteSampler(
sample_count=sample_count,
shuffle=shuffle,
seed=seed,
advance=advance,
)
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
logger.info("sampler: sharded infinite")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
# TODO: Remove support for old shuffling
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
return ShardedInfiniteSampler(
sample_count=sample_count,
shuffle=shuffle,
seed=seed,
advance=advance,
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
)
elif type == SamplerType.EPOCH:
logger.info("sampler: epoch")
if advance > 0:
raise NotImplementedError("sampler advance > 0 is not supported")
size = size if size > 0 else sample_count
logger.info(f"# of samples / epoch: {size:,d}")
return EpochSampler(
size=size,
sample_count=sample_count,
shuffle=shuffle,
seed=seed,
)
elif type == SamplerType.DISTRIBUTED:
logger.info("sampler: distributed")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
if advance > 0:
raise ValueError("sampler advance > 0 is invalid")
return torch.utils.data.DistributedSampler(
dataset=dataset,
shuffle=shuffle,
seed=seed,
drop_last=False,
)
logger.info("sampler: none")
return None
T = TypeVar("T")
def make_data_loader(
*,
dataset,
batch_size: int,
num_workers: int,
shuffle: bool = True,
seed: int = 0,
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
sampler_size: int = -1,
sampler_advance: int = 0,
drop_last: bool = True,
persistent_workers: bool = False,
collate_fn: Optional[Callable[[List[T]], Any]] = None,
):
"""
Creates a data loader with the specified parameters.
Args:
dataset: A dataset (third party, LaViDa or WebDataset).
batch_size: The size of batches to generate.
num_workers: The number of workers to use.
shuffle: Whether to shuffle samples.
seed: The random seed to use.
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
sampler_advance: How many samples to skip (when applicable).
drop_last: Whether the last non-full batch of data should be dropped.
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
collate_fn: Function that performs batch collation
"""
sampler = _make_sampler(
dataset=dataset,
type=sampler_type,
shuffle=shuffle,
seed=seed,
size=sampler_size,
advance=sampler_advance,
)
logger.info("using PyTorch data loader")
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
drop_last=drop_last,
persistent_workers=persistent_workers,
collate_fn=collate_fn,
)
try:
logger.info(f"# of batches: {len(data_loader):,d}")
except TypeError: # data loader has no length
logger.info("infinite data loader")
return data_loader

View File

@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import random
import math
import numpy as np
class MaskingGenerator:
def __init__(
self,
input_size,
num_masking_patches=None,
min_num_patches=4,
max_num_patches=None,
min_aspect=0.3,
max_aspect=None,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self):
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.min_num_patches,
self.max_num_patches,
self.num_masking_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for _ in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top : top + h, left : left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self, num_masking_patches=0):
mask = np.zeros(shape=self.get_shape(), dtype=bool)
mask_count = 0
while mask_count < num_masking_patches:
max_mask_patches = num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask

View File

@@ -0,0 +1,229 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import itertools
from typing import Any, Optional
import warnings
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
import dinov2.distributed as distributed
class EpochSampler(Sampler):
def __init__(
self,
*,
size: int,
sample_count: int,
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
):
self._size = size
self._sample_count = sample_count
self._shuffle = shuffle
self._seed = seed
self._start = distributed.get_global_rank() if start is None else start
self._step = distributed.get_global_size() if step is None else step
self._epoch = 0
def __iter__(self):
count = (self._size + self._sample_count - 1) // self._sample_count
tiled_indices = np.tile(np.arange(self._sample_count), count)
if self._shuffle:
seed = self._seed * self._epoch if self._seed != 0 else self._epoch
rng = np.random.default_rng(seed)
iterable = rng.choice(tiled_indices, self._size, replace=False)
else:
iterable = tiled_indices[: self._size]
yield from itertools.islice(iterable, self._start, None, self._step)
def __len__(self):
return (self._size - self._start + self._step - 1) // self._step
def set_epoch(self, epoch):
self._epoch = epoch
def _get_numpy_dtype(size: int) -> Any:
return np.int32 if size <= 2**31 else np.int64
def _get_torch_dtype(size: int) -> Any:
return torch.int32 if size <= 2**31 else torch.int64
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
"""Generate the indices of a random permutation."""
dtype = _get_torch_dtype(size)
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
perm = torch.arange(size, dtype=dtype)
for i in range(size):
j = torch.randint(i, size, size=(1,), generator=generator).item()
# Always swap even if no-op
value = perm[j].item()
perm[j] = perm[i].item()
perm[i] = value
yield value
class InfiniteSampler(Sampler):
def __init__(
self,
*,
sample_count: int,
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
advance: int = 0,
):
self._sample_count = sample_count
self._seed = seed
self._shuffle = shuffle
self._start = distributed.get_global_rank() if start is None else start
self._step = distributed.get_global_size() if step is None else step
self._advance = advance
def __iter__(self):
if self._shuffle:
iterator = self._shuffled_iterator()
else:
iterator = self._iterator()
yield from itertools.islice(iterator, self._advance, None)
def _iterator(self):
assert not self._shuffle
while True:
iterable = range(self._sample_count)
yield from itertools.islice(iterable, self._start, None, self._step)
def _shuffled_iterator(self):
assert self._shuffle
# Instantiate a generator here (rather than in the ctor) to keep the class
# picklable (requirement of mp.spawn)
generator = torch.Generator().manual_seed(self._seed)
while True:
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
yield from itertools.islice(iterable, self._start, None, self._step)
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
# but avoids a full in-place random permutation generation.
def _shuffle_tensor_slice(
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
stop = len(tensor)
count = stop // step
drop_count = stop - step * count
if drop_count:
warnings.warn(f"# of dropped samples: {drop_count}")
dtype = _get_numpy_dtype(stop)
result = np.empty(count, dtype=dtype)
for i in range(count):
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
result[i] = result[j]
result[j] = tensor[start + i * step].item()
return result
def _new_shuffle_tensor_slice(
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
stop = len(tensor)
count = stop // step
dtype = torch.int64 # Needed for using randperm result as indices
count = stop // step
drop_count = stop - step * count
if drop_count:
warnings.warn(f"# of dropped samples: {drop_count}")
indices = torch.randperm(count, dtype=dtype, generator=generator)
return tensor[start::step][indices].numpy()
def _make_seed(seed: int, start: int, iter_count: int) -> int:
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
return seed + start + (iter_count << 24)
class ShardedInfiniteSampler(Sampler):
def __init__(
self,
*,
sample_count: int,
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
advance: int = 0,
use_new_shuffle_tensor_slice: bool = False,
):
self._sample_count = sample_count
self._seed = seed
self._shuffle = shuffle
self._start = distributed.get_global_rank() if start is None else start
self._step = distributed.get_global_size() if step is None else step
self._advance = advance
self._iter_count = 0
self._shuffle_tensor_slice_fn = (
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
)
def __iter__(self):
iter_count = self._advance // self._sample_count
if iter_count > 0:
self._advance -= iter_count * self._sample_count
self._iter_count += iter_count
if self._shuffle:
iterator = self._shuffled_iterator()
else:
iterator = self._iterator()
yield from itertools.islice(iterator, self._advance, None)
def _iterator(self):
assert not self._shuffle
while True:
iterable = range(self._sample_count)
yield from itertools.islice(iterable, self._start, None, self._step)
def _shuffled_iterator(self):
assert self._shuffle
# Instantiate a generator here (rather than in the ctor) to be keep the class
# picklable (requirement of mp.spawn)
generator = torch.Generator()
# Always shuffle everything first
generator.manual_seed(self._seed)
dtype = _get_torch_dtype(self._sample_count)
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
while True:
# Re-seed on each iteration to allow skipping whole permutations
seed = _make_seed(self._seed, self._start, self._iter_count)
generator.manual_seed(seed)
iterable = self._shuffle_tensor_slice_fn(
tensor=perm, start=self._start, step=self._step, generator=generator
)
yield from iterable
self._iter_count += 1

View File

@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from typing import Sequence
import torch
from torchvision import transforms
class GaussianBlur(transforms.RandomApply):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
# NOTE: torchvision is applying 1 - probability to return the original image
keep_p = 1 - p
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
super().__init__(transforms=[transform], p=keep_p)
class MaybeToTensor(transforms.ToTensor):
"""
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, torch.Tensor):
return pic
return super().__call__(pic)
# Use timm's names
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def make_normalize_transform(
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Normalize:
return transforms.Normalize(mean=mean, std=std)
# This roughly matches torchvision's preset for classification training:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
def make_classification_train_transform(
*,
crop_size: int = 224,
interpolation=transforms.InterpolationMode.BICUBIC,
hflip_prob: float = 0.5,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
):
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0.0:
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
transforms_list.extend(
[
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
)
return transforms.Compose(transforms_list)
# This matches (roughly) torchvision's preset for classification evaluation:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
def make_classification_eval_transform(
*,
resize_size: int = 256,
interpolation=transforms.InterpolationMode.BICUBIC,
crop_size: int = 224,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
return transforms.Compose(transforms_list)

View File

@@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import os
import random
import re
import socket
from typing import Dict, List
import torch
import torch.distributed as dist
_LOCAL_RANK = -1
_LOCAL_WORLD_SIZE = -1
def is_enabled() -> bool:
"""
Returns:
True if distributed training is enabled
"""
return dist.is_available() and dist.is_initialized()
def get_global_size() -> int:
"""
Returns:
The number of processes in the process group
"""
return dist.get_world_size() if is_enabled() else 1
def get_global_rank() -> int:
"""
Returns:
The rank of the current process within the global process group.
"""
return dist.get_rank() if is_enabled() else 0
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not is_enabled():
return 0
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
return _LOCAL_RANK
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not is_enabled():
return 1
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
return _LOCAL_WORLD_SIZE
def is_main_process() -> bool:
"""
Returns:
True if the current process is the main one.
"""
return get_global_rank() == 0
def _restrict_print_to_main_process() -> None:
"""
This function disables printing when not in the main process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_main_process() or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def _get_master_port(seed: int = 0) -> int:
MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
master_port_str = os.environ.get("MASTER_PORT")
if master_port_str is None:
rng = random.Random(seed)
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
return int(master_port_str)
def _get_available_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
# A "" host address means INADDR_ANY i.e. binding to all interfaces.
# Note this is not compatible with IPv6.
s.bind(("", 0))
port = s.getsockname()[1]
return port
_TORCH_DISTRIBUTED_ENV_VARS = (
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_WORLD_SIZE",
)
def _collect_env_vars() -> Dict[str, str]:
return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
def _is_slurm_job_process() -> bool:
return "SLURM_JOB_ID" in os.environ
def _parse_slurm_node_list(s: str) -> List[str]:
nodes = []
# Extract "hostname", "hostname[1-2,3,4-5]," substrings
p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
for m in p.finditer(s):
prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
for suffix in suffixes.split(","):
span = suffix.split("-")
if len(span) == 1:
nodes.append(prefix + suffix)
else:
width = len(span[0])
start, end = int(span[0]), int(span[1]) + 1
nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
return nodes
def _check_env_variable(key: str, new_value: str):
# Only check for difference with preset environment variables
if key in os.environ and os.environ[key] != new_value:
raise RuntimeError(f"Cannot export environment variables as {key} is already set")
class _TorchDistributedEnvironment:
def __init__(self):
self.master_addr = "127.0.0.1"
self.master_port = 0
self.rank = -1
self.world_size = -1
self.local_rank = -1
self.local_world_size = -1
if _is_slurm_job_process():
return self._set_from_slurm_env()
env_vars = _collect_env_vars()
if not env_vars:
# Environment is not set
pass
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
# Environment is fully set
return self._set_from_preset_env()
else:
# Environment is partially set
collected_env_vars = ", ".join(env_vars.keys())
raise RuntimeError(f"Partially set environment: {collected_env_vars}")
if torch.cuda.device_count() > 0:
return self._set_from_local()
raise RuntimeError("Can't initialize PyTorch distributed environment")
# Slurm job created with sbatch, submitit, etc...
def _set_from_slurm_env(self):
# logger.info("Initialization from Slurm environment")
job_id = int(os.environ["SLURM_JOB_ID"])
node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
assert len(nodes) == node_count
self.master_addr = nodes[0]
self.master_port = _get_master_port(seed=job_id)
self.rank = int(os.environ["SLURM_PROCID"])
self.world_size = int(os.environ["SLURM_NTASKS"])
assert self.rank < self.world_size
self.local_rank = int(os.environ["SLURM_LOCALID"])
self.local_world_size = self.world_size // node_count
assert self.local_rank < self.local_world_size
# Single node job with preset environment (i.e. torchrun)
def _set_from_preset_env(self):
# logger.info("Initialization from preset environment")
self.master_addr = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
self.rank = int(os.environ["RANK"])
self.world_size = int(os.environ["WORLD_SIZE"])
assert self.rank < self.world_size
self.local_rank = int(os.environ["LOCAL_RANK"])
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
assert self.local_rank < self.local_world_size
# Single node and GPU job (i.e. local script run)
def _set_from_local(self):
# logger.info("Initialization from local")
self.master_addr = "127.0.0.1"
self.master_port = _get_available_port()
self.rank = 0
self.world_size = 1
self.local_rank = 0
self.local_world_size = 1
def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
# See the "Environment variable initialization" section from
# https://pytorch.org/docs/stable/distributed.html for the complete list of
# environment variables required for the env:// initialization method.
env_vars = {
"MASTER_ADDR": self.master_addr,
"MASTER_PORT": str(self.master_port),
"RANK": str(self.rank),
"WORLD_SIZE": str(self.world_size),
"LOCAL_RANK": str(self.local_rank),
"LOCAL_WORLD_SIZE": str(self.local_world_size),
}
if not overwrite:
for k, v in env_vars.items():
_check_env_variable(k, v)
os.environ.update(env_vars)
return self
def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
"""Enable distributed mode
Args:
set_cuda_current_device: If True, call torch.cuda.set_device() to set the
current PyTorch CUDA device to the one matching the local rank.
overwrite: If True, overwrites already set variables. Else fails.
"""
global _LOCAL_RANK, _LOCAL_WORLD_SIZE
if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
raise RuntimeError("Distributed mode has already been enabled")
torch_env = _TorchDistributedEnvironment()
torch_env.export(overwrite=overwrite)
if set_cuda_current_device:
torch.cuda.set_device(torch_env.local_rank)
if allow_nccl_timeout:
# This allows to use torch distributed timeout in a NCCL backend
key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
if not overwrite:
_check_env_variable(key, value)
os.environ[key] = value
dist.init_process_group(backend="nccl")
dist.barrier()
# Finalize setup
_LOCAL_RANK = torch_env.local_rank
_LOCAL_WORLD_SIZE = torch_env.local_world_size
_restrict_print_to_main_process()

View File

@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .backbones import * # noqa: F403
from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
from .decode_heads import * # noqa: F403
from .depther import * # noqa: F403
from .losses import * # noqa: F403

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .vision_transformer import DinoVisionTransformer

View File

@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.runner import BaseModule
from ..builder import BACKBONES
@BACKBONES.register_module()
class DinoVisionTransformer(BaseModule):
"""Vision Transformer."""
def __init__(self, *args, **kwargs):
super().__init__()

View File

@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
from mmcv.utils import Registry
MODELS = Registry("models", parent=MMCV_MODELS)
ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
DEPTHER = MODELS
def build_backbone(cfg):
"""Build backbone."""
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
return NECKS.build(cfg)
def build_head(cfg):
"""Build head."""
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss."""
return LOSSES.build(cfg)
def build_depther(cfg, train_cfg=None, test_cfg=None):
"""Build depther."""
if train_cfg is not None or test_cfg is not None:
warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))

View File

@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .dpt_head import DPTHead
from .linear_head import BNHead

View File

@@ -0,0 +1,225 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import copy
from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from ...ops import resize
from ..builder import build_loss
class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (List): Input channels.
channels (int): Channels after modules, before conv_depth.
conv_cfg (dict|None): Config of conv layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
loss_decode (dict): Config of decode loss.
Default: dict(type='SigLoss').
sampler (dict|None): The config of depth map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
min_depth (int): Min depth in dataset setting.
Default: 1e-3.
max_depth (int): Max depth in dataset setting.
Default: None.
norm_cfg (dict|None): Config of norm layers.
Default: None.
classify (bool): Whether predict depth in a cls.-reg. manner.
Default: False.
n_bins (int): The number of bins used in cls. step.
Default: 256.
bins_strategy (str): The discrete strategy used in cls. step.
Default: 'UD'.
norm_strategy (str): The norm strategy on cls. probability
distribution. Default: 'linear'
scale_up (str): Whether predict depth in a scale-up manner.
Default: False.
"""
def __init__(
self,
in_channels,
channels=96,
conv_cfg=None,
act_cfg=dict(type="ReLU"),
loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
sampler=None,
align_corners=False,
min_depth=1e-3,
max_depth=None,
norm_cfg=None,
classify=False,
n_bins=256,
bins_strategy="UD",
norm_strategy="linear",
scale_up=False,
):
super(DepthBaseDecodeHead, self).__init__()
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.act_cfg = act_cfg
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
self.align_corners = align_corners
self.min_depth = min_depth
self.max_depth = max_depth
self.norm_cfg = norm_cfg
self.classify = classify
self.n_bins = n_bins
self.scale_up = scale_up
if self.classify:
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
self.bins_strategy = bins_strategy
self.norm_strategy = norm_strategy
self.softmax = nn.Softmax(dim=1)
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
else:
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
self.fp16_enabled = False
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def extra_repr(self):
"""Extra repr."""
s = f"align_corners={self.align_corners}"
return s
@auto_fp16()
@abstractmethod
def forward(self, inputs, img_metas):
"""Placeholder of forward function."""
pass
def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
depth_gt (Tensor): GT depth
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
depth_pred = self.forward(inputs, img_metas)
losses = self.losses(depth_pred, depth_gt)
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
losses.update(**log_imgs)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output depth map.
"""
return self.forward(inputs, img_metas)
def depth_pred(self, feat):
"""Prediction each pixel."""
if self.classify:
logit = self.conv_depth(feat)
if self.bins_strategy == "UD":
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
elif self.bins_strategy == "SID":
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
# following Adabins, default linear
if self.norm_strategy == "linear":
logit = torch.relu(logit)
eps = 0.1
logit = logit + eps
logit = logit / logit.sum(dim=1, keepdim=True)
elif self.norm_strategy == "softmax":
logit = torch.softmax(logit, dim=1)
elif self.norm_strategy == "sigmoid":
logit = torch.sigmoid(logit)
logit = logit / logit.sum(dim=1, keepdim=True)
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
else:
if self.scale_up:
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
else:
output = self.relu(self.conv_depth(feat)) + self.min_depth
return output
@force_fp32(apply_to=("depth_pred",))
def losses(self, depth_pred, depth_gt):
"""Compute depth loss."""
loss = dict()
depth_pred = resize(
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
else:
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
return loss
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
show_img = show_img.numpy().astype(np.float32)
show_img = mmcv.imdenormalize(
show_img,
img_meta["img_norm_cfg"]["mean"],
img_meta["img_norm_cfg"]["std"],
img_meta["img_norm_cfg"]["to_rgb"],
)
show_img = np.clip(show_img, 0, 255)
show_img = show_img.astype(np.uint8)
show_img = show_img[:, :, ::-1]
show_img = show_img.transpose(0, 2, 1)
show_img = show_img.transpose(1, 0, 2)
depth_pred = depth_pred / torch.max(depth_pred)
depth_gt = depth_gt / torch.max(depth_gt)
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}

View File

@@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Linear, build_activation_layer
from mmcv.runner import BaseModule
from ...ops import resize
from ..builder import HEADS
from .decode_head import DepthBaseDecodeHead
class Interpolate(nn.Module):
def __init__(self, scale_factor, mode, align_corners=False):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
return x
class HeadDepth(nn.Module):
def __init__(self, features):
super(HeadDepth, self).__init__()
self.head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
)
def forward(self, x):
x = self.head(x)
return x
class ReassembleBlocks(BaseModule):
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
rearrange the feature vector to feature map.
Args:
in_channels (int): ViT feature channels. Default: 768.
out_channels (List): output channels of each stage.
Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(
self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
):
super(ReassembleBlocks, self).__init__(init_cfg)
assert readout_type in ["ignore", "add", "project"]
self.readout_type = readout_type
self.patch_size = patch_size
self.projects = nn.ModuleList(
[
ConvModule(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
act_cfg=None,
)
for out_channel in out_channels
]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
if self.readout_type == "project":
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
)
def forward(self, inputs):
assert isinstance(inputs, list)
out = []
for i, x in enumerate(inputs):
assert len(x) == 2
x, cls_token = x[0], x[1]
feature_shape = x.shape
if self.readout_type == "project":
x = x.flatten(2).permute((0, 2, 1))
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
x = x.permute(0, 2, 1).reshape(feature_shape)
elif self.readout_type == "add":
x = x.flatten(2) + cls_token.unsqueeze(-1)
x = x.reshape(feature_shape)
else:
pass
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
return out
class PreActResidualConvUnit(BaseModule):
"""ResidualConvUnit, pre-activate residual unit.
Args:
in_channels (int): number of channels in the input feature map.
act_cfg (dict): dictionary to construct and config activation layer.
norm_cfg (dict): dictionary to construct and config norm layer.
stride (int): stride of the first block. Default: 1
dilation (int): dilation rate for convs layers. Default: 1.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
super(PreActResidualConvUnit, self).__init__(init_cfg)
self.conv1 = ConvModule(
in_channels,
in_channels,
3,
stride=stride,
padding=dilation,
dilation=dilation,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=("act", "conv", "norm"),
)
self.conv2 = ConvModule(
in_channels,
in_channels,
3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=("act", "conv", "norm"),
)
def forward(self, inputs):
inputs_ = inputs.clone()
x = self.conv1(inputs)
x = self.conv2(x)
return x + inputs_
class FeatureFusionBlock(BaseModule):
"""FeatureFusionBlock, merge feature map from different stages.
Args:
in_channels (int): Input channels.
act_cfg (dict): The activation config for ResidualConvUnit.
norm_cfg (dict): Config dict for normalization layer.
expand (bool): Whether expand the channels in post process block.
Default: False.
align_corners (bool): align_corner setting for bilinear upsample.
Default: True.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
super(FeatureFusionBlock, self).__init__(init_cfg)
self.in_channels = in_channels
self.expand = expand
self.align_corners = align_corners
self.out_channels = in_channels
if self.expand:
self.out_channels = in_channels // 2
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
def forward(self, *inputs):
x = inputs[0]
if len(inputs) == 2:
if x.shape != inputs[1].shape:
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
else:
res = inputs[1]
x = x + self.res_conv_unit1(res)
x = self.res_conv_unit2(x)
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
x = self.project(x)
return x
@HEADS.register_module()
class DPTHead(DepthBaseDecodeHead):
"""Vision Transformers for Dense Prediction.
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
Args:
embed_dims (int): The embed dimension of the ViT backbone.
Default: 768.
post_process_channels (List): Out channels of post process conv
layers. Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
expand_channels (bool): Whether expand the channels in post process
block. Default: False.
"""
def __init__(
self,
embed_dims=768,
post_process_channels=[96, 192, 384, 768],
readout_type="ignore",
patch_size=16,
expand_channels=False,
**kwargs
):
super(DPTHead, self).__init__(**kwargs)
self.in_channels = self.in_channels
self.expand_channels = expand_channels
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
self.post_process_channels = [
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
]
self.convs = nn.ModuleList()
for channel in self.post_process_channels:
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
self.fusion_blocks = nn.ModuleList()
for _ in range(len(self.convs)):
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
self.fusion_blocks[0].res_conv_unit1 = None
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
self.num_fusion_blocks = len(self.fusion_blocks)
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
self.num_post_process_channels = len(self.post_process_channels)
assert self.num_fusion_blocks == self.num_reassemble_blocks
assert self.num_reassemble_blocks == self.num_post_process_channels
self.conv_depth = HeadDepth(self.channels)
def forward(self, inputs, img_metas):
assert len(inputs) == self.num_reassemble_blocks
x = [inp for inp in inputs]
x = self.reassemble_blocks(x)
x = [self.convs[i](feature) for i, feature in enumerate(x)]
out = self.fusion_blocks[0](x[-1])
for i in range(1, len(self.fusion_blocks)):
out = self.fusion_blocks[i](out, x[-(i + 1)])
out = self.project(out)
out = self.depth_pred(out)
return out

View File

@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...ops import resize
from ..builder import HEADS
from .decode_head import DepthBaseDecodeHead
@HEADS.register_module()
class BNHead(DepthBaseDecodeHead):
"""Just a batchnorm."""
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
super().__init__(**kwargs)
self.input_transform = input_transform
self.in_index = in_index
self.upsample = upsample
# self.bn = nn.SyncBatchNorm(self.in_channels)
if self.classify:
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
else:
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if "concat" in self.input_transform:
inputs = [inputs[i] for i in self.in_index]
if "resize" in self.input_transform:
inputs = [
resize(
input=x,
size=[s * self.upsample for s in inputs[0].shape[2:]],
mode="bilinear",
align_corners=self.align_corners,
)
for x in inputs
]
inputs = torch.cat(inputs, dim=1)
elif self.input_transform == "multiple_select":
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def _forward_feature(self, inputs, img_metas=None, **kwargs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
# accept lists (for cls token)
inputs = list(inputs)
for i, x in enumerate(inputs):
if len(x) == 2:
x, cls_token = x[0], x[1]
if len(x.shape) == 2:
x = x[:, :, None, None]
cls_token = cls_token[:, :, None, None].expand_as(x)
inputs[i] = torch.cat((x, cls_token), 1)
else:
x = x[0]
if len(x.shape) == 2:
x = x[:, :, None, None]
inputs[i] = x
x = self._transform_inputs(inputs)
# feats = self.bn(x)
return x
def forward(self, inputs, img_metas=None, **kwargs):
"""Forward function."""
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
output = self.depth_pred(output)
return output

View File

@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .base import BaseDepther
from .encoder_decoder import DepthEncoderDecoder

View File

@@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import torch
import torch.distributed as dist
from mmcv.runner import BaseModule, auto_fp16
class BaseDepther(BaseModule, metaclass=ABCMeta):
"""Base class for depther."""
def __init__(self, init_cfg=None):
super(BaseDepther, self).__init__(init_cfg)
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the depther has neck"""
return hasattr(self, "neck") and self.neck is not None
@property
def with_auxiliary_head(self):
"""bool: whether the depther has auxiliary head"""
return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
@property
def with_decode_head(self):
"""bool: whether the depther has decode head"""
return hasattr(self, "decode_head") and self.decode_head is not None
@abstractmethod
def extract_feat(self, imgs):
"""Placeholder for extract features from images."""
pass
@abstractmethod
def encode_decode(self, img, img_metas):
"""Placeholder for encode images with backbone and decode into a
semantic depth map of the same size as input."""
pass
@abstractmethod
def forward_train(self, imgs, img_metas, **kwargs):
"""Placeholder for Forward function for training."""
pass
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
"""Placeholder for single image test."""
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Placeholder for augmentation test."""
pass
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
if not isinstance(var, list):
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
# all images in the same aug batch all of the same ori_shape and pad
# shape
for img_meta in img_metas:
ori_shapes = [_["ori_shape"] for _ in img_meta]
assert all(shape == ori_shapes[0] for shape in ori_shapes)
img_shapes = [_["img_shape"] for _ in img_meta]
assert all(shape == img_shapes[0] for shape in img_shapes)
pad_shapes = [_["pad_shape"] for _ in img_meta]
assert all(shape == pad_shapes[0] for shape in pad_shapes)
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
@auto_fp16(apply_to=("img",))
def forward(self, img, img_metas, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
def train_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
losses = self(**data_batch)
# split losses and images
real_losses = {}
log_imgs = {}
for k, v in losses.items():
if "img" in k:
log_imgs[k] = v
else:
real_losses[k] = v
loss, log_vars = self._parse_losses(real_losses)
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
return outputs
def val_step(self, data_batch, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
output = self(**data_batch, **kwargs)
return output
@staticmethod
def _parse_losses(losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
log_vars["loss"] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars

View File

@@ -0,0 +1,236 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from ...models import builder
from ...models.builder import DEPTHER
from ...ops import resize
from .base import BaseDepther
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f"{prefix}.{name}"] = value
return outputs
@DEPTHER.register_module()
class DepthEncoderDecoder(BaseDepther):
"""Encoder Decoder depther.
EncoderDecoder typically consists of backbone, (neck) and decode_head.
"""
def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
super(DepthEncoderDecoder, self).__init__(init_cfg)
if pretrained is not None:
assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone)
self._init_decode_head(decode_head)
if neck is not None:
self.neck = builder.build_neck(neck)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
assert self.with_decode_head
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head)
self.align_corners = self.decode_head.align_corners
def extract_feat(self, img):
"""Extract features from images."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, img, img_metas, rescale=True, size=None):
"""Encode images with backbone and decode into a depth estimation
map of the same size as input."""
x = self.extract_feat(img)
out = self._decode_head_forward_test(x, img_metas)
# crop the pred depth to the certain range.
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
if rescale:
if size is None:
if img_metas is not None:
size = img_metas[0]["ori_shape"][:2]
else:
size = img.shape[2:]
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
return out
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
losses.update(add_prefix(loss_decode, "decode"))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
return depth_pred
def forward_dummy(self, img):
"""Dummy forward function."""
depth = self.encode_decode(img, None)
return depth
def forward_train(self, img, img_metas, depth_gt, **kwargs):
"""Forward function for training.
Args:
img (Tensor): Input images.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
depth_gt (Tensor): Depth gt
used if the architecture supports depth estimation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
# the last of x saves the info from neck
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
losses.update(loss_decode)
return losses
def whole_inference(self, img, img_meta, rescale, size=None):
"""Inference with full image."""
depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
return depth_pred
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, 1, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
return preds
def inference(self, img, img_meta, rescale, size=None):
"""Inference with slide/whole style.
Args:
img (Tensor): The input image of shape (N, 3, H, W).
img_meta (dict): Image info dict where each dict has: 'img_shape',
'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
rescale (bool): Whether rescale back to original shape.
Returns:
Tensor: The output depth map.
"""
assert self.test_cfg.mode in ["slide", "whole"]
ori_shape = img_meta[0]["ori_shape"]
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
if self.test_cfg.mode == "slide":
depth_pred = self.slide_inference(img, img_meta, rescale)
else:
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
output = depth_pred
flip = img_meta[0]["flip"]
if flip:
flip_direction = img_meta[0]["flip_direction"]
assert flip_direction in ["horizontal", "vertical"]
if flip_direction == "horizontal":
output = output.flip(dims=(3,))
elif flip_direction == "vertical":
output = output.flip(dims=(2,))
return output
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
depth_pred = self.inference(img, img_meta, rescale)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
depth_pred = depth_pred.unsqueeze(0)
return depth_pred
depth_pred = depth_pred.cpu().numpy()
# unravel batch dim
depth_pred = list(depth_pred)
return depth_pred
def aug_test(self, imgs, img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented depth logit inplace
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
for i in range(1, len(imgs)):
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
depth_pred += cur_depth_pred
depth_pred /= len(imgs)
depth_pred = depth_pred.cpu().numpy()
# unravel batch dim
depth_pred = list(depth_pred)
return depth_pred

View File

@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .gradientloss import GradientLoss
from .sigloss import SigLoss

View File

@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...models.builder import LOSSES
@LOSSES.register_module()
class GradientLoss(nn.Module):
"""GradientLoss.
Adapted from https://www.cs.cornell.edu/projects/megadepth/
Args:
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
loss_weight (float): Weight of the loss. Default: 1.0.
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
"""
def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
super(GradientLoss, self).__init__()
self.valid_mask = valid_mask
self.loss_weight = loss_weight
self.max_depth = max_depth
self.loss_name = loss_name
self.eps = 0.001 # avoid grad explode
def gradientloss(self, input, target):
input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
gradient_loss = 0
for input, target in zip(input_downscaled, target_downscaled):
if self.valid_mask:
mask = target > 0
if self.max_depth is not None:
mask = torch.logical_and(target > 0, target <= self.max_depth)
N = torch.sum(mask)
else:
mask = torch.ones_like(target)
N = input.numel()
input_log = torch.log(input + self.eps)
target_log = torch.log(target + self.eps)
log_d_diff = input_log - target_log
log_d_diff = torch.mul(log_d_diff, mask)
v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
v_gradient = torch.mul(v_gradient, v_mask)
h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
h_gradient = torch.mul(h_gradient, h_mask)
gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
return gradient_loss
def forward(self, depth_pred, depth_gt):
"""Forward function."""
gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
return gradient_loss

View File

@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...models.builder import LOSSES
@LOSSES.register_module()
class SigLoss(nn.Module):
"""SigLoss.
This follows `AdaBins <https://arxiv.org/abs/2011.14141>`_.
Args:
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
loss_weight (float): Weight of the loss. Default: 1.0.
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
warm_up (bool): A simple warm up stage to help convergence. Default: False.
warm_iter (int): The number of warm up stage. Default: 100.
"""
def __init__(
self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss"
):
super(SigLoss, self).__init__()
self.valid_mask = valid_mask
self.loss_weight = loss_weight
self.max_depth = max_depth
self.loss_name = loss_name
self.eps = 0.001 # avoid grad explode
# HACK: a hack implementation for warmup sigloss
self.warm_up = warm_up
self.warm_iter = warm_iter
self.warm_up_counter = 0
def sigloss(self, input, target):
if self.valid_mask:
valid_mask = target > 0
if self.max_depth is not None:
valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
input = input[valid_mask]
target = target[valid_mask]
if self.warm_up:
if self.warm_up_counter < self.warm_iter:
g = torch.log(input + self.eps) - torch.log(target + self.eps)
g = 0.15 * torch.pow(torch.mean(g), 2)
self.warm_up_counter += 1
return torch.sqrt(g)
g = torch.log(input + self.eps) - torch.log(target + self.eps)
Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
return torch.sqrt(Dg)
def forward(self, depth_pred, depth_gt):
"""Forward function."""
loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt)
return loss_depth

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .wrappers import resize

View File

@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
import torch.nn.functional as F
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if (
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)
):
warnings.warn(
f"When align_corners={align_corners}, "
"the output would more aligned if "
f"input size {(input_h, input_w)} is `x+1` and "
f"out size {(output_h, output_w)} is `nx+1`"
)
return F.interpolate(input, size, scale_factor, mode, align_corners)

View File

@@ -0,0 +1,404 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import argparse
from functools import partial
import json
import logging
import os
import sys
from typing import List, Optional
import torch
from torch.nn.functional import one_hot, softmax
import dinov2.distributed as distributed
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.transforms import make_classification_eval_transform
from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
logger = logging.getLogger("dinov2")
def get_args_parser(
description: Optional[str] = None,
parents: Optional[List[argparse.ArgumentParser]] = None,
add_help: bool = True,
):
parents = parents or []
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
parents = [setup_args_parser]
parser = argparse.ArgumentParser(
description=description,
parents=parents,
add_help=add_help,
)
parser.add_argument(
"--train-dataset",
dest="train_dataset_str",
type=str,
help="Training dataset",
)
parser.add_argument(
"--val-dataset",
dest="val_dataset_str",
type=str,
help="Validation dataset",
)
parser.add_argument(
"--nb_knn",
nargs="+",
type=int,
help="Number of NN to use. 20 is usually working the best.",
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature used in the voting coefficient",
)
parser.add_argument(
"--gather-on-cpu",
action="store_true",
help="Whether to gather the train features on cpu, slower"
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
)
parser.add_argument(
"--batch-size",
type=int,
help="Batch size.",
)
parser.add_argument(
"--n-per-class-list",
nargs="+",
type=int,
help="Number to take per class",
)
parser.add_argument(
"--n-tries",
type=int,
help="Number of tries",
)
parser.set_defaults(
train_dataset_str="ImageNet:split=TRAIN",
val_dataset_str="ImageNet:split=VAL",
nb_knn=[10, 20, 100, 200],
temperature=0.07,
batch_size=256,
n_per_class_list=[-1],
n_tries=1,
)
return parser
class KnnModule(torch.nn.Module):
"""
Gets knn of test features from all processes on a chunk of the train features
Each rank gets a chunk of the train features as well as a chunk of the test features.
In `compute_neighbors`, for each rank one after the other, its chunk of test features
is sent to all devices, partial knns are computed with each chunk of train features
then collated back on the original device.
"""
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
super().__init__()
self.global_rank = distributed.get_global_rank()
self.global_size = distributed.get_global_size()
self.device = device
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
self.nb_knn = nb_knn
self.max_k = max(self.nb_knn)
self.T = T
self.num_classes = num_classes
def _get_knn_sims_and_labels(self, similarity, train_labels):
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
neighbors_labels = torch.gather(train_labels, 1, indices)
return topk_sims, neighbors_labels
def _similarity_for_rank(self, features_rank, source_rank):
# Send the features from `source_rank` to all ranks
broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
torch.distributed.broadcast(broadcast_shape, source_rank)
broadcasted = features_rank
if self.global_rank != source_rank:
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
torch.distributed.broadcast(broadcasted, source_rank)
# Compute the neighbors for `source_rank` among `train_features_rank_T`
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
candidate_labels = self.candidates.expand(len(similarity_rank), -1)
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
# Gather all neighbors for `target_rank`
topk_sims_rank = retrieved_rank = None
if self.global_rank == target_rank:
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
if self.global_rank == target_rank:
# Perform a second top-k on the k * global_size retrieved neighbors
topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
retrieved_rank = torch.cat(retrieved_rank, dim=1)
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
return results
return None
def compute_neighbors(self, features_rank):
for rank in range(self.global_size):
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
if results is not None:
topk_sims_rank, neighbors_labels_rank = results
return topk_sims_rank, neighbors_labels_rank
def forward(self, features_rank):
"""
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
"""
assert all(k <= self.max_k for k in self.nb_knn)
topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
batch_size = neighbors_labels.shape[0]
topk_sims_transform = softmax(topk_sims / self.T, 1)
matmul = torch.mul(
one_hot(neighbors_labels, num_classes=self.num_classes),
topk_sims_transform.view(batch_size, -1, 1),
)
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
return probas_for_k
class DictKeysModule(torch.nn.Module):
def __init__(self, keys):
super().__init__()
self.keys = keys
def forward(self, features_dict, targets):
for k in self.keys:
features_dict = features_dict[k]
return {"preds": features_dict, "target": targets}
def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
modules = {}
mapping = create_class_indices_mapping(train_labels)
for npc in n_per_class_list:
if npc < 0: # Only one try needed when using the full data
full_module = module(
train_features=train_features,
train_labels=train_labels,
nb_knn=nb_knn,
)
modules["full"] = ModuleDictWithForward({"1": full_module})
continue
all_tries = {}
for t in range(n_tries):
final_indices = filter_train(mapping, npc, seed=t)
k_list = list(set(nb_knn + [npc]))
k_list = sorted([el for el in k_list if el <= npc])
all_tries[str(t)] = module(
train_features=train_features[final_indices],
train_labels=train_labels[final_indices],
nb_knn=k_list,
)
modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
return ModuleDictWithForward(modules)
def filter_train(mapping, n_per_class, seed):
torch.manual_seed(seed)
final_indices = []
for k in mapping.keys():
index = torch.randperm(len(mapping[k]))[:n_per_class]
final_indices.append(mapping[k][index])
return torch.cat(final_indices).squeeze()
def create_class_indices_mapping(labels):
unique_labels, inverse = torch.unique(labels, return_inverse=True)
mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
return mapping
class ModuleDictWithForward(torch.nn.ModuleDict):
def forward(self, *args, **kwargs):
return {k: module(*args, **kwargs) for k, module in self._modules.items()}
def eval_knn(
model,
train_dataset,
val_dataset,
accuracy_averaging,
nb_knn,
temperature,
batch_size,
num_workers,
gather_on_cpu,
n_per_class_list=[-1],
n_tries=1,
):
model = ModelWithNormalize(model)
logger.info("Extracting features for train set...")
train_features, train_labels = extract_features(
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
)
logger.info(f"Train features created, shape {train_features.shape}.")
val_dataloader = make_data_loader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler_type=SamplerType.DISTRIBUTED,
drop_last=False,
shuffle=False,
persistent_workers=True,
)
num_classes = train_labels.max() + 1
metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
device = torch.cuda.current_device()
partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
knn_module_dict = create_module_dict(
module=partial_module,
n_per_class_list=n_per_class_list,
n_tries=n_tries,
nb_knn=nb_knn,
train_features=train_features,
train_labels=train_labels,
)
postprocessors, metrics = {}, {}
for n_per_class, knn_module in knn_module_dict.items():
for t, knn_try in knn_module.items():
postprocessors = {
**postprocessors,
**{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
}
metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
model_with_knn = torch.nn.Sequential(model, knn_module_dict)
# ============ evaluation ... ============
logger.info("Start the k-NN classification.")
_, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
# Averaging the results over the n tries for each value of n_per_class
for n_per_class, knn_module in knn_module_dict.items():
first_try = list(knn_module.keys())[0]
k_list = knn_module[first_try].nb_knn
for k in k_list:
keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
results_dict[(n_per_class, k)] = {
key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
for key in keys
}
for t in knn_module.keys():
del results_dict[(n_per_class, t, k)]
return results_dict
def eval_knn_with_model(
model,
output_dir,
train_dataset_str="ImageNet:split=TRAIN",
val_dataset_str="ImageNet:split=VAL",
nb_knn=(10, 20, 100, 200),
temperature=0.07,
autocast_dtype=torch.float,
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
transform=None,
gather_on_cpu=False,
batch_size=256,
num_workers=5,
n_per_class_list=[-1],
n_tries=1,
):
transform = transform or make_classification_eval_transform()
train_dataset = make_dataset(
dataset_str=train_dataset_str,
transform=transform,
)
val_dataset = make_dataset(
dataset_str=val_dataset_str,
transform=transform,
)
with torch.cuda.amp.autocast(dtype=autocast_dtype):
results_dict_knn = eval_knn(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
accuracy_averaging=accuracy_averaging,
nb_knn=nb_knn,
temperature=temperature,
batch_size=batch_size,
num_workers=num_workers,
gather_on_cpu=gather_on_cpu,
n_per_class_list=n_per_class_list,
n_tries=n_tries,
)
results_dict = {}
if distributed.is_main_process():
for knn_ in results_dict_knn.keys():
top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
results_dict[f"{knn_} Top 1"] = top1
results_dict[f"{knn_} Top 5"] = top5
logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
with open(metrics_file_path, "a") as f:
for k, v in results_dict.items():
f.write(json.dumps({k: v}) + "\n")
if distributed.is_enabled():
torch.distributed.barrier()
return results_dict
def main(args):
model, autocast_dtype = setup_and_build_model(args)
eval_knn_with_model(
model=model,
output_dir=args.output_dir,
train_dataset_str=args.train_dataset_str,
val_dataset_str=args.val_dataset_str,
nb_knn=args.nb_knn,
temperature=args.temperature,
autocast_dtype=autocast_dtype,
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
transform=None,
gather_on_cpu=args.gather_on_cpu,
batch_size=args.batch_size,
num_workers=5,
n_per_class_list=args.n_per_class_list,
n_tries=args.n_tries,
)
return 0
if __name__ == "__main__":
description = "DINOv2 k-NN evaluation"
args_parser = get_args_parser(description=description)
args = args_parser.parse_args()
sys.exit(main(args))

View File

@@ -0,0 +1,625 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import argparse
from functools import partial
import json
import logging
import os
import sys
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform
import dinov2.distributed as distributed
from dinov2.eval.metrics import MetricType, build_metric
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate
from dinov2.logging import MetricLogger
logger = logging.getLogger("dinov2")
def get_args_parser(
description: Optional[str] = None,
parents: Optional[List[argparse.ArgumentParser]] = None,
add_help: bool = True,
):
parents = parents or []
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
parents = [setup_args_parser]
parser = argparse.ArgumentParser(
description=description,
parents=parents,
add_help=add_help,
)
parser.add_argument(
"--train-dataset",
dest="train_dataset_str",
type=str,
help="Training dataset",
)
parser.add_argument(
"--val-dataset",
dest="val_dataset_str",
type=str,
help="Validation dataset",
)
parser.add_argument(
"--test-datasets",
dest="test_dataset_strs",
type=str,
nargs="+",
help="Test datasets, none to reuse the validation dataset",
)
parser.add_argument(
"--epochs",
type=int,
help="Number of training epochs",
)
parser.add_argument(
"--batch-size",
type=int,
help="Batch Size (per GPU)",
)
parser.add_argument(
"--num-workers",
type=int,
help="Number de Workers",
)
parser.add_argument(
"--epoch-length",
type=int,
help="Length of an epoch in number of iterations",
)
parser.add_argument(
"--save-checkpoint-frequency",
type=int,
help="Number of epochs between two named checkpoint saves.",
)
parser.add_argument(
"--eval-period-iterations",
type=int,
help="Number of iterations between two evaluations.",
)
parser.add_argument(
"--learning-rates",
nargs="+",
type=float,
help="Learning rates to grid search.",
)
parser.add_argument(
"--no-resume",
action="store_true",
help="Whether to not resume from existing checkpoints",
)
parser.add_argument(
"--val-metric-type",
type=MetricType,
choices=list(MetricType),
help="Validation metric",
)
parser.add_argument(
"--test-metric-types",
type=MetricType,
choices=list(MetricType),
nargs="+",
help="Evaluation metric",
)
parser.add_argument(
"--classifier-fpath",
type=str,
help="Path to a file containing pretrained linear classifiers",
)
parser.add_argument(
"--val-class-mapping-fpath",
type=str,
help="Path to a file containing a mapping to adjust classifier outputs",
)
parser.add_argument(
"--test-class-mapping-fpaths",
nargs="+",
type=str,
help="Path to a file containing a mapping to adjust classifier outputs",
)
parser.set_defaults(
train_dataset_str="ImageNet:split=TRAIN",
val_dataset_str="ImageNet:split=VAL",
test_dataset_strs=None,
epochs=10,
batch_size=128,
num_workers=8,
epoch_length=1250,
save_checkpoint_frequency=20,
eval_period_iterations=1250,
learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1],
val_metric_type=MetricType.MEAN_ACCURACY,
test_metric_types=None,
classifier_fpath=None,
val_class_mapping_fpath=None,
test_class_mapping_fpaths=[None],
)
return parser
def has_ddp_wrapper(m: nn.Module) -> bool:
return isinstance(m, DistributedDataParallel)
def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
return m.module if has_ddp_wrapper(m) else m
def _pad_and_collate(batch):
maxlen = max(len(targets) for image, targets in batch)
padded_batch = [
(image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch
]
return torch.utils.data.default_collate(padded_batch)
def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool):
intermediate_output = x_tokens_list[-use_n_blocks:]
output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
if use_avgpool:
output = torch.cat(
(
output,
torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
),
dim=-1,
)
output = output.reshape(output.shape[0], -1)
return output.float()
class LinearClassifier(nn.Module):
"""Linear layer to train on top of frozen features"""
def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000):
super().__init__()
self.out_dim = out_dim
self.use_n_blocks = use_n_blocks
self.use_avgpool = use_avgpool
self.num_classes = num_classes
self.linear = nn.Linear(out_dim, num_classes)
self.linear.weight.data.normal_(mean=0.0, std=0.01)
self.linear.bias.data.zero_()
def forward(self, x_tokens_list):
output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool)
return self.linear(output)
class AllClassifiers(nn.Module):
def __init__(self, classifiers_dict):
super().__init__()
self.classifiers_dict = nn.ModuleDict()
self.classifiers_dict.update(classifiers_dict)
def forward(self, inputs):
return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
def __len__(self):
return len(self.classifiers_dict)
class LinearPostprocessor(nn.Module):
def __init__(self, linear_classifier, class_mapping=None):
super().__init__()
self.linear_classifier = linear_classifier
self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
def forward(self, samples, targets):
preds = self.linear_classifier(samples)
return {
"preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
"target": targets,
}
def scale_lr(learning_rates, batch_size):
return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000):
linear_classifiers_dict = nn.ModuleDict()
optim_param_groups = []
for n in n_last_blocks_list:
for avgpool in [False, True]:
for _lr in learning_rates:
lr = scale_lr(_lr, batch_size)
out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1]
linear_classifier = LinearClassifier(
out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes
)
linear_classifier = linear_classifier.cuda()
linear_classifiers_dict[
f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_")
] = linear_classifier
optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr})
linear_classifiers = AllClassifiers(linear_classifiers_dict)
if distributed.is_enabled():
linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
return linear_classifiers, optim_param_groups
@torch.no_grad()
def evaluate_linear_classifiers(
feature_model,
linear_classifiers,
data_loader,
metric_type,
metrics_file_path,
training_num_classes,
iteration,
prefixstring="",
class_mapping=None,
best_classifier_on_val=None,
):
logger.info("running validation !")
num_classes = len(class_mapping) if class_mapping is not None else training_num_classes
metric = build_metric(metric_type, num_classes=num_classes)
postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()}
metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
_, results_dict_temp = evaluate(
feature_model,
data_loader,
postprocessors,
metrics,
torch.cuda.current_device(),
)
logger.info("")
results_dict = {}
max_accuracy = 0
best_classifier = ""
for i, (classifier_string, metric) in enumerate(results_dict_temp.items()):
logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
if (
best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
) or classifier_string == best_classifier_on_val:
max_accuracy = metric["top-1"].item()
best_classifier = classifier_string
results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
logger.info(f"best classifier: {results_dict['best_classifier']}")
if distributed.is_main_process():
with open(metrics_file_path, "a") as f:
f.write(f"iter: {iteration}\n")
for k, v in results_dict.items():
f.write(json.dumps({k: v}) + "\n")
f.write("\n")
return results_dict
def eval_linear(
*,
feature_model,
linear_classifiers,
train_data_loader,
val_data_loader,
metrics_file_path,
optimizer,
scheduler,
output_dir,
max_iter,
checkpoint_period, # In number of iter, creates a new file every period
running_checkpoint_period, # Period to update main checkpoint file
eval_period,
metric_type,
training_num_classes,
resume=True,
classifier_fpath=None,
val_class_mapping=None,
):
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter)
iteration = start_iter
logger.info("Starting training from iteration {}".format(start_iter))
metric_logger = MetricLogger(delimiter=" ")
header = "Training"
for data, labels in metric_logger.log_every(
train_data_loader,
10,
header,
max_iter,
start_iter,
):
data = data.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
features = feature_model(data)
outputs = linear_classifiers(features)
losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
loss = sum(losses.values())
# compute the gradients
optimizer.zero_grad()
loss.backward()
# step
optimizer.step()
scheduler.step()
# log
if iteration % 10 == 0:
torch.cuda.synchronize()
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
print("lr", optimizer.param_groups[0]["lr"])
if iteration - start_iter > 5:
if iteration % running_checkpoint_period == 0:
torch.cuda.synchronize()
if distributed.is_main_process():
logger.info("Checkpointing running_checkpoint")
periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration)
torch.cuda.synchronize()
periodic_checkpointer.step(iteration)
if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
_ = evaluate_linear_classifiers(
feature_model=feature_model,
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
data_loader=val_data_loader,
metrics_file_path=metrics_file_path,
prefixstring=f"ITER: {iteration}",
metric_type=metric_type,
training_num_classes=training_num_classes,
iteration=iteration,
class_mapping=val_class_mapping,
)
torch.cuda.synchronize()
iteration = iteration + 1
val_results_dict = evaluate_linear_classifiers(
feature_model=feature_model,
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
data_loader=val_data_loader,
metrics_file_path=metrics_file_path,
metric_type=metric_type,
training_num_classes=training_num_classes,
iteration=iteration,
class_mapping=val_class_mapping,
)
return val_results_dict, feature_model, linear_classifiers, iteration
def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type):
test_dataset = make_dataset(
dataset_str=test_dataset_str,
transform=make_classification_eval_transform(),
)
test_data_loader = make_data_loader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler_type=SamplerType.DISTRIBUTED,
drop_last=False,
shuffle=False,
persistent_workers=False,
collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None,
)
return test_data_loader
def test_on_datasets(
feature_model,
linear_classifiers,
test_dataset_strs,
batch_size,
num_workers,
test_metric_types,
metrics_file_path,
training_num_classes,
iteration,
best_classifier_on_val,
prefixstring="",
test_class_mappings=[None],
):
results_dict = {}
for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types):
logger.info(f"Testing on {test_dataset_str}")
test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type)
dataset_results_dict = evaluate_linear_classifiers(
feature_model,
remove_ddp_wrapper(linear_classifiers),
test_data_loader,
metric_type,
metrics_file_path,
training_num_classes,
iteration,
prefixstring="",
class_mapping=class_mapping,
best_classifier_on_val=best_classifier_on_val,
)
results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"]
return results_dict
def run_eval_linear(
model,
output_dir,
train_dataset_str,
val_dataset_str,
batch_size,
epochs,
epoch_length,
num_workers,
save_checkpoint_frequency,
eval_period_iterations,
learning_rates,
autocast_dtype,
test_dataset_strs=None,
resume=True,
classifier_fpath=None,
val_class_mapping_fpath=None,
test_class_mapping_fpaths=[None],
val_metric_type=MetricType.MEAN_ACCURACY,
test_metric_types=None,
):
seed = 0
if test_dataset_strs is None:
test_dataset_strs = [val_dataset_str]
if test_metric_types is None:
test_metric_types = [val_metric_type] * len(test_dataset_strs)
else:
assert len(test_metric_types) == len(test_dataset_strs)
assert len(test_dataset_strs) == len(test_class_mapping_fpaths)
train_transform = make_classification_train_transform()
train_dataset = make_dataset(
dataset_str=train_dataset_str,
transform=train_transform,
)
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
sampler_type = SamplerType.SHARDED_INFINITE
# sampler_type = SamplerType.INFINITE
n_last_blocks_list = [1, 4]
n_last_blocks = max(n_last_blocks_list)
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
linear_classifiers, optim_param_groups = setup_linear_classifiers(
sample_output,
n_last_blocks_list,
learning_rates,
batch_size,
training_num_classes,
)
optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0)
max_iter = epochs * epoch_length
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
train_data_loader = make_data_loader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
seed=seed,
sampler_type=sampler_type,
sampler_advance=start_iter,
drop_last=True,
persistent_workers=True,
)
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
checkpoint_period = save_checkpoint_frequency * epoch_length
if val_class_mapping_fpath is not None:
logger.info(f"Using class mapping from {val_class_mapping_fpath}")
val_class_mapping = np.load(val_class_mapping_fpath)
else:
val_class_mapping = None
test_class_mappings = []
for class_mapping_fpath in test_class_mapping_fpaths:
if class_mapping_fpath is not None and class_mapping_fpath != "None":
logger.info(f"Using class mapping from {class_mapping_fpath}")
class_mapping = np.load(class_mapping_fpath)
else:
class_mapping = None
test_class_mappings.append(class_mapping)
metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
val_results_dict, feature_model, linear_classifiers, iteration = eval_linear(
feature_model=feature_model,
linear_classifiers=linear_classifiers,
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
metrics_file_path=metrics_file_path,
optimizer=optimizer,
scheduler=scheduler,
output_dir=output_dir,
max_iter=max_iter,
checkpoint_period=checkpoint_period,
running_checkpoint_period=epoch_length,
eval_period=eval_period_iterations,
metric_type=val_metric_type,
training_num_classes=training_num_classes,
resume=resume,
val_class_mapping=val_class_mapping,
classifier_fpath=classifier_fpath,
)
results_dict = {}
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
results_dict = test_on_datasets(
feature_model,
linear_classifiers,
test_dataset_strs,
batch_size,
0, # num_workers,
test_metric_types,
metrics_file_path,
training_num_classes,
iteration,
val_results_dict["best_classifier"]["name"],
prefixstring="",
test_class_mappings=test_class_mappings,
)
results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"]
results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"]
logger.info("Test Results Dict " + str(results_dict))
return results_dict
def main(args):
model, autocast_dtype = setup_and_build_model(args)
run_eval_linear(
model=model,
output_dir=args.output_dir,
train_dataset_str=args.train_dataset_str,
val_dataset_str=args.val_dataset_str,
test_dataset_strs=args.test_dataset_strs,
batch_size=args.batch_size,
epochs=args.epochs,
epoch_length=args.epoch_length,
num_workers=args.num_workers,
save_checkpoint_frequency=args.save_checkpoint_frequency,
eval_period_iterations=args.eval_period_iterations,
learning_rates=args.learning_rates,
autocast_dtype=autocast_dtype,
resume=not args.no_resume,
classifier_fpath=args.classifier_fpath,
val_metric_type=args.val_metric_type,
test_metric_types=args.test_metric_types,
val_class_mapping_fpath=args.val_class_mapping_fpath,
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
)
return 0
if __name__ == "__main__":
description = "DINOv2 linear evaluation"
args_parser = get_args_parser(description=description)
args = args_parser.parse_args()
sys.exit(main(args))

View File

@@ -0,0 +1,444 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import argparse
import gc
import logging
import sys
import time
from typing import List, Optional
from cuml.linear_model import LogisticRegression
import torch
import torch.backends.cudnn as cudnn
import torch.distributed
from torch import nn
from torch.utils.data import TensorDataset
from torchmetrics import MetricTracker
from dinov2.data import make_dataset
from dinov2.data.transforms import make_classification_eval_transform
from dinov2.distributed import get_global_rank, get_global_size
from dinov2.eval.metrics import MetricType, build_metric
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import evaluate, extract_features
from dinov2.utils.dtype import as_torch_dtype
logger = logging.getLogger("dinov2")
DEFAULT_MAX_ITER = 1_000
C_POWER_RANGE = torch.linspace(-6, 5, 45)
_CPU_DEVICE = torch.device("cpu")
def get_args_parser(
description: Optional[str] = None,
parents: Optional[List[argparse.ArgumentParser]] = None,
add_help: bool = True,
):
parents = parents or []
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
parents = [setup_args_parser]
parser = argparse.ArgumentParser(
description=description,
parents=parents,
add_help=add_help,
)
parser.add_argument(
"--train-dataset",
dest="train_dataset_str",
type=str,
help="Training dataset",
)
parser.add_argument(
"--val-dataset",
dest="val_dataset_str",
type=str,
help="Validation dataset",
)
parser.add_argument(
"--finetune-dataset-str",
dest="finetune_dataset_str",
type=str,
help="Fine-tuning dataset",
)
parser.add_argument(
"--finetune-on-val",
action="store_true",
help="If there is no finetune dataset, whether to choose the "
"hyperparameters on the val set instead of 10%% of the train dataset",
)
parser.add_argument(
"--metric-type",
type=MetricType,
choices=list(MetricType),
help="Metric type",
)
parser.add_argument(
"--train-features-device",
type=str,
help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s",
)
parser.add_argument(
"--train-dtype",
type=str,
help="Data type to convert the train features to (default: %(default)s)",
)
parser.add_argument(
"--max-train-iters",
type=int,
help="Maximum number of train iterations (default: %(default)s)",
)
parser.set_defaults(
train_dataset_str="ImageNet:split=TRAIN",
val_dataset_str="ImageNet:split=VAL",
finetune_dataset_str=None,
metric_type=MetricType.MEAN_ACCURACY,
train_features_device="cpu",
train_dtype="float64",
max_train_iters=DEFAULT_MAX_ITER,
finetune_on_val=False,
)
return parser
class LogRegModule(nn.Module):
def __init__(
self,
C,
max_iter=DEFAULT_MAX_ITER,
dtype=torch.float64,
device=_CPU_DEVICE,
):
super().__init__()
self.dtype = dtype
self.device = device
self.estimator = LogisticRegression(
penalty="l2",
C=C,
max_iter=max_iter,
output_type="numpy",
tol=1e-12,
linesearch_max_iter=50,
)
def forward(self, samples, targets):
samples_device = samples.device
samples = samples.to(dtype=self.dtype, device=self.device)
if self.device == _CPU_DEVICE:
samples = samples.numpy()
probas = self.estimator.predict_proba(samples)
return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets}
def fit(self, train_features, train_labels):
train_features = train_features.to(dtype=self.dtype, device=self.device)
train_labels = train_labels.to(dtype=self.dtype, device=self.device)
if self.device == _CPU_DEVICE:
# both cuML and sklearn only work with numpy arrays on CPU
train_features = train_features.numpy()
train_labels = train_labels.numpy()
self.estimator.fit(train_features, train_labels)
def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device):
postprocessors = {"metrics": logreg_model}
metrics = {"metrics": logreg_metric}
return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device)
def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE):
logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device)
logreg_model.fit(train_features, train_labels)
return logreg_model
def train_and_evaluate(
*,
C,
max_iter,
train_features,
train_labels,
logreg_metric,
test_data_loader,
train_dtype=torch.float64,
train_features_device,
eval_device,
):
logreg_model = train_for_C(
C=C,
max_iter=max_iter,
train_features=train_features,
train_labels=train_labels,
dtype=train_dtype,
device=train_features_device,
)
return evaluate_model(
logreg_model=logreg_model,
logreg_metric=logreg_metric,
test_data_loader=test_data_loader,
device=eval_device,
)
def sweep_C_values(
*,
train_features,
train_labels,
test_data_loader,
metric_type,
num_classes,
train_dtype=torch.float64,
train_features_device=_CPU_DEVICE,
max_train_iters=DEFAULT_MAX_ITER,
):
if metric_type == MetricType.PER_CLASS_ACCURACY:
# If we want to output per-class accuracy, we select the hyperparameters with mean per class
metric_type = MetricType.MEAN_PER_CLASS_ACCURACY
logreg_metric = build_metric(metric_type, num_classes=num_classes)
metric_tracker = MetricTracker(logreg_metric, maximize=True)
ALL_C = 10**C_POWER_RANGE
logreg_models = {}
train_features = train_features.to(dtype=train_dtype, device=train_features_device)
train_labels = train_labels.to(device=train_features_device)
for i in range(get_global_rank(), len(ALL_C), get_global_size()):
C = ALL_C[i].item()
logger.info(
f"Training for C = {C:.5f}, dtype={train_dtype}, "
f"features: {train_features.shape}, {train_features.dtype}, "
f"labels: {train_labels.shape}, {train_labels.dtype}"
)
logreg_models[C] = train_for_C(
C=C,
max_iter=max_train_iters,
train_features=train_features,
train_labels=train_labels,
dtype=train_dtype,
device=train_features_device,
)
gather_list = [None for _ in range(get_global_size())]
torch.distributed.all_gather_object(gather_list, logreg_models)
logreg_models_gathered = {}
for logreg_dict in gather_list:
logreg_models_gathered.update(logreg_dict)
for i in range(len(ALL_C)):
metric_tracker.increment()
C = ALL_C[i].item()
evals = evaluate_model(
logreg_model=logreg_models_gathered[C],
logreg_metric=metric_tracker,
test_data_loader=test_data_loader,
device=torch.cuda.current_device(),
)
logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}")
best_stats, which_epoch = metric_tracker.best_metric(return_step=True)
best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()}
if which_epoch["top-1"] == i:
best_C = C
logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}")
return best_stats, best_C
def eval_log_regression(
*,
model,
train_dataset,
val_dataset,
finetune_dataset,
metric_type,
batch_size,
num_workers,
finetune_on_val=False,
train_dtype=torch.float64,
train_features_device=_CPU_DEVICE,
max_train_iters=DEFAULT_MAX_ITER,
):
"""
Implements the "standard" process for log regression evaluation:
The value of C is chosen by training on train_dataset and evaluating on
finetune_dataset. Then, the final model is trained on a concatenation of
train_dataset and finetune_dataset, and is evaluated on val_dataset.
If there is no finetune_dataset, the value of C is the one that yields
the best results on a random 10% subset of the train dataset
"""
start = time.time()
train_features, train_labels = extract_features(
model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
)
val_features, val_labels = extract_features(
model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
)
val_data_loader = torch.utils.data.DataLoader(
TensorDataset(val_features, val_labels),
batch_size=batch_size,
drop_last=False,
num_workers=0,
persistent_workers=False,
)
if finetune_dataset is None and finetune_on_val:
logger.info("Choosing hyperparameters on the val dataset")
finetune_features, finetune_labels = val_features, val_labels
elif finetune_dataset is None and not finetune_on_val:
logger.info("Choosing hyperparameters on 10% of the train dataset")
torch.manual_seed(0)
indices = torch.randperm(len(train_features), device=train_features.device)
finetune_index = indices[: len(train_features) // 10]
train_index = indices[len(train_features) // 10 :]
finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index]
train_features, train_labels = train_features[train_index], train_labels[train_index]
else:
logger.info("Choosing hyperparameters on the finetune dataset")
finetune_features, finetune_labels = extract_features(
model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
)
# release the model - free GPU memory
del model
gc.collect()
torch.cuda.empty_cache()
finetune_data_loader = torch.utils.data.DataLoader(
TensorDataset(finetune_features, finetune_labels),
batch_size=batch_size,
drop_last=False,
)
if len(train_labels.shape) > 1:
num_classes = train_labels.shape[1]
else:
num_classes = train_labels.max() + 1
logger.info("Using cuML for logistic regression")
best_stats, best_C = sweep_C_values(
train_features=train_features,
train_labels=train_labels,
test_data_loader=finetune_data_loader,
metric_type=metric_type,
num_classes=num_classes,
train_dtype=train_dtype,
train_features_device=train_features_device,
max_train_iters=max_train_iters,
)
if not finetune_on_val:
logger.info("Best parameter found, concatenating features")
train_features = torch.cat((train_features, finetune_features))
train_labels = torch.cat((train_labels, finetune_labels))
logger.info("Training final model")
logreg_metric = build_metric(metric_type, num_classes=num_classes)
evals = train_and_evaluate(
C=best_C,
max_iter=max_train_iters,
train_features=train_features,
train_labels=train_labels,
logreg_metric=logreg_metric.clone(),
test_data_loader=val_data_loader,
eval_device=torch.cuda.current_device(),
train_dtype=train_dtype,
train_features_device=train_features_device,
)
best_stats = evals[1]["metrics"]
best_stats["best_C"] = best_C
logger.info(f"Log regression evaluation done in {int(time.time() - start)}s")
return best_stats
def eval_log_regression_with_model(
model,
train_dataset_str="ImageNet:split=TRAIN",
val_dataset_str="ImageNet:split=VAL",
finetune_dataset_str=None,
autocast_dtype=torch.float,
finetune_on_val=False,
metric_type=MetricType.MEAN_ACCURACY,
train_dtype=torch.float64,
train_features_device=_CPU_DEVICE,
max_train_iters=DEFAULT_MAX_ITER,
):
cudnn.benchmark = True
transform = make_classification_eval_transform(resize_size=224)
target_transform = None
train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform)
val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform)
if finetune_dataset_str is not None:
finetune_dataset = make_dataset(
dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform
)
else:
finetune_dataset = None
with torch.cuda.amp.autocast(dtype=autocast_dtype):
results_dict_logreg = eval_log_regression(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
finetune_dataset=finetune_dataset,
metric_type=metric_type,
batch_size=256,
num_workers=0, # 5,
finetune_on_val=finetune_on_val,
train_dtype=train_dtype,
train_features_device=train_features_device,
max_train_iters=max_train_iters,
)
results_dict = {
"top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0,
"top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0,
"best_C": results_dict_logreg["best_C"],
}
logger.info(
"\n".join(
[
"Training of the supervised logistic regression on frozen features completed.\n"
"Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]),
"Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]),
"obtained for C = {c:.6f}".format(c=results_dict["best_C"]),
]
)
)
torch.distributed.barrier()
return results_dict
def main(args):
model, autocast_dtype = setup_and_build_model(args)
eval_log_regression_with_model(
model=model,
train_dataset_str=args.train_dataset_str,
val_dataset_str=args.val_dataset_str,
finetune_dataset_str=args.finetune_dataset_str,
autocast_dtype=autocast_dtype,
finetune_on_val=args.finetune_on_val,
metric_type=args.metric_type,
train_dtype=as_torch_dtype(args.train_dtype),
train_features_device=torch.device(args.train_features_device),
max_train_iters=args.max_train_iters,
)
return 0
if __name__ == "__main__":
description = "DINOv2 logistic regression evaluation"
args_parser = get_args_parser(description=description)
args = args_parser.parse_args()
sys.exit(main(args))

View File

@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from enum import Enum
import logging
from typing import Any, Dict, Optional
import torch
from torch import Tensor
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.utilities.data import dim_zero_cat, select_topk
logger = logging.getLogger("dinov2")
class MetricType(Enum):
MEAN_ACCURACY = "mean_accuracy"
MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy"
PER_CLASS_ACCURACY = "per_class_accuracy"
IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy"
@property
def accuracy_averaging(self):
return getattr(AccuracyAveraging, self.name, None)
def __str__(self):
return self.value
class AccuracyAveraging(Enum):
MEAN_ACCURACY = "micro"
MEAN_PER_CLASS_ACCURACY = "macro"
PER_CLASS_ACCURACY = "none"
def __str__(self):
return self.value
def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None):
if metric_type.accuracy_averaging is not None:
return build_topk_accuracy_metric(
average_type=metric_type.accuracy_averaging,
num_classes=num_classes,
ks=(1, 5) if ks is None else ks,
)
elif metric_type == MetricType.IMAGENET_REAL_ACCURACY:
return build_topk_imagenet_real_accuracy_metric(
num_classes=num_classes,
ks=(1, 5) if ks is None else ks,
)
raise ValueError(f"Unknown metric type {metric_type}")
def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)):
metrics: Dict[str, Metric] = {
f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks
}
return MetricCollection(metrics)
def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)):
metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks}
return MetricCollection(metrics)
class ImageNetReaLAccuracy(Metric):
is_differentiable: bool = False
higher_is_better: Optional[bool] = None
full_state_update: bool = False
def __init__(
self,
num_classes: int,
top_k: int = 1,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.num_classes = num_classes
self.top_k = top_k
self.add_state("tp", [], dist_reduce_fx="cat")
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
# preds [B, D]
# target [B, A]
# preds_oh [B, D] with 0 and 1
# select top K highest probabilities, use one hot representation
preds_oh = select_topk(preds, self.top_k)
# target_oh [B, D + 1] with 0 and 1
target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32)
target = target.long()
# for undefined targets (-1) use a fake value `num_classes`
target[target == -1] = self.num_classes
# fill targets, use one hot representation
target_oh.scatter_(1, target, 1)
# target_oh [B, D] (remove the fake target at index `num_classes`)
target_oh = target_oh[:, :-1]
# tp [B] with 0 and 1
tp = (preds_oh * target_oh == 1).sum(dim=1)
# at least one match between prediction and target
tp.clip_(max=1)
# ignore instances where no targets are defined
mask = target_oh.sum(dim=1) > 0
tp = tp[mask]
self.tp.append(tp) # type: ignore
def compute(self) -> Tensor:
tp = dim_zero_cat(self.tp) # type: ignore
return tp.float().mean()

View File

@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .optimizer import DistOptimizerHook

View File

@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
try:
import apex
except ImportError:
print("apex is not installed")
from mmcv.runner import OptimizerHook, HOOKS
@HOOKS.register_module()
class DistOptimizerHook(OptimizerHook):
"""Optimizer hook for distributed training."""
def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
self.update_interval = update_interval
self.use_fp16 = use_fp16
def before_run(self, runner):
runner.optimizer.zero_grad()
def after_train_iter(self, runner):
runner.outputs["loss"] /= self.update_interval
if self.use_fp16:
# runner.outputs['loss'].backward()
with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss:
scaled_loss.backward()
else:
runner.outputs["loss"].backward()
if self.every_n_iters(runner, self.update_interval):
if self.grad_clip is not None:
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
runner.optimizer.zero_grad()

View File

@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .backbones import * # noqa: F403
from .decode_heads import * # noqa: F403

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .vision_transformer import DinoVisionTransformer

View File

@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.runner import BaseModule
from mmseg.models.builder import BACKBONES
@BACKBONES.register_module()
class DinoVisionTransformer(BaseModule):
"""Vision Transformer."""
def __init__(
self,
*args,
**kwargs,
):
super().__init__()

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .linear_head import BNHead

View File

@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.ops import resize
@HEADS.register_module()
class BNHead(BaseDecodeHead):
"""Just a batchnorm."""
def __init__(self, resize_factors=None, **kwargs):
super().__init__(**kwargs)
assert self.in_channels == self.channels
self.bn = nn.SyncBatchNorm(self.in_channels)
self.resize_factors = resize_factors
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
# print("inputs", [i.shape for i in inputs])
x = self._transform_inputs(inputs)
# print("x", x.shape)
feats = self.bn(x)
# print("feats", feats.shape)
return feats
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == "resize_concat":
# accept lists (for cls token)
input_list = []
for x in inputs:
if isinstance(x, list):
input_list.extend(x)
else:
input_list.append(x)
inputs = input_list
# an image descriptor can be a local descriptor with resolution 1x1
for i, x in enumerate(inputs):
if len(x.shape) == 2:
inputs[i] = x[:, :, None, None]
# select indices
inputs = [inputs[i] for i in self.in_index]
# Resizing shenanigans
# print("before", *(x.shape for x in inputs))
if self.resize_factors is not None:
assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs))
inputs = [
resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area")
for x, f in zip(inputs, self.resize_factors)
]
# print("after", *(x.shape for x in inputs))
upsampled_inputs = [
resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == "multiple_select":
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output

View File

@@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,362 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
ADE20K_COLORMAP = [
(0, 0, 0),
(120, 120, 120),
(180, 120, 120),
(6, 230, 230),
(80, 50, 50),
(4, 200, 3),
(120, 120, 80),
(140, 140, 140),
(204, 5, 255),
(230, 230, 230),
(4, 250, 7),
(224, 5, 255),
(235, 255, 7),
(150, 5, 61),
(120, 120, 70),
(8, 255, 51),
(255, 6, 82),
(143, 255, 140),
(204, 255, 4),
(255, 51, 7),
(204, 70, 3),
(0, 102, 200),
(61, 230, 250),
(255, 6, 51),
(11, 102, 255),
(255, 7, 71),
(255, 9, 224),
(9, 7, 230),
(220, 220, 220),
(255, 9, 92),
(112, 9, 255),
(8, 255, 214),
(7, 255, 224),
(255, 184, 6),
(10, 255, 71),
(255, 41, 10),
(7, 255, 255),
(224, 255, 8),
(102, 8, 255),
(255, 61, 6),
(255, 194, 7),
(255, 122, 8),
(0, 255, 20),
(255, 8, 41),
(255, 5, 153),
(6, 51, 255),
(235, 12, 255),
(160, 150, 20),
(0, 163, 255),
(140, 140, 140),
(250, 10, 15),
(20, 255, 0),
(31, 255, 0),
(255, 31, 0),
(255, 224, 0),
(153, 255, 0),
(0, 0, 255),
(255, 71, 0),
(0, 235, 255),
(0, 173, 255),
(31, 0, 255),
(11, 200, 200),
(255, 82, 0),
(0, 255, 245),
(0, 61, 255),
(0, 255, 112),
(0, 255, 133),
(255, 0, 0),
(255, 163, 0),
(255, 102, 0),
(194, 255, 0),
(0, 143, 255),
(51, 255, 0),
(0, 82, 255),
(0, 255, 41),
(0, 255, 173),
(10, 0, 255),
(173, 255, 0),
(0, 255, 153),
(255, 92, 0),
(255, 0, 255),
(255, 0, 245),
(255, 0, 102),
(255, 173, 0),
(255, 0, 20),
(255, 184, 184),
(0, 31, 255),
(0, 255, 61),
(0, 71, 255),
(255, 0, 204),
(0, 255, 194),
(0, 255, 82),
(0, 10, 255),
(0, 112, 255),
(51, 0, 255),
(0, 194, 255),
(0, 122, 255),
(0, 255, 163),
(255, 153, 0),
(0, 255, 10),
(255, 112, 0),
(143, 255, 0),
(82, 0, 255),
(163, 255, 0),
(255, 235, 0),
(8, 184, 170),
(133, 0, 255),
(0, 255, 92),
(184, 0, 255),
(255, 0, 31),
(0, 184, 255),
(0, 214, 255),
(255, 0, 112),
(92, 255, 0),
(0, 224, 255),
(112, 224, 255),
(70, 184, 160),
(163, 0, 255),
(153, 0, 255),
(71, 255, 0),
(255, 0, 163),
(255, 204, 0),
(255, 0, 143),
(0, 255, 235),
(133, 255, 0),
(255, 0, 235),
(245, 0, 255),
(255, 0, 122),
(255, 245, 0),
(10, 190, 212),
(214, 255, 0),
(0, 204, 255),
(20, 0, 255),
(255, 255, 0),
(0, 153, 255),
(0, 41, 255),
(0, 255, 204),
(41, 0, 255),
(41, 255, 0),
(173, 0, 255),
(0, 245, 255),
(71, 0, 255),
(122, 0, 255),
(0, 255, 184),
(0, 92, 255),
(184, 255, 0),
(0, 133, 255),
(255, 214, 0),
(25, 194, 194),
(102, 255, 0),
(92, 0, 255),
]
ADE20K_CLASS_NAMES = [
"",
"wall",
"building;edifice",
"sky",
"floor;flooring",
"tree",
"ceiling",
"road;route",
"bed",
"windowpane;window",
"grass",
"cabinet",
"sidewalk;pavement",
"person;individual;someone;somebody;mortal;soul",
"earth;ground",
"door;double;door",
"table",
"mountain;mount",
"plant;flora;plant;life",
"curtain;drape;drapery;mantle;pall",
"chair",
"car;auto;automobile;machine;motorcar",
"water",
"painting;picture",
"sofa;couch;lounge",
"shelf",
"house",
"sea",
"mirror",
"rug;carpet;carpeting",
"field",
"armchair",
"seat",
"fence;fencing",
"desk",
"rock;stone",
"wardrobe;closet;press",
"lamp",
"bathtub;bathing;tub;bath;tub",
"railing;rail",
"cushion",
"base;pedestal;stand",
"box",
"column;pillar",
"signboard;sign",
"chest;of;drawers;chest;bureau;dresser",
"counter",
"sand",
"sink",
"skyscraper",
"fireplace;hearth;open;fireplace",
"refrigerator;icebox",
"grandstand;covered;stand",
"path",
"stairs;steps",
"runway",
"case;display;case;showcase;vitrine",
"pool;table;billiard;table;snooker;table",
"pillow",
"screen;door;screen",
"stairway;staircase",
"river",
"bridge;span",
"bookcase",
"blind;screen",
"coffee;table;cocktail;table",
"toilet;can;commode;crapper;pot;potty;stool;throne",
"flower",
"book",
"hill",
"bench",
"countertop",
"stove;kitchen;stove;range;kitchen;range;cooking;stove",
"palm;palm;tree",
"kitchen;island",
"computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
"swivel;chair",
"boat",
"bar",
"arcade;machine",
"hovel;hut;hutch;shack;shanty",
"bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle",
"towel",
"light;light;source",
"truck;motortruck",
"tower",
"chandelier;pendant;pendent",
"awning;sunshade;sunblind",
"streetlight;street;lamp",
"booth;cubicle;stall;kiosk",
"television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
"airplane;aeroplane;plane",
"dirt;track",
"apparel;wearing;apparel;dress;clothes",
"pole",
"land;ground;soil",
"bannister;banister;balustrade;balusters;handrail",
"escalator;moving;staircase;moving;stairway",
"ottoman;pouf;pouffe;puff;hassock",
"bottle",
"buffet;counter;sideboard",
"poster;posting;placard;notice;bill;card",
"stage",
"van",
"ship",
"fountain",
"conveyer;belt;conveyor;belt;conveyer;conveyor;transporter",
"canopy",
"washer;automatic;washer;washing;machine",
"plaything;toy",
"swimming;pool;swimming;bath;natatorium",
"stool",
"barrel;cask",
"basket;handbasket",
"waterfall;falls",
"tent;collapsible;shelter",
"bag",
"minibike;motorbike",
"cradle",
"oven",
"ball",
"food;solid;food",
"step;stair",
"tank;storage;tank",
"trade;name;brand;name;brand;marque",
"microwave;microwave;oven",
"pot;flowerpot",
"animal;animate;being;beast;brute;creature;fauna",
"bicycle;bike;wheel;cycle",
"lake",
"dishwasher;dish;washer;dishwashing;machine",
"screen;silver;screen;projection;screen",
"blanket;cover",
"sculpture",
"hood;exhaust;hood",
"sconce",
"vase",
"traffic;light;traffic;signal;stoplight",
"tray",
"ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin",
"fan",
"pier;wharf;wharfage;dock",
"crt;screen",
"plate",
"monitor;monitoring;device",
"bulletin;board;notice;board",
"shower",
"radiator",
"glass;drinking;glass",
"clock",
"flag",
]
VOC2012_COLORMAP = [
(0, 0, 0),
(128, 0, 0),
(0, 128, 0),
(128, 128, 0),
(0, 0, 128),
(128, 0, 128),
(0, 128, 128),
(128, 128, 128),
(64, 0, 0),
(192, 0, 0),
(64, 128, 0),
(192, 128, 0),
(64, 0, 128),
(192, 0, 128),
(64, 128, 128),
(192, 128, 128),
(0, 64, 0),
(128, 64, 0),
(0, 192, 0),
(128, 192, 0),
(0, 64, 128),
]
VOC2012_CLASS_NAMES = [
"",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]

View File

@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .core import * # noqa: F403
from .models import * # noqa: F403
from .ops import * # noqa: F403

View File

@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmseg.core.evaluation import * # noqa: F403
from mmseg.core.seg import * # noqa: F403
from .anchor import * # noqa: F403
from .box import * # noqa: F403
from .utils import * # noqa: F403

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .point_generator import MlvlPointGenerator # noqa: F403

View File

@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
from mmcv.utils import Registry, build_from_cfg
PRIOR_GENERATORS = Registry("Generator for anchors and points")
ANCHOR_GENERATORS = PRIOR_GENERATORS
def build_prior_generator(cfg, default_args=None):
return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
def build_anchor_generator(cfg, default_args=None):
warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ")
return build_prior_generator(cfg, default_args=default_args)

View File

@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch.nn.modules.utils import _pair
from .builder import PRIOR_GENERATORS
@PRIOR_GENERATORS.register_module()
class MlvlPointGenerator:
"""Standard points generator for multi-level (Mlvl) feature maps in 2D
points-based detectors.
Args:
strides (list[int] | list[tuple[int, int]]): Strides of anchors
in multiple feature levels in order (w, h).
offset (float): The offset of points, the value is normalized with
corresponding stride. Defaults to 0.5.
"""
def __init__(self, strides, offset=0.5):
self.strides = [_pair(stride) for stride in strides]
self.offset = offset
@property
def num_levels(self):
"""int: number of feature levels that the generator will be applied"""
return len(self.strides)
@property
def num_base_priors(self):
"""list[int]: The number of priors (points) at a point
on the feature grid"""
return [1 for _ in range(len(self.strides))]
def _meshgrid(self, x, y, row_major=True):
yy, xx = torch.meshgrid(y, x)
if row_major:
# warning .flatten() would cause error in ONNX exporting
# have to use reshape here
return xx.reshape(-1), yy.reshape(-1)
else:
return yy.reshape(-1), xx.reshape(-1)
def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False):
"""Generate grid points of multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes in
multiple feature levels, each size arrange as
as (h, w).
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str): The device where the anchors will be put on.
with_stride (bool): Whether to concatenate the stride to
the last dimension of points.
Return:
list[torch.Tensor]: Points of multiple feature levels.
The sizes of each tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""
assert self.num_levels == len(featmap_sizes)
multi_level_priors = []
for i in range(self.num_levels):
priors = self.single_level_grid_priors(
featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride
)
multi_level_priors.append(priors)
return multi_level_priors
def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False):
"""Generate grid Points of a single level.
Note:
This function is usually called by method ``self.grid_priors``.
Args:
featmap_size (tuple[int]): Size of the feature maps, arrange as
(h, w).
level_idx (int): The index of corresponding feature map level.
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str, optional): The device the tensor will be put on.
Defaults to 'cuda'.
with_stride (bool): Concatenate the stride to the last dimension
of points.
Return:
Tensor: Points of single feature levels.
The shape of tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_x = shift_x.to(dtype)
shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_y = shift_y.to(dtype)
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype)
stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype)
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1)
all_points = shifts.to(device)
return all_points
def valid_flags(self, featmap_sizes, pad_shape, device="cuda"):
"""Generate valid flags of points of multiple feature levels.
Args:
featmap_sizes (list(tuple)): List of feature map sizes in
multiple feature levels, each size arrange as
as (h, w).
pad_shape (tuple(int)): The padded shape of the image,
arrange as (h, w).
device (str): The device where the anchors will be put on.
Return:
list(torch.Tensor): Valid flags of points of multiple levels.
"""
assert self.num_levels == len(featmap_sizes)
multi_level_flags = []
for i in range(self.num_levels):
point_stride = self.strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = pad_shape[:2]
valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device)
multi_level_flags.append(flags)
return multi_level_flags
def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"):
"""Generate the valid flags of points of a single feature map.
Args:
featmap_size (tuple[int]): The size of feature maps, arrange as
as (h, w).
valid_size (tuple[int]): The valid size of the feature maps.
The size arrange as as (h, w).
device (str, optional): The device where the flags will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: The valid flags of each points in a single level \
feature map.
"""
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
return valid
def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"):
"""Generate sparse points according to the ``prior_idxs``.
Args:
prior_idxs (Tensor): The index of corresponding anchors
in the feature map.
featmap_size (tuple[int]): feature map size arrange as (w, h).
level_idx (int): The level index of corresponding feature
map.
dtype (obj:`torch.dtype`): Date type of points. Defaults to
``torch.float32``.
device (obj:`torch.device`): The device where the points is
located.
Returns:
Tensor: Anchor with shape (N, 2), N should be equal to
the length of ``prior_idxs``. And last dimension
2 represent (coord_x, coord_y).
"""
height, width = featmap_size
x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1]
prioris = torch.stack([x, y], 1).to(dtype)
prioris = prioris.to(device)
return prioris

View File

@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .builder import * # noqa: F403
from .samplers import MaskPseudoSampler # noqa: F403

View File

@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.utils import Registry, build_from_cfg
BBOX_SAMPLERS = Registry("bbox_sampler")
BBOX_CODERS = Registry("bbox_coder")
def build_sampler(cfg, **default_args):
"""Builder of box sampler."""
return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
def build_bbox_coder(cfg, **default_args):
"""Builder of box coder."""
return build_from_cfg(cfg, BBOX_CODERS, default_args)

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403

View File

@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
import torch
from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers."""
def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self
@abstractmethod
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Sample positive samples."""
pass
@abstractmethod
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Sample negative samples."""
pass
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
Example:
>>> from mmdet.core.bbox import RandomSampler
>>> from mmdet.core.bbox import AssignResult
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
>>> gt_labels = None
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
"""
if len(bboxes.shape) < 2:
bboxes = bboxes[None, :]
bboxes = bboxes[:, :4]
gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
if gt_labels is None:
raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])
num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
return sampling_result

View File

@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
import torch
from ..builder import BBOX_SAMPLERS
from .base_sampler import BaseSampler
from .mask_sampling_result import MaskSamplingResult
@BBOX_SAMPLERS.register_module()
class MaskPseudoSampler(BaseSampler):
"""A pseudo sampler that does not do sampling actually."""
def __init__(self, **kwargs):
pass
def _sample_pos(self, **kwargs):
"""Sample positive samples."""
raise NotImplementedError
def _sample_neg(self, **kwargs):
"""Sample negative samples."""
raise NotImplementedError
def sample(self, assign_result, masks, gt_masks, **kwargs):
"""Directly returns the positive and negative indices of samples.
Args:
assign_result (:obj:`AssignResult`): Assigned results
masks (torch.Tensor): Bounding boxes
gt_masks (torch.Tensor): Ground truth boxes
Returns:
:obj:`SamplingResult`: sampler results
"""
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)
sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags)
return sampling_result

View File

@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
import torch
from .sampling_result import SamplingResult
class MaskSamplingResult(SamplingResult):
"""Mask sampling result."""
def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_masks = masks[pos_inds]
self.neg_masks = masks[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_masks.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_masks.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_masks = torch.empty_like(gt_masks)
else:
self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def masks(self):
"""torch.Tensor: concatenated positive and negative boxes"""
return torch.cat([self.pos_masks, self.neg_masks])
def __nice__(self):
data = self.info.copy()
data["pos_masks"] = data.pop("pos_masks").shape
data["neg_masks"] = data.pop("neg_masks").shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = " " + ",\n ".join(parts)
return "{\n" + body + "\n}"
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
"pos_inds": self.pos_inds,
"neg_inds": self.neg_inds,
"pos_masks": self.pos_masks,
"neg_masks": self.neg_masks,
"pos_is_gt": self.pos_is_gt,
"num_gts": self.num_gts,
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
}

View File

@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
class SamplingResult:
"""Bbox sampling result.
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print(f'self = {self}')
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def bboxes(self):
"""torch.Tensor: concatenated positive and negative boxes"""
return torch.cat([self.pos_bboxes, self.neg_bboxes])
def to(self, device):
"""Change the device of the data inplace.
Example:
>>> self = SamplingResult.random()
>>> print(f'self = {self.to(None)}')
>>> # xdoctest: +REQUIRES(--gpu)
>>> print(f'self = {self.to(0)}')
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self
def __nice__(self):
data = self.info.copy()
data["pos_bboxes"] = data.pop("pos_bboxes").shape
data["neg_bboxes"] = data.pop("neg_bboxes").shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = " " + ",\n ".join(parts)
return "{\n" + body + "\n}"
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
"pos_inds": self.pos_inds,
"neg_inds": self.neg_inds,
"pos_bboxes": self.pos_bboxes,
"neg_bboxes": self.neg_bboxes,
"pos_is_gt": self.pos_is_gt,
"num_gts": self.num_gts,
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
}
@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state.
kwargs (keyword arguments):
- num_preds: number of predicted boxes
- num_gts: number of true boxes
- p_ignore (float): probability of a predicted box assigned to \
an ignored truth.
- p_assigned (float): probability of a predicted box not being \
assigned.
- p_use_label (float | bool): with labels or not.
Returns:
:obj:`SamplingResult`: Randomly generated sampling result.
Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox import demodata
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
rng = demodata.ensure_rng(rng)
# make probabalistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1
assign_result = AssignResult.random(rng=rng, **kwargs)
# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()
if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None
if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabalistic?
sampler = RandomSampler(
num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng
)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self

View File

@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .dist_utils import reduce_mean
from .misc import add_prefix, multi_apply

View File

@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch.distributed as dist
def reduce_mean(tensor):
""" "Obtain the mean of tensor on different GPUs."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor

View File

@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from functools import partial
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f"{prefix}.{name}"] = value
return outputs

View File

@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .backbones import * # noqa: F403
from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost
from .decode_heads import * # noqa: F403
from .losses import * # noqa: F403
from .plugins import * # noqa: F403
from .segmentors import * # noqa: F403

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .vit_adapter import ViTAdapter

View File

@@ -0,0 +1,442 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from ...ops.modules import MSDeformAttn
from .drop_path import DropPath
def get_reference_points(spatial_shapes, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / H_
ref_x = ref_x.reshape(-1)[None] / W_
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None]
return reference_points
def deform_inputs(x, patch_size):
bs, c, h, w = x.shape
spatial_shapes = torch.as_tensor(
[(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
return deform_inputs1, deform_inputs2
class ConvFFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
n = N // 21
x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
x = torch.cat([x1, x2, x3], dim=1)
return x
class Extractor(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
n_levels=1,
deform_ratio=1.0,
with_cffn=True,
cffn_ratio=0.25,
drop=0.0,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
with_cp=False,
):
super().__init__()
self.query_norm = norm_layer(dim)
self.feat_norm = norm_layer(dim)
self.attn = MSDeformAttn(
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
)
self.with_cffn = with_cffn
self.with_cp = with_cp
if with_cffn:
self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
self.ffn_norm = norm_layer(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
def _inner_forward(query, feat):
attn = self.attn(
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
)
query = query + attn
if self.with_cffn:
query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
return query
if self.with_cp and query.requires_grad:
query = cp.checkpoint(_inner_forward, query, feat)
else:
query = _inner_forward(query, feat)
return query
class Injector(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
n_levels=1,
deform_ratio=1.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=0.0,
with_cp=False,
):
super().__init__()
self.with_cp = with_cp
self.query_norm = norm_layer(dim)
self.feat_norm = norm_layer(dim)
self.attn = MSDeformAttn(
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
)
self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
def _inner_forward(query, feat):
attn = self.attn(
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
)
return query + self.gamma * attn
if self.with_cp and query.requires_grad:
query = cp.checkpoint(_inner_forward, query, feat)
else:
query = _inner_forward(query, feat)
return query
class InteractionBlock(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop=0.0,
drop_path=0.0,
with_cffn=True,
cffn_ratio=0.25,
init_values=0.0,
deform_ratio=1.0,
extra_extractor=False,
with_cp=False,
):
super().__init__()
self.injector = Injector(
dim=dim,
n_levels=3,
num_heads=num_heads,
init_values=init_values,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cp=with_cp,
)
self.extractor = Extractor(
dim=dim,
n_levels=1,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
if extra_extractor:
self.extra_extractors = nn.Sequential(
*[
Extractor(
dim=dim,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
for _ in range(2)
]
)
else:
self.extra_extractors = None
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
x = self.injector(
query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2],
)
for idx, blk in enumerate(blocks):
x = blk(x, H_toks, W_toks)
c = self.extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
return x, c
class InteractionBlockWithCls(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop=0.0,
drop_path=0.0,
with_cffn=True,
cffn_ratio=0.25,
init_values=0.0,
deform_ratio=1.0,
extra_extractor=False,
with_cp=False,
):
super().__init__()
self.injector = Injector(
dim=dim,
n_levels=3,
num_heads=num_heads,
init_values=init_values,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cp=with_cp,
)
self.extractor = Extractor(
dim=dim,
n_levels=1,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
if extra_extractor:
self.extra_extractors = nn.Sequential(
*[
Extractor(
dim=dim,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
for _ in range(2)
]
)
else:
self.extra_extractors = None
def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
x = self.injector(
query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2],
)
x = torch.cat((cls, x), dim=1)
for idx, blk in enumerate(blocks):
x = blk(x, H_toks, W_toks)
cls, x = (
x[
:,
:1,
],
x[
:,
1:,
],
)
c = self.extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
return x, c, cls
class SpatialPriorModule(nn.Module):
def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
super().__init__()
self.with_cp = with_cp
self.stem = nn.Sequential(
*[
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
]
)
self.conv2 = nn.Sequential(
*[
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(2 * inplanes),
nn.ReLU(inplace=True),
]
)
self.conv3 = nn.Sequential(
*[
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
nn.ReLU(inplace=True),
]
)
self.conv4 = nn.Sequential(
*[
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
nn.ReLU(inplace=True),
]
)
self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
def _inner_forward(x):
c1 = self.stem(x)
c2 = self.conv2(c1)
c3 = self.conv3(c2)
c4 = self.conv4(c3)
c1 = self.fc1(c1)
c2 = self.fc2(c2)
c3 = self.fc3(c3)
c4 = self.fc4(c4)
bs, dim, _, _ = c1.shape
# c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
return c1, c2, c3, c4
if self.with_cp and x.requires_grad:
outs = cp.checkpoint(_inner_forward, x)
else:
outs = _inner_forward(x)
return outs

View File

@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@@ -0,0 +1,552 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
"""Vision Transformer (ViT) in PyTorch.
A PyTorch implement of Vision Transformers as described in:
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.10270
The official jax code is released and available at https://github.com/google-research/vision_transformer
DeiT model defs and weights from https://github.com/facebookresearch/deit,
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
import math
from functools import partial
from itertools import repeat
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.runner import BaseModule, load_checkpoint
from mmseg.ops import resize
from mmseg.utils import get_root_logger
from torch import Tensor
from .drop_path import DropPath
def to_2tuple(x):
return tuple(repeat(x, 2))
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
swiglu_hidden_features = int(2 * hidden_features / 3)
align_as = 8
swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as
self.w1 = nn.Linear(in_features, swiglu_hidden_features)
self.w2 = nn.Linear(in_features, swiglu_hidden_features)
self.w3 = nn.Linear(swiglu_hidden_features, out_features)
def forward(self, x: Tensor) -> Tensor:
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(hidden)
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding."""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x, H, W
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, H, W):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor, H, W) -> Tensor:
from xformers.ops import memory_efficient_attention, unbind
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowedAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant"
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.window_size = window_size
self.pad_mode = pad_mode
def forward(self, x, H, W):
B, N, C = x.shape
N_ = self.window_size * self.window_size
H_ = math.ceil(H / self.window_size) * self.window_size
W_ = math.ceil(W / self.window_size) * self.window_size
qkv = self.qkv(x) # [B, N, C]
qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W]
qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode)
qkv = F.unfold(
qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size)
)
B, C_kw_kw, L = qkv.shape # L - the num of windows
qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C]
qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# q,k,v [B, L, num_head, N_, C/num_head]
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
# if self.mask:
# attn = attn * mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
# attn @ v = [B, L, num_head, N_, C/num_head]
x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L)
x = F.fold(
x,
output_size=(H_, W_),
kernel_size=(self.window_size, self.window_size),
stride=(self.window_size, self.window_size),
) # [B, C, H_, W_]
x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2)
x = self.proj(x)
x = self.proj_drop(x)
return x
# class WindowedAttention(nn.Module):
# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"):
# super().__init__()
# self.num_heads = num_heads
# head_dim = dim // num_heads
# self.scale = head_dim ** -0.5
#
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# self.window_size = window_size
# self.pad_mode = pad_mode
#
# def forward(self, x, H, W):
# B, N, C = x.shape
#
# N_ = self.window_size * self.window_size
# H_ = math.ceil(H / self.window_size) * self.window_size
# W_ = math.ceil(W / self.window_size) * self.window_size
# x = x.view(B, H, W, C)
# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode)
#
# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C
# x = x.view(-1, N_, C)
#
# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C)
#
# x = window_reverse(x, self.window_size, H_, W_)
# x = x[:, :H, :W, :].reshape(B, N, C).contiguous()
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
windowed=False,
window_size=14,
pad_mode="constant",
layer_scale=False,
with_cp=False,
ffn_layer=Mlp,
memeff=False,
):
super().__init__()
self.with_cp = with_cp
self.norm1 = norm_layer(dim)
if windowed:
self.attn = WindowedAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
pad_mode=pad_mode,
)
elif memeff:
self.attn = MemEffAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
)
else:
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.layer_scale = layer_scale
if layer_scale:
self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True)
self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True)
def forward(self, x, H, W):
def _inner_forward(x):
if self.layer_scale:
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class TIMMVisionTransformer(BaseModule):
"""Vision Transformer.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
layer_scale=True,
embed_layer=PatchEmbed,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
window_attn=False,
window_size=14,
pretrained=None,
with_cp=False,
pre_norm=False,
ffn_type="mlp",
memeff=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
pretrained: (str): pretrained path
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.norm_layer = norm_layer
self.act_layer = act_layer
self.pretrain_size = img_size
self.drop_path_rate = drop_path_rate
self.drop_rate = drop_rate
self.patch_size = patch_size
window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn
window_size = [window_size] * depth if not isinstance(window_size, list) else window_size
logging.info("window attention:", window_attn)
logging.info("window size:", window_size)
logging.info("layer scale:", layer_scale)
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm
)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN}
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
windowed=window_attn[i],
window_size=window_size[i],
layer_scale=layer_scale,
with_cp=with_cp,
ffn_layer=ffn_types[ffn_type],
memeff=memeff,
)
for i in range(depth)
]
)
# self.norm = norm_layer(embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# For CLIP
if pre_norm:
norm_pre = norm_layer(embed_dim)
self.norm_pre = norm_pre
else:
self.norm_pre = nn.Identity()
self.init_weights(pretrained)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger)
def forward_features(self, x):
x, H, W = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
# For CLIP
x = self.norm_pre(x)
for blk in self.blocks:
x = blk(x, H, W)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]"
pos_h, pos_w = pos_shape
# keep dim for easy deployment
cls_token_weight = pos_embed[:, 0:1]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :]
pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed

View File

@@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.models.builder import BACKBONES
from torch.nn.init import normal_
from ...ops.modules import MSDeformAttn
from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs
from .vit import TIMMVisionTransformer
@BACKBONES.register_module()
class ViTAdapter(TIMMVisionTransformer):
def __init__(
self,
pretrain_size=224,
num_heads=12,
conv_inplane=64,
n_points=4,
deform_num_heads=6,
init_values=0.0,
interaction_indexes=None,
with_cffn=True,
cffn_ratio=0.25,
deform_ratio=1.0,
add_vit_feature=True,
pretrained=None,
use_extra_extractor=True,
freeze_vit=False,
use_cls=True,
with_cp=False,
*args,
**kwargs
):
super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs)
if freeze_vit:
for param in self.parameters():
param.requires_grad = False
# self.num_classes = 80
self.use_cls = use_cls
if not self.use_cls:
self.cls_token = None
self.num_block = len(self.blocks)
self.pretrain_size = (pretrain_size, pretrain_size)
self.interaction_indexes = interaction_indexes
self.add_vit_feature = add_vit_feature
embed_dim = self.embed_dim
block_fn = InteractionBlockWithCls if use_cls else InteractionBlock
self.level_embed = nn.Parameter(torch.zeros(3, embed_dim))
self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False)
self.interactions = nn.Sequential(
*[
block_fn(
dim=embed_dim,
num_heads=deform_num_heads,
n_points=n_points,
init_values=init_values,
drop_path=self.drop_path_rate,
norm_layer=self.norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
with_cp=with_cp,
)
for i in range(len(interaction_indexes))
]
)
self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
self.norm1 = nn.SyncBatchNorm(embed_dim)
self.norm2 = nn.SyncBatchNorm(embed_dim)
self.norm3 = nn.SyncBatchNorm(embed_dim)
self.norm4 = nn.SyncBatchNorm(embed_dim)
self.up.apply(self._init_weights)
self.spm.apply(self._init_weights)
self.interactions.apply(self._init_weights)
self.apply(self._init_deform_weights)
normal_(self.level_embed)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def _get_pos_embed(self, pos_embed, H, W):
pos_embed = pos_embed.reshape(
1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1
).permute(0, 3, 1, 2)
pos_embed = (
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
.reshape(1, -1, H * W)
.permute(0, 2, 1)
)
return pos_embed
def _init_deform_weights(self, m):
if isinstance(m, MSDeformAttn):
m._reset_parameters()
def _add_level_embed(self, c2, c3, c4):
c2 = c2 + self.level_embed[0]
c3 = c3 + self.level_embed[1]
c4 = c4 + self.level_embed[2]
return c2, c3, c4
def forward(self, x):
deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size)
# SPM forward
c1, c2, c3, c4 = self.spm(x)
c2, c3, c4 = self._add_level_embed(c2, c3, c4)
c = torch.cat([c2, c3, c4], dim=1)
# Patch Embedding forward
H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
x, H_toks, W_toks = self.patch_embed(x)
# print("H_toks, W_toks =", H_toks, W_toks)
bs, n, dim = x.shape
pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks)
if self.use_cls:
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_token, x), dim=1)
pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1)
x = self.pos_drop(x + pos_embed)
# For CLIP
x = self.norm_pre(x)
# Interaction
if self.use_cls:
cls, x = (
x[
:,
:1,
],
x[
:,
1:,
],
)
outs = list()
for i, layer in enumerate(self.interactions):
indexes = self.interaction_indexes[i]
if self.use_cls:
x, c, cls = layer(
x,
c,
cls,
self.blocks[indexes[0] : indexes[-1] + 1],
deform_inputs1,
deform_inputs2,
H_c,
W_c,
H_toks,
W_toks,
)
else:
x, c = layer(
x,
c,
self.blocks[indexes[0] : indexes[-1] + 1],
deform_inputs1,
deform_inputs2,
H_c,
W_c,
H_toks,
W_toks,
)
outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous())
# Split & Reshape
c2 = c[:, 0 : c2.size(1), :]
c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :]
c4 = c[:, c2.size(1) + c3.size(1) :, :]
c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous()
c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous()
c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous()
c1 = self.up(c2) + c1
if self.add_vit_feature:
x1, x2, x3, x4 = outs
x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False)
x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False)
x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False)
x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False)
# print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks)
c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
# Final Norm
f1 = self.norm1(c1)
f2 = self.norm2(c2)
f3 = self.norm3(c3)
f4 = self.norm4(c4)
return [f1, f2, f3, f4]

View File

@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.utils import Registry
TRANSFORMER = Registry("Transformer")
MASK_ASSIGNERS = Registry("mask_assigner")
MATCH_COST = Registry("match_cost")
def build_match_cost(cfg):
"""Build Match Cost."""
return MATCH_COST.build(cfg)
def build_assigner(cfg):
"""Build Assigner."""
return MASK_ASSIGNERS.build(cfg)
def build_transformer(cfg):
"""Build Transformer."""
return TRANSFORMER.build(cfg)

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .mask2former_head import Mask2FormerHead

View File

@@ -0,0 +1,544 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
from mmcv.ops import point_sample
from mmcv.runner import ModuleList, force_fp32
from mmseg.models.builder import HEADS, build_loss
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from ...core import build_sampler, multi_apply, reduce_mean
from ..builder import build_assigner
from ..utils import get_uncertain_point_coords_with_randomness
@HEADS.register_module()
class Mask2FormerHead(BaseDecodeHead):
"""Implements the Mask2Former head.
See `Masked-attention Mask Transformer for Universal Image
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
Args:
in_channels (list[int]): Number of channels in the input feature map.
feat_channels (int): Number of channels for features.
out_channels (int): Number of channels for output.
num_things_classes (int): Number of things.
num_stuff_classes (int): Number of stuff.
num_queries (int): Number of query in Transformer decoder.
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
decoder. Defaults to None.
enforce_decoder_input_project (bool, optional): Whether to add
a layer to change the embed_dim of tranformer encoder in
pixel decoder to the embed_dim of transformer decoder.
Defaults to False.
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder. Defaults to None.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder position encoding. Defaults to None.
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
loss. Defaults to None.
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
Defaults to None.
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
Defaults to None.
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
Mask2Former head.
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
Mask2Former head.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels,
feat_channels,
out_channels,
num_things_classes=80,
num_stuff_classes=53,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=None,
enforce_decoder_input_project=False,
transformer_decoder=None,
positional_encoding=None,
loss_cls=None,
loss_mask=None,
loss_dice=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
**kwargs,
):
super(Mask2FormerHead, self).__init__(
in_channels=in_channels,
channels=feat_channels,
num_classes=(num_things_classes + num_stuff_classes),
init_cfg=init_cfg,
input_transform="multiple_select",
**kwargs,
)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
self.num_classes = self.num_things_classes + self.num_stuff_classes
self.num_queries = num_queries
self.num_transformer_feat_level = num_transformer_feat_level
self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads
self.num_transformer_decoder_layers = transformer_decoder.num_layers
assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level
pixel_decoder_ = copy.deepcopy(pixel_decoder)
pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels)
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder)
self.decoder_embed_dims = self.transformer_decoder.embed_dims
self.decoder_input_projs = ModuleList()
# from low resolution to high resolution
for _ in range(num_transformer_feat_level):
if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project:
self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1))
else:
self.decoder_input_projs.append(nn.Identity())
self.decoder_positional_encoding = build_positional_encoding(positional_encoding)
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
# from low resolution to high resolution
self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels)
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
self.mask_embed = nn.Sequential(
nn.Linear(feat_channels, feat_channels),
nn.ReLU(inplace=True),
nn.Linear(feat_channels, feat_channels),
nn.ReLU(inplace=True),
nn.Linear(feat_channels, out_channels),
)
self.conv_seg = None # fix a bug here (conv_seg is not used)
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
self.sampler = build_sampler(self.train_cfg.sampler, context=self)
self.num_points = self.train_cfg.get("num_points", 12544)
self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0)
self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75)
self.class_weight = loss_cls.class_weight
self.loss_cls = build_loss(loss_cls)
self.loss_mask = build_loss(loss_mask)
self.loss_dice = build_loss(loss_dice)
def init_weights(self):
for m in self.decoder_input_projs:
if isinstance(m, Conv2d):
caffe2_xavier_init(m, bias=0)
self.pixel_decoder.init_weights()
for p in self.transformer_decoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas):
"""Compute classification and mask targets for all images for a decoder
layer.
Args:
cls_scores_list (list[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape [num_queries,
cls_out_channels].
mask_preds_list (list[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape [num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for all
images. Each with shape (n, ), n is the sum of number of stuff
type and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[list[Tensor]]: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels of all images.
Each with shape [num_queries, ].
- label_weights_list (list[Tensor]): Label weights of all
images.Each with shape [num_queries, ].
- mask_targets_list (list[Tensor]): Mask targets of all images.
Each with shape [num_queries, h, w].
- mask_weights_list (list[Tensor]): Mask weights of all images.
Each with shape [num_queries, ].
- num_total_pos (int): Number of positive samples in all
images.
- num_total_neg (int): Number of negative samples in all
images.
"""
(
labels_list,
label_weights_list,
mask_targets_list,
mask_weights_list,
pos_inds_list,
neg_inds_list,
) = multi_apply(
self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas
)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas):
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_labels (Tensor): Ground truth class indices for one image with
shape (num_gts, ).
gt_masks (Tensor): Ground truth mask for each image, each with
shape (num_gts, h, w).
img_metas (dict): Image informtation.
Returns:
tuple[Tensor]: A tuple containing the following for one image.
- labels (Tensor): Labels of each image. \
shape (num_queries, ).
- label_weights (Tensor): Label weights of each image. \
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image. \
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image. \
shape (num_queries, ).
- pos_inds (Tensor): Sampled positive indices for each \
image.
- neg_inds (Tensor): Sampled negative indices for each \
image.
"""
# sample points
num_queries = cls_score.shape[0]
num_gts = gt_labels.shape[0]
point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device)
# shape (num_queries, num_points)
mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1)
# shape (num_gts, num_points)
gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1)
# assign and sample
assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas)
sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label target
labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_labels.new_ones((self.num_queries,))
# mask target
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
mask_weights = mask_pred.new_zeros((self.num_queries,))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas):
"""Loss function for outputs from a single decoder layer.
Args:
cls_scores (Tensor): Mask score logits from a single decoder layer
for all images. Shape (batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
mask_preds (Tensor): Mask logits for a pixel decoder for all
images. Shape (batch_size, num_queries, h, w).
gt_labels_list (list[Tensor]): Ground truth class indices for each
image, each with shape (num_gts, ).
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (num_gts, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[Tensor]: Loss components for outputs from a single \
decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
(
labels_list,
label_weights_list,
mask_targets_list,
mask_weights_list,
num_total_pos,
num_total_neg,
) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas)
# shape (batch_size, num_queries)
labels = torch.stack(labels_list, dim=0)
# shape (batch_size, num_queries)
label_weights = torch.stack(label_weights_list, dim=0)
# shape (num_total_gts, h, w)
mask_targets = torch.cat(mask_targets_list, dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(mask_weights_list, dim=0)
# classfication loss
# shape (batch_size * num_queries, )
cls_scores = cls_scores.flatten(0, 1)
labels = labels.flatten(0, 1)
label_weights = label_weights.flatten(0, 1)
class_weight = cls_scores.new_tensor(self.class_weight)
loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum())
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
num_total_masks = max(num_total_masks, 1)
# extract positive ones
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
mask_preds = mask_preds[mask_weights > 0]
if mask_targets.shape[0] == 0:
# zero match
loss_dice = mask_preds.sum()
loss_mask = mask_preds.sum()
return loss_cls, loss_mask, loss_dice
with torch.no_grad():
points_coords = get_uncertain_point_coords_with_randomness(
mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio
)
# shape (num_total_gts, h, w) -> (num_total_gts, num_points)
mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
# shape (num_queries, h, w) -> (num_queries, num_points)
mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1)
# dice loss
loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
# mask loss
# shape (num_queries, num_points) -> (num_queries * num_points, )
mask_point_preds = mask_point_preds.reshape(-1, 1)
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
mask_point_targets = mask_point_targets.reshape(-1)
loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points)
return loss_cls, loss_mask, loss_dice
@force_fp32(apply_to=("all_cls_scores", "all_mask_preds"))
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas):
"""Loss function.
Args:
all_cls_scores (Tensor): Classification scores for all decoder
layers with shape [num_decoder, batch_size, num_queries,
cls_out_channels].
all_mask_preds (Tensor): Mask scores for all decoder layers with
shape [num_decoder, batch_size, num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (n, ). n is the sum of number of stuff type
and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image with
shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_dec_layers = len(all_cls_scores)
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_mask, losses_dice = multi_apply(
self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list
)
loss_dict = dict()
# loss from the last decoder layer
loss_dict["loss_cls"] = losses_cls[-1]
loss_dict["loss_mask"] = losses_mask[-1]
loss_dict["loss_dice"] = losses_dice[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i
loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i
num_dec_layer += 1
return loss_dict
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
"""Forward for head part which is called after every decoder layer.
Args:
decoder_out (Tensor): in shape (num_queries, batch_size, c).
mask_feature (Tensor): in shape (batch_size, c, h, w).
attn_mask_target_size (tuple[int, int]): target attention
mask size.
Returns:
tuple: A tuple contain three elements.
- cls_pred (Tensor): Classification scores in shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred (Tensor): Mask scores in shape \
(batch_size, num_queries,h, w).
- attn_mask (Tensor): Attention mask in shape \
(batch_size * num_heads, num_queries, h, w).
"""
decoder_out = self.transformer_decoder.post_norm(decoder_out)
decoder_out = decoder_out.transpose(0, 1)
# shape (num_queries, batch_size, c)
cls_pred = self.cls_embed(decoder_out)
# shape (num_queries, batch_size, c)
mask_embed = self.mask_embed(decoder_out)
# shape (num_queries, batch_size, h, w)
mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature)
attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False)
# shape (num_queries, batch_size, h, w) ->
# (batch_size * num_head, num_queries, h, w)
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1)
attn_mask = attn_mask.sigmoid() < 0.5
attn_mask = attn_mask.detach()
return cls_pred, mask_pred, attn_mask
def forward(self, feats, img_metas):
"""Forward function.
Args:
feats (list[Tensor]): Multi scale Features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple: A tuple contains two elements.
- cls_pred_list (list[Tensor)]: Classification logits \
for each decoder layer. Each is a 3D-tensor with shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred_list (list[Tensor]): Mask logits for each \
decoder layer. Each with shape (batch_size, num_queries, \
h, w).
"""
batch_size = len(img_metas)
mask_features, multi_scale_memorys = self.pixel_decoder(feats)
# multi_scale_memorys (from low resolution to high resolution)
decoder_inputs = []
decoder_positional_encodings = []
for i in range(self.num_transformer_feat_level):
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
level_embed = self.level_embed.weight[i].view(1, 1, -1)
decoder_input = decoder_input + level_embed
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool)
decoder_positional_encoding = self.decoder_positional_encoding(mask)
decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1)
decoder_inputs.append(decoder_input)
decoder_positional_encodings.append(decoder_positional_encoding)
# shape (num_queries, c) -> (num_queries, batch_size, c)
query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1))
query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1))
cls_pred_list = []
mask_pred_list = []
cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
for i in range(self.num_transformer_decoder_layers):
level_idx = i % self.num_transformer_feat_level
# if a mask is all True(all background), then set it all False.
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
# cross_attn + self_attn
layer = self.transformer_decoder.layers[i]
attn_masks = [attn_mask, None]
query_feat = layer(
query=query_feat,
key=decoder_inputs[level_idx],
value=decoder_inputs[level_idx],
query_pos=query_embed,
key_pos=decoder_positional_encodings[level_idx],
attn_masks=attn_masks,
query_key_padding_mask=None,
# here we do not apply masking on padded region
key_padding_mask=None,
)
cls_pred, mask_pred, attn_mask = self.forward_head(
query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:]
)
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
return cls_pred_list, mask_pred_list
def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks):
"""Forward function for training mode.
Args:
x (list[Tensor]): Multi-level features from the upstream network,
each is a 4D-tensor.
img_metas (list[Dict]): List of image information.
gt_semantic_seg (list[tensor]):Each element is the ground truth
of semantic segmentation with the shape (N, H, W).
train_cfg (dict): The training config, which not been used in
maskformer.
gt_labels (list[Tensor]): Each element is ground truth labels of
each box, shape (num_gts,).
gt_masks (list[BitmapMasks]): Each element is masks of instances
of a image, shape (num_gts, h, w).
Returns:
losses (dict[str, Tensor]): a dictionary of loss components
"""
# forward
all_cls_scores, all_mask_preds = self(x, img_metas)
# loss
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Test segment without test-time aumengtation.
Only the output of last decoder layers was used.
Args:
inputs (list[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
test_cfg (dict): Testing config.
Returns:
seg_mask (Tensor): Predicted semantic segmentation logits.
"""
all_cls_scores, all_mask_preds = self(inputs, img_metas)
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
ori_h, ori_w, _ = img_metas[0]["ori_shape"]
# semantic inference
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred)
return seg_mask

View File

@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy
from .dice_loss import DiceLoss
from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost

View File

@@ -0,0 +1,279 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss
def cross_entropy(
pred,
label,
weight=None,
class_weight=None,
reduction="mean",
avg_factor=None,
ignore_index=-100,
avg_non_ignore=False,
):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index)
# apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == "mean":
avg_factor = label.numel() - (label == ignore_index).sum().item()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights = bin_label_weights * valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(
pred,
label,
weight=None,
reduction="mean",
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False,
**kwargs,
):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int): The label index to be ignored. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
Returns:
torch.Tensor: The calculated loss
"""
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
pred = pred.squeeze()
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), (
"Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported"
)
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements
if reduction == "mean" and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()
loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none")
# do the reduction for the weighted loss
loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(
pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs
):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask'
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
"""
assert ignore_index is None, "BCE loss does not support ignore_index"
assert reduction == "mean" and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None]
@LOSSES.register_module(force=True)
class CrossEntropyLoss(nn.Module):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
def __init__(
self,
use_sigmoid=False,
use_mask=False,
reduction="mean",
class_weight=None,
loss_weight=1.0,
loss_name="loss_ce",
avg_non_ignore=False,
):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == "mean":
warnings.warn(
"Default ``avg_non_ignore`` is False, if you would like to "
"ignore the certain label and average loss over non-ignore "
"labels, which is the same with PyTorch official "
"cross_entropy, set ``avg_non_ignore=True``."
)
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self._loss_name = loss_name
def extra_repr(self):
"""Extra repr."""
s = f"avg_non_ignore={self.avg_non_ignore}"
return s
def forward(
self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs
):
"""Forward function."""
assert reduction_override in (None, "none", "mean", "sum")
reduction = reduction_override if reduction_override else self.reduction
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs,
)
return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import weight_reduce_loss
def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
"""Calculate dice loss, which is proposed in
`V-Net: Fully Convolutional Neural Networks for Volumetric
Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
"""Calculate naive dice loss, the coefficient in the denominator is the
first power instead of the second power.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input, 1)
c = torch.sum(target, 1)
d = (2 * a + eps) / (b + c + eps)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module(force=True)
class DiceLoss(nn.Module):
def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3):
"""Dice Loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Networks for Volumetric Medical Image Segmentation
<https://arxiv.org/abs/1606.04797>`_.
- the dice loss in which the power of the number in the
denominator is the first power instead of the second
power.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power.Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
"""
super(DiceLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.naive_dice = naive_dice
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, "none", "mean", "sum")
reduction = reduction_override if reduction_override else self.reduction
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
else:
raise NotImplementedError
if self.naive_dice:
loss = self.loss_weight * naive_dice_loss(
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
)
else:
loss = self.loss_weight * dice_loss(
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
)
return loss

View File

@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from ..builder import MATCH_COST
@MATCH_COST.register_module()
class ClassificationCost:
"""ClsSoftmaxCost.Borrow from
mmdet.core.bbox.match_costs.match_cost.ClassificationCost.
Args:
weight (int | float, optional): loss_weight
Examples:
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight=1.0):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be omitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class DiceCost:
"""Cost of mask assignments based on dice losses.
Args:
weight (int | float, optional): loss_weight. Defaults to 1.
pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float, optional): default 1e-12.
"""
def __init__(self, weight=1.0, pred_act=False, eps=1e-3):
self.weight = weight
self.pred_act = pred_act
self.eps = eps
def binary_mask_dice_loss(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
mask_preds = mask_preds.reshape((mask_preds.shape[0], -1))
gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float()
numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks)
denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction logits in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W).
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
if self.pred_act:
mask_preds = mask_preds.sigmoid()
dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
return dice_cost * self.weight
@MATCH_COST.register_module()
class CrossEntropyLossCost:
"""CrossEntropyLossCost.
Args:
weight (int | float, optional): loss weight. Defaults to 1.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to True.
"""
def __init__(self, weight=1.0, use_sigmoid=True):
assert use_sigmoid, "use_sigmoid = False is not supported yet."
self.weight = weight
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
(num_query, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none")
neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none")
cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits.
gt_labels (Tensor): Labels.
Returns:
Tensor: Cross entropy cost matrix with weight in
shape (num_query, num_gt).
"""
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
else:
raise NotImplementedError
return cls_cost * self.weight

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder

View File

@@ -0,0 +1,242 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
from mmcv.runner import BaseModule, ModuleList
from ...core.anchor import MlvlPointGenerator
from ..utils.transformer import MultiScaleDeformableAttention
@PLUGIN_LAYERS.register_module()
class MSDeformAttnPixelDecoder(BaseModule):
"""Pixel decoder with multi-scale deformable attention.
Args:
in_channels (list[int] | tuple[int]): Number of channels in the
input feature maps.
strides (list[int] | tuple[int]): Output strides of feature from
backbone.
feat_channels (int): Number of channels for feature.
out_channels (int): Number of channels for output.
num_outs (int): Number of output scales.
norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
Defaults to dict(type='GN', num_groups=32).
act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
Defaults to dict(type='ReLU').
encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
encoder. Defaults to `DetrTransformerEncoder`.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer encoder position encoding. Defaults to
dict(type='SinePositionalEncoding', num_feats=128,
normalize=True).
init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
"""
def __init__(
self,
in_channels=[256, 512, 1024, 2048],
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_outs=3,
norm_cfg=dict(type="GN", num_groups=32),
act_cfg=dict(type="ReLU"),
encoder=dict(
type="DetrTransformerEncoder",
num_layers=6,
transformerlayers=dict(
type="BaseTransformerLayer",
attn_cfgs=dict(
type="MultiScaleDeformableAttention",
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None,
),
feedforward_channels=1024,
ffn_dropout=0.0,
operation_order=("self_attn", "norm", "ffn", "norm"),
),
init_cfg=None,
),
positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True),
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
self.strides = strides
self.num_input_levels = len(in_channels)
self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels
assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one"
input_conv_list = []
# from top to down (low to high resolution)
for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1):
input_conv = ConvModule(
in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True
)
input_conv_list.append(input_conv)
self.input_convs = ModuleList(input_conv_list)
self.encoder = build_transformer_layer_sequence(encoder)
self.postional_encoding = build_positional_encoding(positional_encoding)
# high resolution to low resolution
self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels)
# fpn-like structure
self.lateral_convs = ModuleList()
self.output_convs = ModuleList()
self.use_bias = norm_cfg is None
# from top to down (low to high resolution)
# fpn for the rest features that didn't pass in encoder
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
lateral_conv = ConvModule(
in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None
)
output_conv = ConvModule(
feat_channels,
feat_channels,
kernel_size=3,
stride=1,
padding=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
self.lateral_convs.append(lateral_conv)
self.output_convs.append(output_conv)
self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.num_outs = num_outs
self.point_generator = MlvlPointGenerator(strides)
def init_weights(self):
"""Initialize weights."""
for i in range(0, self.num_encoder_levels):
xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform")
for i in range(0, self.num_input_levels - self.num_encoder_levels):
caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
caffe2_xavier_init(self.output_convs[i].conv, bias=0)
caffe2_xavier_init(self.mask_feature, bias=0)
normal_init(self.level_encoding, mean=0, std=1)
for p in self.encoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
# init_weights defined in MultiScaleDeformableAttention
for layer in self.encoder.layers:
for attn in layer.attentions:
if isinstance(attn, MultiScaleDeformableAttention):
attn.init_weights()
def forward(self, feats):
"""
Args:
feats (list[Tensor]): Feature maps of each level. Each has
shape of (batch_size, c, h, w).
Returns:
tuple: A tuple containing the following:
- mask_feature (Tensor): shape (batch_size, c, h, w).
- multi_scale_features (list[Tensor]): Multi scale \
features, each in shape (batch_size, c, h, w).
"""
# generate padding mask for each level, for each image
batch_size = feats[0].shape[0]
encoder_input_list = []
padding_mask_list = []
level_positional_encoding_list = []
spatial_shapes = []
reference_points_list = []
for i in range(self.num_encoder_levels):
level_idx = self.num_input_levels - i - 1
feat = feats[level_idx]
feat_projected = self.input_convs[i](feat)
h, w = feat.shape[-2:]
# no padding
padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool)
pos_embed = self.postional_encoding(padding_mask_resized)
level_embed = self.level_encoding.weight[i]
level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
# (h_i * w_i, 2)
reference_points = self.point_generator.single_level_grid_priors(
feat.shape[-2:], level_idx, device=feat.device
)
# normalize
factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
reference_points = reference_points / factor
# shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
padding_mask_resized = padding_mask_resized.flatten(1)
encoder_input_list.append(feat_projected)
padding_mask_list.append(padding_mask_resized)
level_positional_encoding_list.append(level_pos_embed)
spatial_shapes.append(feat.shape[-2:])
reference_points_list.append(reference_points)
# shape (batch_size, total_num_query),
# total_num_query=sum([., h_i * w_i,.])
padding_masks = torch.cat(padding_mask_list, dim=1)
# shape (total_num_query, batch_size, c)
encoder_inputs = torch.cat(encoder_input_list, dim=0)
level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0)
device = encoder_inputs.device
# shape (num_encoder_levels, 2), from low
# resolution to high resolution
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device)
# shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = torch.cat(reference_points_list, dim=0)
reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1)
valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2))
# shape (num_total_query, batch_size, c)
memory = self.encoder(
query=encoder_inputs,
key=None,
value=None,
query_pos=level_positional_encodings,
key_pos=None,
attn_masks=None,
key_padding_mask=None,
query_key_padding_mask=padding_masks,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_radios=valid_radios,
)
# (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
memory = memory.permute(1, 2, 0)
# from low resolution to high resolution
num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
outs = torch.split(memory, num_query_per_level, dim=-1)
outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)]
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
x = feats[i]
cur_feat = self.lateral_convs[i](x)
y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False)
y = self.output_convs[i](y)
outs.append(y)
multi_scale_features = outs[: self.num_outs]
mask_feature = self.mask_feature(outs[-1])
return mask_feature, multi_scale_features

View File

@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .encoder_decoder_mask2former import EncoderDecoderMask2Former

View File

@@ -0,0 +1,271 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.core import add_prefix
from mmseg.models import builder
from mmseg.models.builder import SEGMENTORS
from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize
@SEGMENTORS.register_module()
class EncoderDecoderMask2Former(BaseSegmentor):
"""Encoder Decoder segmentors.
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be dumped during inference.
"""
def __init__(
self,
backbone,
decode_head,
neck=None,
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None,
):
super(EncoderDecoderMask2Former, self).__init__(init_cfg)
if pretrained is not None:
assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight"
backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
decode_head.update(train_cfg=train_cfg)
decode_head.update(test_cfg=test_cfg)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
assert self.with_decode_head
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
def _init_auxiliary_head(self, auxiliary_head):
"""Initialize ``auxiliary_head``"""
if auxiliary_head is not None:
if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head:
self.auxiliary_head.append(builder.build_head(head_cfg))
else:
self.auxiliary_head = builder.build_head(auxiliary_head)
def extract_feat(self, img):
"""Extract features from images."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(img)
out = self._decode_head_forward_test(x, img_metas)
out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners)
return out
def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs)
losses.update(add_prefix(loss_decode, "decode"))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
return seg_logits
def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
losses.update(add_prefix(loss_aux, f"aux_{idx}"))
else:
loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
losses.update(add_prefix(loss_aux, "aux"))
return losses
def forward_dummy(self, img):
"""Dummy forward function."""
seg_logit = self.encode_decode(img, None)
return seg_logit
def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs):
"""Forward function for training.
Args:
img (Tensor): Input images.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg)
losses.update(loss_aux)
return losses
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
num_classes = self.num_classes
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
crop_seg_logit = self.encode_decode(crop_img, img_meta)
preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
preds = resize(
preds,
size=img_meta[0]["ori_shape"][:2],
mode="bilinear",
align_corners=self.align_corners,
warning=False,
)
return preds
def whole_inference(self, img, img_meta, rescale):
"""Inference with full image."""
seg_logit = self.encode_decode(img, img_meta)
if rescale:
# support dynamic shape for onnx
if torch.onnx.is_in_onnx_export():
size = img.shape[2:]
else:
size = img_meta[0]["ori_shape"][:2]
seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False)
return seg_logit
def inference(self, img, img_meta, rescale):
"""Inference with slide/whole style.
Args:
img (Tensor): The input image of shape (N, 3, H, W).
img_meta (dict): Image info dict where each dict has: 'img_shape',
'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
rescale (bool): Whether rescale back to original shape.
Returns:
Tensor: The output segmentation map.
"""
assert self.test_cfg.mode in ["slide", "whole"]
ori_shape = img_meta[0]["ori_shape"]
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
if self.test_cfg.mode == "slide":
seg_logit = self.slide_inference(img, img_meta, rescale)
else:
seg_logit = self.whole_inference(img, img_meta, rescale)
output = F.softmax(seg_logit, dim=1)
flip = img_meta[0]["flip"]
if flip:
flip_direction = img_meta[0]["flip_direction"]
assert flip_direction in ["horizontal", "vertical"]
if flip_direction == "horizontal":
output = output.flip(dims=(3,))
elif flip_direction == "vertical":
output = output.flip(dims=(2,))
return output
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
return seg_pred
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred
def aug_test(self, imgs, img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented seg logit inplace
seg_logit = self.inference(imgs[0], img_metas[0], rescale)
for i in range(1, len(imgs)):
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
seg_logit += cur_seg_logit
seg_logit /= len(imgs)
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred

View File

@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .assigner import MaskHungarianAssigner
from .point_sample import get_uncertain_point_coords_with_randomness
from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding
from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer

View File

@@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
import torch
from ..builder import MASK_ASSIGNERS, build_match_cost
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
class AssignResult(metaclass=ABCMeta):
"""Collection of assign results."""
def __init__(self, num_gts, gt_inds, labels):
self.num_gts = num_gts
self.gt_inds = gt_inds
self.labels = labels
@property
def info(self):
info = {
"num_gts": self.num_gts,
"gt_inds": self.gt_inds,
"labels": self.labels,
}
return info
class BaseAssigner(metaclass=ABCMeta):
"""Base assigner that assigns boxes to ground truth boxes."""
@abstractmethod
def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None):
"""Assign boxes to either a ground truth boxes or a negative boxes."""
pass
@MASK_ASSIGNERS.register_module()
class MaskHungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between predictions and ground truth for
mask.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classification cost, regression L1 cost and regression iou cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config.
mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config.
dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config.
"""
def __init__(
self,
cls_cost=dict(type="ClassificationCost", weight=1.0),
dice_cost=dict(type="DiceCost", weight=1.0),
mask_cost=dict(type="MaskFocalCost", weight=1.0),
):
self.cls_cost = build_match_cost(cls_cost)
self.dice_cost = build_match_cost(dice_cost)
self.mask_cost = build_match_cost(mask_cost)
def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7):
"""Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.
1. assign every prediction to -1
2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.
Args:
mask_pred (Tensor): Predicted mask, shape [num_query, h, w]
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w].
gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,).
img_meta (dict): Meta information for current image.
gt_masks_ignore (Tensor, optional): Ground truth masks that are
labelled as `ignored`. Default None.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported."
num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0]
# 1. assign -1 by default
assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
if num_gts == 0 or num_queries == 0:
# No ground truth or boxes, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)
# 2. compute the weighted costs
# classification and maskcost.
if self.cls_cost.weight != 0 and cls_pred is not None:
cls_cost = self.cls_cost(cls_pred, gt_labels)
else:
cls_cost = 0
if self.mask_cost.weight != 0:
# mask_pred shape = [nq, h, w]
# gt_mask shape = [ng, h, w]
# mask_cost shape = [nq, ng]
mask_cost = self.mask_cost(mask_pred, gt_masks)
else:
mask_cost = 0
if self.dice_cost.weight != 0:
dice_cost = self.dice_cost(mask_pred, gt_masks)
else:
dice_cost = 0
cost = cls_cost + mask_cost + dice_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" ' "to install scipy first.")
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)

View File

@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
from mmcv.ops import point_sample
def get_uncertainty(mask_pred, labels):
"""Estimate uncertainty based on pred logits.
We estimate uncertainty as L1 distance between 0.0 and the logits
prediction in 'mask_pred' for the foreground class in `classes`.
Args:
mask_pred (Tensor): mask predication logits, shape (num_rois,
num_classes, mask_height, mask_width).
labels (list[Tensor]): Either predicted or ground truth label for
each predicted mask, of length num_rois.
Returns:
scores (Tensor): Uncertainty scores with the most uncertain
locations having the highest uncertainty score,
shape (num_rois, 1, mask_height, mask_width)
"""
if mask_pred.shape[1] == 1:
gt_class_logits = mask_pred.clone()
else:
inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
return -torch.abs(gt_class_logits)
def get_uncertain_point_coords_with_randomness(
mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio
):
"""Get ``num_points`` most uncertain points with random points during
train.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'get_uncertainty()' function that takes point's logit prediction as
input.
Args:
mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
mask_height, mask_width) for class-specific or class-agnostic
prediction.
labels (list): The ground truth class for each instance.
num_points (int): The number of points to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled
via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
that contains the coordinates sampled points.
"""
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = mask_pred.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device)
point_logits = point_sample(mask_pred, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = get_uncertainty(point_logits, labels)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device)
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
return point_coords

View File

@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
from mmcv.runner import BaseModule
@POSITIONAL_ENCODING.register_module()
class SinePositionalEncoding(BaseModule):
"""Position encoding with sine and cosine functions.
See `End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
Args:
num_feats (int): The feature dimension for each position
along x-axis or y-axis. Note the final returned dimension
for each position is 2 times of this value.
temperature (int, optional): The temperature used for scaling
the position embedding. Defaults to 10000.
normalize (bool, optional): Whether to normalize the position
embedding. Defaults to False.
scale (float, optional): A scale factor that scales the position
embedding. The scale will be used only when `normalize` is True.
Defaults to 2*pi.
eps (float, optional): A value added to the denominator for
numerical stability. Defaults to 1e-6.
offset (float): offset add to embed when do the normalization.
Defaults to 0.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(
self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None
):
super(SinePositionalEncoding, self).__init__(init_cfg)
if normalize:
assert isinstance(scale, (float, int)), (
"when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}"
)
self.num_feats = num_feats
self.temperature = temperature
self.normalize = normalize
self.scale = scale
self.eps = eps
self.offset = offset
def forward(self, mask):
"""Forward function for `SinePositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
# For convenience of exporting to ONNX, it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
not_mask = 1 - mask # logical_not
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
# use `view` instead of `flatten` for dynamically exporting to ONNX
B, H, W = mask.size()
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f"(num_feats={self.num_feats}, "
repr_str += f"temperature={self.temperature}, "
repr_str += f"normalize={self.normalize}, "
repr_str += f"scale={self.scale}, "
repr_str += f"eps={self.eps})"
return repr_str
@POSITIONAL_ENCODING.register_module()
class LearnedPositionalEncoding(BaseModule):
"""Position embedding with learnable embedding weights.
Args:
num_feats (int): The feature dimension for each position
along x-axis or y-axis. The final returned dimension for
each position is 2 times of this value.
row_num_embed (int, optional): The dictionary size of row embeddings.
Default 50.
col_num_embed (int, optional): The dictionary size of col embeddings.
Default 50.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")):
super(LearnedPositionalEncoding, self).__init__(init_cfg)
self.row_embed = nn.Embedding(row_num_embed, num_feats)
self.col_embed = nn.Embedding(col_num_embed, num_feats)
self.num_feats = num_feats
self.row_num_embed = row_num_embed
self.col_num_embed = col_num_embed
def forward(self, mask):
"""Forward function for `LearnedPositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
h, w = mask.shape[-2:]
x = torch.arange(w, device=mask.device)
y = torch.arange(h, device=mask.device)
x_embed = self.col_embed(x)
y_embed = self.row_embed(y)
pos = (
torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(mask.shape[0], 1, 1, 1)
)
return pos
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f"(num_feats={self.num_feats}, "
repr_str += f"row_num_embed={self.row_num_embed}, "
repr_str += f"col_num_embed={self.col_num_embed})"
return repr_str

View File

@@ -0,0 +1,989 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import warnings
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE
from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence
from mmcv.runner.base_module import BaseModule, Sequential
from mmcv.utils import deprecated_api_warning, to_2tuple
from torch.nn.init import normal_
from ..builder import TRANSFORMER
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
"`MultiScaleDeformableAttention` in MMCV has been moved to "
"`mmcv.ops.multi_scale_deform_attn`, please update your MMCV"
)
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
class AdaptivePadding(nn.Module):
"""Applies padding to input (if needed) so that input can get fully covered
by filter you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
input. The "corner" mode would pad zero to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel:
stride (int | tuple): Stride of the filter. Default: 1:
dilation (int | tuple): Spacing between kernel elements.
Default: 1
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"):
super(AdaptivePadding, self).__init__()
assert padding in ("same", "corner")
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == "corner":
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == "same":
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
merge patch, which is about 25% faster than original implementation.
Instead, we need to modify pretrained models for compatibility.
Args:
in_channels (int): The num of input channels.
to gets fully covered by filter and stride you specified..
Default: True.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding="corner",
dilation=1,
bias=False,
norm_cfg=dict(type="LN"),
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding
)
# disable the padding of unfold
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}"
H, W = input_size
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
if self.adap_padding:
x = self.adap_padding(x)
H, W = x.shape[-2:]
x = self.sampler(x)
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
out_h = (
H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1
) // self.sampler.stride[0] + 1
out_w = (
W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1
) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
@FEEDFORWARD_NETWORK.register_module(force=True)
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
@deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN")
def __init__(
self,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type="ReLU", inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True,
init_cfg=None,
with_cp=False,
**kwargs,
):
super().__init__(init_cfg)
assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}."
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
self.with_cp = with_cp
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity
@deprecated_api_warning({"residual": "identity"}, cls_name="FFN")
def forward(self, x, identity=None):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.layers, x)
else:
out = self.layers(x)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module()
class DetrTransformerDecoderLayer(BaseTransformerLayer):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
DefaultNone
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default2.
"""
def __init__(
self,
attn_cfgs,
feedforward_channels,
ffn_dropout=0.0,
operation_order=None,
act_cfg=dict(type="ReLU", inplace=True),
norm_cfg=dict(type="LN"),
ffn_num_fcs=2,
**kwargs,
):
super(DetrTransformerDecoderLayer, self).__init__(
attn_cfgs=attn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
ffn_num_fcs=ffn_num_fcs,
**kwargs,
)
assert len(operation_order) == 6
assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"])
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DetrTransformerEncoder(TransformerLayerSequence):
"""TransformerEncoder of DETR.
Args:
post_norm_cfg (dict): Config of last normalization layer. Default
`LN`. Only used when `self.pre_norm` is `True`
"""
def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs):
super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
if post_norm_cfg is not None:
self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
else:
assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg"
self.post_norm = None
def forward(self, *args, **kwargs):
"""Forward function for `TransformerCoder`.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
if self.post_norm is not None:
x = self.post_norm(x)
return x
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DetrTransformerDecoder(TransformerLayerSequence):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
post_norm_cfg (dict): Config of last normalization layer. Default
`LN`.
"""
def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs):
super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
if post_norm_cfg is not None:
self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1]
else:
self.post_norm = None
def forward(self, query, *args, **kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
if not self.return_intermediate:
x = super().forward(query, *args, **kwargs)
if self.post_norm:
x = self.post_norm(x)[None]
return x
intermediate = []
for layer in self.layers:
query = layer(query, *args, **kwargs)
if self.return_intermediate:
if self.post_norm is not None:
intermediate.append(self.post_norm(query))
else:
intermediate.append(query)
return torch.stack(intermediate)
@TRANSFORMER.register_module()
class Transformer(BaseModule):
"""Implements the DETR transformer.
Following the official DETR implementation, this module copy-paste
from torch.nn.Transformer with modifications:
* positional encodings are passed in MultiheadAttention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
See `paper: End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
Args:
encoder (`mmcv.ConfigDict` | Dict): Config of
TransformerEncoder. Defaults to None.
decoder ((`mmcv.ConfigDict` | Dict)): Config of
TransformerDecoder. Defaults to None
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Defaults to None.
"""
def __init__(self, encoder=None, decoder=None, init_cfg=None):
super(Transformer, self).__init__(init_cfg=init_cfg)
self.encoder = build_transformer_layer_sequence(encoder)
self.decoder = build_transformer_layer_sequence(decoder)
self.embed_dims = self.encoder.embed_dims
def init_weights(self):
# follow the official DETR to init parameters
for m in self.modules():
if hasattr(m, "weight") and m.weight.dim() > 1:
xavier_init(m, distribution="uniform")
self._is_init = True
def forward(self, x, mask, query_embed, pos_embed):
"""Forward function for `Transformer`.
Args:
x (Tensor): Input query with shape [bs, c, h, w] where
c = embed_dims.
mask (Tensor): The key_padding_mask used for encoder and decoder,
with shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder, with shape
[num_query, c].
pos_embed (Tensor): The positional encoding for encoder and
decoder, with the same shape as `x`.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- out_dec: Output from decoder. If return_intermediate_dec \
is True output has shape [num_dec_layers, bs,
num_query, embed_dims], else has shape [1, bs, \
num_query, embed_dims].
- memory: Output results from encoder, with shape \
[bs, embed_dims, h, w].
"""
bs, c, h, w = x.shape
# use `view` instead of `flatten` for dynamically exporting to ONNX
x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask)
target = torch.zeros_like(query_embed)
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder(
query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask
)
out_dec = out_dec.transpose(1, 2)
memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
return out_dec, memory
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DeformableDetrTransformerDecoder(TransformerLayerSequence):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default
`LN`.
"""
def __init__(self, *args, return_intermediate=False, **kwargs):
super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
)
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
output = layer(output, *args, reference_points=reference_points_input, **kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
return output, reference_points
@TRANSFORMER.register_module()
class DeformableDetrTransformer(Transformer):
"""Implements the DeformableDETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs):
super(DeformableDetrTransformer, self).__init__(**kwargs)
self.as_two_stage = as_two_stage
self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals
self.embed_dims = self.encoder.embed_dims
self.init_layers()
def init_layers(self):
"""Initialize layers of the DeformableDetrTransformer."""
self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims))
if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
self.enc_output_norm = nn.LayerNorm(self.embed_dims)
self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2)
self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
else:
self.reference_points = nn.Linear(self.embed_dims, 2)
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
if not self.as_two_stage:
xavier_init(self.reference_points, distribution="uniform", bias=0.0)
normal_(self.level_embeds)
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
"""Generate proposals from encoded memory.
Args:
memory (Tensor) : The output of encoder,
has shape (bs, num_key, embed_dim). num_key is
equal the number of points on feature map from
all level.
memory_padding_mask (Tensor): Padding mask for memory.
has shape (bs, num_key).
spatial_shapes (Tensor): The shape of all feature maps.
has shape (num_level, 2).
Returns:
tuple: A tuple of feature map and bbox prediction.
- output_memory (Tensor): The input of decoder, \
has shape (bs, num_key, embed_dim). num_key is \
equal the number of points on feature map from \
all levels.
- output_proposals (Tensor): The normalized proposal \
after a inverse sigmoid, has shape \
(bs, num_keys, 4).
"""
N, S, C = memory.shape
proposals = []
_cur = 0
for lvl, (H, W) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
proposals.append(proposal)
_cur += H * W
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def get_valid_ratio(self, mask):
"""Get the valid radios of feature maps of all level."""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000):
"""Get the position embedding of proposal."""
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
return pos
def forward(
self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs
):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
assert self.as_two_stage or query_embed is not None
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device)
feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
memory = self.encoder(
query=feat_flatten,
key=None,
value=None,
query_pos=lvl_pos_embed_flatten,
query_key_padding_mask=mask_flatten,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
**kwargs,
)
memory = memory.permute(1, 0, 2)
bs, _, c = memory.shape
if self.as_two_stage:
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory)
enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals
topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_pos, query = torch.split(pos_trans_out, c, dim=2)
else:
query_pos, query = torch.split(query_embed, c, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos).sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2)
memory = memory.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=memory,
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs,
)
inter_references_out = inter_references
if self.as_two_stage:
return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
return inter_states, init_reference_out, inter_references_out, None, None
@TRANSFORMER.register_module()
class DynamicConv(BaseModule):
"""Implements Dynamic Convolution.
This module generate parameters for each sample and
use bmm to implement 1*1 convolution. Code is modified
from the `official github repo <https://github.com/PeizeSun/
SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
Args:
in_channels (int): The input feature channel.
Defaults to 256.
feat_channels (int): The inner feature channel.
Defaults to 64.
out_channels (int, optional): The output feature channel.
When not specified, it will be set to `in_channels`
by default
input_feat_shape (int): The shape of input feature.
Defaults to 7.
with_proj (bool): Project two-dimentional feature to
one-dimentional feature. Default to True.
act_cfg (dict): The activation config for DynamicConv.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(
self,
in_channels=256,
feat_channels=64,
out_channels=None,
input_feat_shape=7,
with_proj=True,
act_cfg=dict(type="ReLU", inplace=True),
norm_cfg=dict(type="LN"),
init_cfg=None,
):
super(DynamicConv, self).__init__(init_cfg)
self.in_channels = in_channels
self.feat_channels = feat_channels
self.out_channels_raw = out_channels
self.input_feat_shape = input_feat_shape
self.with_proj = with_proj
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.out_channels = out_channels if out_channels else in_channels
self.num_params_in = self.in_channels * self.feat_channels
self.num_params_out = self.out_channels * self.feat_channels
self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out)
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
self.activation = build_activation_layer(act_cfg)
num_output = self.out_channels * input_feat_shape**2
if self.with_proj:
self.fc_layer = nn.Linear(num_output, self.out_channels)
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
def forward(self, param_feature, input_feature):
"""Forward function for `DynamicConv`.
Args:
param_feature (Tensor): The feature can be used
to generate the parameter, has shape
(num_all_proposals, in_channels).
input_feature (Tensor): Feature that
interact with parameters, has shape
(num_all_proposals, in_channels, H, W).
Returns:
Tensor: The output feature has shape
(num_all_proposals, out_channels).
"""
input_feature = input_feature.flatten(2).permute(2, 0, 1)
input_feature = input_feature.permute(1, 0, 2)
parameters = self.dynamic_layer(param_feature)
param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels)
param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels)
# input_feature has shape (num_all_proposals, H*W, in_channels)
# param_in has shape (num_all_proposals, in_channels, feat_channels)
# feature has shape (num_all_proposals, H*W, feat_channels)
features = torch.bmm(input_feature, param_in)
features = self.norm_in(features)
features = self.activation(features)
# param_out has shape (batch_size, feat_channels, out_channels)
features = torch.bmm(features, param_out)
features = self.norm_out(features)
features = self.activation(features)
if self.with_proj:
features = features.flatten(1)
features = self.fc_layer(features)
features = self.fc_norm(features)
features = self.activation(features)
return features

View File

@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules
# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
from .ms_deform_attn import MSDeformAttn

View File

@@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function
from torch.cuda.amp import custom_fwd
from torch.nn.init import constant_, xavier_uniform_
class MSDeformAttnFunction(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(
ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step
):
output = ms_deform_attn_core_pytorch(
value,
value_spatial_shapes,
# value_level_start_index,
sampling_locations,
attention_weights,
)
return output
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
return output.transpose(1, 2).contiguous()
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0):
"""Multi-Scale Deformable Attention Module.
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2
# which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn(
"You'd better set d_model in MSDeformAttn to make "
"the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation."
)
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.ratio = ratio
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, int(d_model * ratio))
self.output_proj = nn.Linear(int(d_model * ratio), d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.n_heads, 1, 1, 2)
.repeat(1, self.n_levels, self.n_points, 1)
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
def forward(
self,
query,
reference_points,
input_flatten,
input_spatial_shapes,
input_level_start_index,
input_padding_mask=None,
):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
# print(query.shape)
# print(reference_points.shape)
# print(input_flatten.shape)
# print(input_spatial_shapes.shape)
# print(input_level_start_index.shape)
# print(input_spatial_shapes)
# print(input_level_start_index)
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1])
)
output = MSDeformAttnFunction.apply(
value,
input_spatial_shapes,
input_level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output)
return output

Some files were not shown because too many files have changed in this diff Show More