# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
import logging
from typing import Any, List, Union
from transformers import BertTokenizer
TypePrompts = Union[str, List[str]]
logger = logging.getLogger(__name__)
[docs]class TextInput:
"""Text input class that can be used for inference and deployment.
Implements tokenizer related operations and ensure that input strings
conform with the standards expected from a BERT model.
:param tokenizer: A BertTokenizer object.
"""
def __init__(self, tokenizer: BertTokenizer) -> None:
self.tokenizer = tokenizer
[docs] def tokenize_input_prompts(self, prompts: TypePrompts, verbose: bool) -> Any:
"""
Tokenizes the input sentence(s) and adds special tokens as defined by the tokenizer.
:param prompts: Either a string containing a single sentence, or a list of strings each containing
a single sentence. Note that this method will not correctly tokenize multiple sentences if they
are input as a single string.
:param verbose: If set to True, will log the sentence after tokenization.
:return: A 2D tensor containing the tokenized sentences
"""
prompts = [prompts] if isinstance(prompts, str) else prompts
self.assert_special_tokens_not_present(" ".join(prompts))
prompts = [prompt.rstrip("!?.") for prompt in prompts] # removes punctuation from end of prompt
tokenizer_output = self.tokenizer.batch_encode_plus(
batch_text_or_text_pairs=prompts, add_special_tokens=True, padding='longest', return_tensors='pt'
)
if verbose:
for prompt in tokenizer_output.input_ids:
input_tokens = self.tokenizer.convert_ids_to_tokens(prompt.tolist())
logger.info(f"Input tokens: {input_tokens}")
return tokenizer_output
[docs] def assert_special_tokens_not_present(self, prompt: str) -> None:
"""Check if the input prompts contain special tokens."""
special_tokens = self.tokenizer.all_special_tokens
special_tokens.remove(self.tokenizer.mask_token) # [MASK] is allowed
if any(map(lambda token: token in prompt, special_tokens)):
raise ValueError(f"The input \"{prompt}\" contains at least one special token ({special_tokens})")