health_multimodal.image.model.encoder

Functions

get_encoder_from_type(img_encoder_type)

Returns the encoder class for the given encoder type.

get_encoder_output_dim(module, device)

Calculate the output dimension of an encoder by making a single forward pass.

restore_training_mode(module)

Restore the training mode of a module after some operation.

Classes

ImageEncoder(img_encoder_type)

Image encoder trunk module for the ImageModel class.

MultiImageEncoder(img_encoder_type)

Multi-image encoder trunk module for the ImageModel class.

class health_multimodal.image.model.encoder.ImageEncoder(img_encoder_type)[source]

Image encoder trunk module for the ImageModel class.

:param img_encoder_typeType of image encoder model to use, either "resnet18_multi_image" or

"resnet50_multi_image".

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(current_image, return_patch_embeddings=False)[source]

Get image global and patch embeddings

Return type

Union[Tensor, Tuple[Tensor, Tensor]]

reload_encoder_with_dilation(replace_stride_with_dilation=None)[source]

Workaround for enabling dilated convolutions after model initialization.

Parameters

replace_stride_with_dilation (Optional[Sequence[bool]]) – Replace the 2x2 standard convolution stride with a dilated convolution in each layer in the last three blocks of ResNet architecture.

Return type

None

class health_multimodal.image.model.encoder.MultiImageEncoder(img_encoder_type)[source]

Multi-image encoder trunk module for the ImageModel class. It can be used to encode multiple images into combined latent representation. Currently it only supports two input images but can be extended to support more in future.

Parameters

img_encoder_type (str) – Type of image encoder model to use: either "resnet18" or "resnet50".

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(current_image, previous_image=None, return_patch_embeddings=False)[source]

Get image global and patch embeddings

Return type

Union[Tensor, Tuple[Tensor, Tensor]]

reload_encoder_with_dilation(replace_stride_with_dilation=None)[source]

Workaround for enabling dilated convolutions after model initialization.

Parameters

replace_stride_with_dilation (Optional[Sequence[bool]]) – Replace the 2x2 standard convolution stride with a dilated convolution in each layer in the last three blocks of ResNet architecture.

Return type

None

health_multimodal.image.model.encoder.get_encoder_from_type(img_encoder_type)[source]

Returns the encoder class for the given encoder type.

Parameters

img_encoder_type (str) – Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE}

Return type

ImageEncoder

health_multimodal.image.model.encoder.get_encoder_output_dim(module, device)[source]

Calculate the output dimension of an encoder by making a single forward pass.

Parameters
  • module (Module) – Encoder module.

  • device (device) – Compute device to use.

Return type

int

health_multimodal.image.model.encoder.restore_training_mode(module)[source]

Restore the training mode of a module after some operation.

Parameters

module (Module) – PyTorch module.

Return type

Generator[None, None, None]