health_multimodal.image.model.modules
Classes
|
Fully connected layers to map between image embeddings and projection space where pairs of images are compared. |
|
Torch module for multi-task classification heads. |
- class health_multimodal.image.model.modules.MLP(input_dim, output_dim, hidden_dim=None, use_1x1_convs=False)[source]
Fully connected layers to map between image embeddings and projection space where pairs of images are compared.
- Parameters
input_dim (
int
) – Input embedding feature sizehidden_dim (
Optional
[int
]) – Hidden layer size in MLPoutput_dim (
int
) – Output projection sizeuse_1x1_convs (
bool
) – Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class health_multimodal.image.model.modules.MultiTaskModel(input_dim, classifier_hidden_dim, num_classes, num_tasks)[source]
Torch module for multi-task classification heads. We create a separate classification head for each task and perform a forward pass on each head independently in forward(). Classification heads are instances of MLP.
- Parameters
input_dim (
int
) – Number of dimensions of the input feature map.classifier_hidden_dim (
Optional
[int
]) – Number of dimensions of hidden features in the MLP.num_classes (
int
) – Number of output classes per task.num_tasks (
int
) – Number of classification tasks or heads required.
Initializes internal Module state, shared by both nn.Module and ScriptModule.