health_multimodal.image.model.model

Classes

BaseImageModel()

Abstract class for image models.

ImageModel(img_encoder_type, joint_feature_size)

Image encoder module

MultiImageModel(**kwargs)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class health_multimodal.image.model.model.BaseImageModel[source]

Abstract class for image models.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

abstract forward(*args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

ImageModelOutput

class health_multimodal.image.model.model.ImageModel(img_encoder_type, joint_feature_size, freeze_encoder=False, pretrained_model_path=None, **downstream_classifier_kwargs)[source]

Image encoder module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

create_downstream_classifier(**kwargs)[source]

Create the classification module for the downstream task.

Return type

MultiTaskModel

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

ImageModelOutput

get_patchwise_projected_embeddings(input_img, normalize)[source]

Get patch-wise projected embeddings from the CNN model.

Parameters
  • input_img (Tensor) – input tensor image [B, C, H, W].

  • normalize (bool) – If True, the embeddings are L2-normalized.

Returns projected_embeddings

tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].

Return type

Tensor

train(mode=True)[source]

Switch the model between training and evaluation modes.

Return type

Any

class health_multimodal.image.model.model.MultiImageModel(**kwargs)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(current_image, previous_image=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

ImageModelOutput