# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from enum import Enum, unique
from .data.transforms import create_chest_xray_transform_for_inference
from .inference_engine import ImageInferenceEngine
from .model.pretrained import get_biovil_image_encoder
from .model.pretrained import get_biovil_t_image_encoder
TRANSFORM_RESIZE = 512
[docs]@unique
class ImageModelType(str, Enum):
BIOVIL = "biovil"
BIOVIL_T = "biovil_t"
[docs]def get_image_inference(image_model_type: ImageModelType = ImageModelType.BIOVIL_T) -> ImageInferenceEngine:
"""Create a :class:`ImageInferenceEngine` for the image model.
:param image_model_type: The type of image model to use, `BIOVIL` or `BIOVIL_T`.
The model is downloaded from the Hugging Face Hub.
The engine can be used to get embeddings from text prompts or masked token predictions.
"""
if image_model_type == ImageModelType.BIOVIL_T:
image_model = get_biovil_t_image_encoder()
transform_center_crop_size = 448
elif image_model_type == ImageModelType.BIOVIL:
image_model = get_biovil_image_encoder()
transform_center_crop_size = 480
else:
raise ValueError(f"Unknown image_model_type: {image_model_type}")
transform = create_chest_xray_transform_for_inference(
resize=TRANSFORM_RESIZE, center_crop_size=transform_center_crop_size
)
image_inference = ImageInferenceEngine(image_model=image_model, transform=transform)
return image_inference