# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from typing import Any, List, Union
import torch
from transformers import BertForMaskedLM, BertTokenizer
from health_multimodal.text.data.io import TextInput
[docs]class TextInferenceEngine(TextInput):
"""
Text inference class that implements functionalities required to extract
sentence embedding, similarity and MLM prediction tasks.
:param tokenizer: A BertTokenizer object.
:param text_model: Text model either default HuggingFace class
"""
def __init__(self, tokenizer: BertTokenizer, text_model: BertForMaskedLM) -> None:
super().__init__(tokenizer=tokenizer)
assert isinstance(text_model, BertForMaskedLM), f"Expected a BertForMaskedLM, got {type(text_model)}"
self.model = text_model
self.max_allowed_input_length = self.model.config.max_position_embeddings
self.to = self.model.to
[docs] def is_in_eval(self) -> bool:
"""Returns True if the model is in eval mode."""
return not self.model.training
[docs] def tokenize_input_prompts(self, prompts: Union[str, List[str]], verbose: bool = True) -> Any:
tokenizer_output = super().tokenize_input_prompts(prompts, verbose=verbose)
device = next(self.model.parameters()).device
tokenizer_output.input_ids = tokenizer_output.input_ids.to(device)
tokenizer_output.attention_mask = tokenizer_output.attention_mask.to(device)
max_length = tokenizer_output.input_ids.shape[1]
if tokenizer_output.input_ids.shape[1] > self.max_allowed_input_length:
raise ValueError(
f"The sequence length of the input ({max_length}) is "
f"longer than the maximum allowed sequence length ({self.max_allowed_input_length})."
)
return tokenizer_output
[docs] @torch.no_grad()
def get_embeddings_from_prompt(
self, prompts: Union[str, List[str]], normalize: bool = True, verbose: bool = True
) -> torch.Tensor:
"""Generate L2-normalised embeddings for a list of input text prompts.
:param prompts: Input text prompt(s) either in string or list of string format.
:param normalize: If True, L2-normalise the embeddings.
:param verbose: If set to True, tokenized words are displayed in the console.
:return: Tensor of shape (batch_size, embedding_size).
"""
assert self.is_in_eval()
tokenizer_output = self.tokenize_input_prompts(prompts=prompts, verbose=verbose)
txt_emb = self.model.get_projected_text_embeddings( # type: ignore
input_ids=tokenizer_output.input_ids,
attention_mask=tokenizer_output.attention_mask,
normalize_embeddings=normalize,
)
return txt_emb
[docs] @torch.no_grad()
def get_pairwise_similarities(
self, prompt_set_1: Union[str, List[str]], prompt_set_2: Union[str, List[str]]
) -> torch.Tensor:
"""Compute pairwise cosine similarities between the embeddings of the given prompts."""
emb_1 = self.get_embeddings_from_prompt(prompts=prompt_set_1, verbose=False)
emb_2 = self.get_embeddings_from_prompt(prompts=prompt_set_2, verbose=False)
sim = torch.diag(torch.mm(emb_1, emb_2.t())).detach()
return sim
[docs] @torch.no_grad()
def predict_masked_tokens(self, prompts: Union[str, List[str]]) -> List[List[str]]:
"""Predict masked tokens for a single or list of input text prompts.
Requires models to be trained with a MLM prediction head.
:param prompts: Input text prompt(s) either in string or list of string format.
:return: Predicted token candidates (Top-1) at masked position.
"""
assert self.is_in_eval()
# Tokenize the input prompts
tokenized_prompts = self.tokenize_input_prompts(prompts)
# Collect all token predictions
text_model_output = self.model.forward(
input_ids=tokenized_prompts.input_ids, attention_mask=tokenized_prompts.attention_mask
)
logits = text_model_output.logits
logits = logits.detach()
predicted_token_ids = torch.argmax(logits, dim=-1) # Batch x Seq
# Identify the masked token indices
batch_size = predicted_token_ids.shape[0]
mask_token_id = self.tokenizer.mask_token_id
mlm_mask = tokenized_prompts.input_ids == mask_token_id # Batch x Seq
# Convert the predicted token ids to token strings
output = list()
for b in range(batch_size):
_ids = predicted_token_ids[b, mlm_mask[b]].cpu().tolist()
_tokens = self.tokenizer.convert_ids_to_tokens(_ids)
output.append(_tokens)
return output