Source code for health_multimodal.image.model.types

#  -------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  -------------------------------------------------------------------------------------------


from __future__ import annotations

from dataclasses import dataclass
from enum import Enum, unique
from typing import List

import torch


[docs]@dataclass class ImageModelOutput: img_embedding: torch.Tensor patch_embeddings: torch.Tensor projected_global_embedding: torch.Tensor class_logits: torch.Tensor projected_patch_embeddings: torch.Tensor
[docs]@unique class ImageEncoderType(str, Enum): RESNET18 = "resnet18" RESNET50 = "resnet50" RESNET18_MULTI_IMAGE = "resnet18_multi_image" RESNET50_MULTI_IMAGE = "resnet50_multi_image" @classmethod def get_members(cls, multi_image_encoders_only: bool) -> List[ImageEncoderType]: if multi_image_encoders_only: return [cls.RESNET18_MULTI_IMAGE, cls.RESNET50_MULTI_IMAGE] else: return [member for member in cls]
[docs]@unique class ImageEncoderWeightTypes(str, Enum): RANDOM = "random" IMAGENET = "imagenet" BIOVIL = "biovil" BIOVIL_T = "biovil_t"