health_multimodal.text.inference_engine
Classes
|
Text inference class that implements functionalities required to extract sentence embedding, similarity and MLM prediction tasks. |
- class health_multimodal.text.inference_engine.TextInferenceEngine(tokenizer, text_model)[source]
Text inference class that implements functionalities required to extract sentence embedding, similarity and MLM prediction tasks.
- Parameters
tokenizer (
BertTokenizer
) – A BertTokenizer object.text_model (
BertForMaskedLM
) – Text model either default HuggingFace class
- get_embeddings_from_prompt(prompts, normalize=True, verbose=True)[source]
Generate L2-normalised embeddings for a list of input text prompts.
- Parameters
prompts (
Union
[str
,List
[str
]]) – Input text prompt(s) either in string or list of string format.normalize (
bool
) – If True, L2-normalise the embeddings.verbose (
bool
) – If set to True, tokenized words are displayed in the console.
- Return type
Tensor
- Returns
Tensor of shape (batch_size, embedding_size).
- get_pairwise_similarities(prompt_set_1, prompt_set_2)[source]
Compute pairwise cosine similarities between the embeddings of the given prompts.
- Return type
Tensor
- predict_masked_tokens(prompts)[source]
Predict masked tokens for a single or list of input text prompts.
Requires models to be trained with a MLM prediction head.
- Parameters
prompts (
Union
[str
,List
[str
]]) – Input text prompt(s) either in string or list of string format.- Return type
List
[List
[str
]]- Returns
Predicted token candidates (Top-1) at masked position.
- tokenize_input_prompts(prompts, verbose=True)[source]
Tokenizes the input sentence(s) and adds special tokens as defined by the tokenizer. :type prompts:
Union
[str
,List
[str
]] :param prompts: Either a string containing a single sentence, or a list of strings each containinga single sentence. Note that this method will not correctly tokenize multiple sentences if they are input as a single string.
- Parameters
verbose (
bool
) – If set to True, will log the sentence after tokenization.- Return type
Any
- Returns
A 2D tensor containing the tokenized sentences