# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional, Set, Tuple
import torch
import torch.nn as nn
from timm.models.layers import DropPath, Mlp, trunc_normal_
from transformers.pytorch_utils import torch_int_div
[docs]@dataclass
class MultiHeadAttentionOutput:
mha_output: torch.Tensor
attention: Optional[torch.Tensor] = None
[docs]class MultiHeadAttentionLayer(nn.Module):
"""
Multi-head self attention module
The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
- Defines a custom `MultiHeadAttentionLayer` which does not only apply `self-attention` but it can be
generalised to arbitrary (query, key, value) input tuples. This feature can be valuable to process
more than 2 scans at a time.
- `Self-attention` specific use-case can still be invoked by calling the `forward_as_mhsa` method.
"""
def __init__(
self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0
) -> None:
super().__init__()
self.num_heads = num_heads
assert dim % num_heads == 0, f"The embedding dim ({dim}) must be divisible by the number of heads ({num_heads})"
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.return_attention = False
self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
[docs] def forward(self, k: torch.Tensor, q: torch.Tensor, v: torch.Tensor) -> MultiHeadAttentionOutput:
B, N, C = v.shape
assert (
C % self.num_heads == 0
), f"The embedding dim ({C}) must be divisible by the number of heads ({self.num_heads})"
w_q = self.proj_q(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
w_k = self.proj_k(k).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
w_v = self.proj_v(v).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (w_q @ w_k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
o = (attn @ w_v).transpose(1, 2).reshape(B, N, C)
o = self.proj(o)
o = self.proj_drop(o)
attention_output = attn if self.return_attention else None
return MultiHeadAttentionOutput(mha_output=o, attention=attention_output)
def forward_as_mhsa(self, input: torch.Tensor) -> MultiHeadAttentionOutput:
return self(k=input, q=input, v=input)
[docs]class Block(nn.Module):
"""
Encapsulates multi-layer perceptron and multi-head self attention modules into a block.
The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
- This implementation uses spatio-temporal positional embeddings instead of 2D positional embeddings only,
and they are taken into account within the forward pass of each ViT block.
- Utilises the custom defined `MultiHeadAttentionLayer` which does not apply `self-attention` only but can be
generalised to arbitrary (query, key, value) tuples. This can be valuable to process more than 2 scans.
Positional and type embeddings are handled in a similar fashion as DETR object localisation paper
https://alcinos.github.io/detr_page/, where a fixed set of sine/cos positional embeddings are used
in an additive manner to Q and K tensors.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 1.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = MultiHeadAttentionLayer(
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def with_pos_and_type_embed(self, tensor: torch.Tensor, emb: Optional[torch.Tensor]) -> torch.Tensor:
# Add positional embeddings to key and query tensors
return tensor if emb is None else tensor + emb
[docs] def forward(self, x: torch.Tensor, pos_and_type_embed: Optional[torch.Tensor]) -> torch.Tensor:
x_with_emb = self.with_pos_and_type_embed(self.norm1(x), emb=pos_and_type_embed)
x = x + self.drop_path(self.attn.forward_as_mhsa(x_with_emb).mha_output)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
[docs]class SinePositionEmbedding:
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self, embedding_dim: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def __call__(self, mask: torch.Tensor) -> torch.Tensor:
assert mask is not None, "No pixel mask provided"
B, H, W = mask.shape
y_embed = mask.cumsum(1, dtype=torch.float32)
x_embed = mask.cumsum(2, dtype=torch.float32)
if self.normalize:
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2) / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2)
return pos