Source code for health_multimodal.text.data.io

#  -------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  -------------------------------------------------------------------------------------------

import logging
from typing import Any, List, Union

from transformers import BertTokenizer


TypePrompts = Union[str, List[str]]

logger = logging.getLogger(__name__)


[docs]class TextInput: """Text input class that can be used for inference and deployment. Implements tokenizer related operations and ensure that input strings conform with the standards expected from a BERT model. :param tokenizer: A BertTokenizer object. """ def __init__(self, tokenizer: BertTokenizer) -> None: self.tokenizer = tokenizer
[docs] def tokenize_input_prompts(self, prompts: TypePrompts, verbose: bool) -> Any: """ Tokenizes the input sentence(s) and adds special tokens as defined by the tokenizer. :param prompts: Either a string containing a single sentence, or a list of strings each containing a single sentence. Note that this method will not correctly tokenize multiple sentences if they are input as a single string. :param verbose: If set to True, will log the sentence after tokenization. :return: A 2D tensor containing the tokenized sentences """ prompts = [prompts] if isinstance(prompts, str) else prompts self.assert_special_tokens_not_present(" ".join(prompts)) prompts = [prompt.rstrip("!?.") for prompt in prompts] # removes punctuation from end of prompt tokenizer_output = self.tokenizer.batch_encode_plus( batch_text_or_text_pairs=prompts, add_special_tokens=True, padding='longest', return_tensors='pt' ) if verbose: for prompt in tokenizer_output.input_ids: input_tokens = self.tokenizer.convert_ids_to_tokens(prompt.tolist()) logger.info(f"Input tokens: {input_tokens}") return tokenizer_output
[docs] def assert_special_tokens_not_present(self, prompt: str) -> None: """Check if the input prompts contain special tokens.""" special_tokens = self.tokenizer.all_special_tokens special_tokens.remove(self.tokenizer.mask_token) # [MASK] is allowed if any(map(lambda token: token in prompt, special_tokens)): raise ValueError(f"The input \"{prompt}\" contains at least one special token ({special_tokens})")