health_multimodal.image.model.transformer

Classes

Block(dim, num_heads[, mlp_ratio, qkv_bias, …])

Encapsulates multi-layer perceptron and multi-head self attention modules into a block.

MultiHeadAttentionLayer(dim[, num_heads, …])

Multi-head self attention module

MultiHeadAttentionOutput(mha_output[, attention])

SinePositionEmbedding([embedding_dim, …])

This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images.

VisionTransformerPooler(input_dim, grid_shape)

type input_dim

int

class health_multimodal.image.model.transformer.Block(dim, num_heads, mlp_ratio=1.0, qkv_bias=False, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]

Encapsulates multi-layer perceptron and multi-head self attention modules into a block.

The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
  • This implementation uses spatio-temporal positional embeddings instead of 2D positional embeddings only,

    and they are taken into account within the forward pass of each ViT block.

  • Utilises the custom defined MultiHeadAttentionLayer which does not apply self-attention only but can be

    generalised to arbitrary (query, key, value) tuples. This can be valuable to process more than 2 scans.

Positional and type embeddings are handled in a similar fashion as DETR object localisation paper https://alcinos.github.io/detr_page/, where a fixed set of sine/cos positional embeddings are used in an additive manner to Q and K tensors.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, pos_and_type_embed)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

Tensor

class health_multimodal.image.model.transformer.MultiHeadAttentionLayer(dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0)[source]

Multi-head self attention module

The content builds on top of the TIMM library (vision_transformer.py) and differs by the following:
  • Defines a custom MultiHeadAttentionLayer which does not only apply self-attention but it can be

    generalised to arbitrary (query, key, value) input tuples. This feature can be valuable to process more than 2 scans at a time.

  • Self-attention specific use-case can still be invoked by calling the forward_as_mhsa method.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(k, q, v)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

MultiHeadAttentionOutput

class health_multimodal.image.model.transformer.MultiHeadAttentionOutput(mha_output, attention=None)[source]
class health_multimodal.image.model.transformer.SinePositionEmbedding(embedding_dim=64, temperature=10000, normalize=False, scale=None)[source]

This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images.

class health_multimodal.image.model.transformer.VisionTransformerPooler(input_dim, grid_shape, num_heads=8, num_blocks=3, norm_layer=functools.partial(<class 'torch.nn.modules.normalization.LayerNorm'>, eps=1e-06))[source]
Parameters
  • input_dim (int) – Input feature dimension (i.e., channels in old CNN terminology)

  • grid_shape (Tuple[int, int]) – Shape of the grid of patches per image

  • num_heads (int) – Number of self-attention heads within the MHA block

  • num_blocks (int) – Number of blocks per attention layer

  • norm_layer (Any) – Normalisation layer

self.type_embed: Is used to characterise prior and current scans, and

create permutation variance across modalities/series.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(current_image, previous_image=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

Tensor