health_multimodal.image.model.encoder
Functions
|
Returns the encoder class for the given encoder type. |
|
Calculate the output dimension of an encoder by making a single forward pass. |
|
Restore the training mode of a module after some operation. |
Classes
|
Image encoder trunk module for the |
|
Multi-image encoder trunk module for the |
- class health_multimodal.image.model.encoder.ImageEncoder(img_encoder_type)[source]
Image encoder trunk module for the
ImageModel
class.- :param img_encoder_typeType of image encoder model to use, either
"resnet18_multi_image"
or "resnet50_multi_image"
.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(current_image, return_patch_embeddings=False)[source]
Get image global and patch embeddings
- Return type
Union
[Tensor
,Tuple
[Tensor
,Tensor
]]
- reload_encoder_with_dilation(replace_stride_with_dilation=None)[source]
Workaround for enabling dilated convolutions after model initialization.
- Parameters
replace_stride_with_dilation (
Optional
[Sequence
[bool
]]) – Replace the 2x2 standard convolution stride with a dilated convolution in each layer in the last three blocks of ResNet architecture.- Return type
None
- :param img_encoder_typeType of image encoder model to use, either
- class health_multimodal.image.model.encoder.MultiImageEncoder(img_encoder_type)[source]
Multi-image encoder trunk module for the
ImageModel
class. It can be used to encode multiple images into combined latent representation. Currently it only supports two input images but can be extended to support more in future.- Parameters
img_encoder_type (
str
) – Type of image encoder model to use: either"resnet18"
or"resnet50"
.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(current_image, previous_image=None, return_patch_embeddings=False)[source]
Get image global and patch embeddings
- Return type
Union
[Tensor
,Tuple
[Tensor
,Tensor
]]
- reload_encoder_with_dilation(replace_stride_with_dilation=None)[source]
Workaround for enabling dilated convolutions after model initialization.
- Parameters
replace_stride_with_dilation (
Optional
[Sequence
[bool
]]) – Replace the 2x2 standard convolution stride with a dilated convolution in each layer in the last three blocks of ResNet architecture.- Return type
None
- health_multimodal.image.model.encoder.get_encoder_from_type(img_encoder_type)[source]
Returns the encoder class for the given encoder type.
- Parameters
img_encoder_type (
str
) – Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE}- Return type