# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from enum import Enum, unique
from typing import Tuple
from ..image.model.pretrained import (
BIOMED_VLP_BIOVIL_T,
BIOMED_VLP_CXR_BERT_SPECIALIZED,
BIOVIL_T_COMMIT_TAG,
CXR_BERT_COMMIT_TAG,
)
from .inference_engine import TextInferenceEngine
from .model import CXRBertModel, CXRBertTokenizer
[docs]@unique
class BertEncoderType(str, Enum):
CXR_BERT = "cxr_bert"
BIOVIL_T_BERT = "biovil_t_bert"
[docs]def get_biovil_t_bert() -> Tuple[CXRBertTokenizer, CXRBertModel]:
"""Load the BioViL-T Bert model and tokenizer from the `Hugging Face Hub <https://huggingface.co/microsoft/BiomedVLP-BioViL-T>`_.""" # noqa: B950
tokenizer = CXRBertTokenizer.from_pretrained(BIOMED_VLP_BIOVIL_T, revision=BIOVIL_T_COMMIT_TAG)
text_model = CXRBertModel.from_pretrained(BIOMED_VLP_BIOVIL_T, revision=BIOVIL_T_COMMIT_TAG)
return tokenizer, text_model
[docs]def get_cxr_bert() -> Tuple[CXRBertTokenizer, CXRBertModel]:
"""Load the CXR-BERT model and tokenizer from the `Hugging Face Hub <https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized>`_.""" # noqa: B950
tokenizer = CXRBertTokenizer.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, revision=CXR_BERT_COMMIT_TAG)
text_model = CXRBertModel.from_pretrained(BIOMED_VLP_CXR_BERT_SPECIALIZED, revision=CXR_BERT_COMMIT_TAG)
return tokenizer, text_model
[docs]def get_bert_inference(bert_encoder_type: BertEncoderType = BertEncoderType.BIOVIL_T_BERT) -> TextInferenceEngine:
"""Create a :class:`TextInferenceEngine` for a text encoder model.
:param bert_encoder_type: The type of text encoder model to use, `CXR_BERT` or `BIOVIL_T_BERT`.
The model is downloaded from the Hugging Face Hub.
The engine can be used to get embeddings from text prompts or masked token predictions.
"""
if bert_encoder_type == BertEncoderType.CXR_BERT:
tokenizer, text_model = get_cxr_bert()
elif bert_encoder_type == BertEncoderType.BIOVIL_T_BERT:
tokenizer, text_model = get_biovil_t_bert()
else:
raise ValueError(f"Unknown bert_encoder_type: {bert_encoder_type}")
text_inference = TextInferenceEngine(tokenizer=tokenizer, text_model=text_model)
return text_inference