|
| 1 | +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | +#include "tensorflow_lite_support/cc/task/processor/bert_preprocessor.h" |
| 16 | + |
| 17 | +#include "absl/status/status.h" // from @com_google_absl |
| 18 | +#include "absl/strings/ascii.h" // from @com_google_absl |
| 19 | +#include "tensorflow_lite_support/cc/common.h" |
| 20 | +#include "tensorflow_lite_support/cc/port/status_macros.h" |
| 21 | +#include "tensorflow_lite_support/cc/task/core/task_utils.h" |
| 22 | +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" |
| 23 | +#include "tensorflow_lite_support/cc/utils/common_utils.h" |
| 24 | + |
| 25 | +namespace tflite { |
| 26 | +namespace task { |
| 27 | +namespace processor { |
| 28 | + |
| 29 | +using ::absl::StatusCode; |
| 30 | +using ::tflite::support::CreateStatusWithPayload; |
| 31 | +using ::tflite::support::StatusOr; |
| 32 | +using ::tflite::support::TfLiteSupportStatus; |
| 33 | +using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; |
| 34 | +using ::tflite::support::text::tokenizer::TokenizerResult; |
| 35 | +using ::tflite::task::core::FindIndexByMetadataTensorName; |
| 36 | +using ::tflite::task::core::PopulateTensor; |
| 37 | + |
| 38 | +constexpr int kTokenizerProcessUnitIndex = 0; |
| 39 | +constexpr char kIdsTensorName[] = "ids"; |
| 40 | +constexpr char kMaskTensorName[] = "mask"; |
| 41 | +constexpr char kSegmentIdsTensorName[] = "segment_ids"; |
| 42 | +constexpr char kClassificationToken[] = "[CLS]"; |
| 43 | +constexpr char kSeparator[] = "[SEP]"; |
| 44 | + |
| 45 | +/* static */ |
| 46 | +StatusOr<std::unique_ptr<BertPreprocessor>> BertPreprocessor::Create( |
| 47 | + tflite::task::core::TfLiteEngine* engine, |
| 48 | + const std::initializer_list<int> input_tensor_indices) { |
| 49 | + ASSIGN_OR_RETURN(auto processor, Processor::Create<BertPreprocessor>( |
| 50 | + /* num_expected_tensors = */ 3, engine, |
| 51 | + input_tensor_indices, |
| 52 | + /* requires_metadata = */ false)); |
| 53 | + RETURN_IF_ERROR(processor->Init()); |
| 54 | + return processor; |
| 55 | +} |
| 56 | + |
| 57 | +absl::Status BertPreprocessor::Init() { |
| 58 | + // Try if RegexTokenzier can be found. |
| 59 | + // BertTokenzier is packed in the processing unit of the InputTensors in |
| 60 | + // SubgraphMetadata. |
| 61 | + const tflite::ProcessUnit* tokenzier_metadata = |
| 62 | + GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); |
| 63 | + // Identify the tensor index for three Bert input tensors. |
| 64 | + auto tensors_metadata = GetMetadataExtractor()->GetInputTensorMetadata(); |
| 65 | + int ids_tensor_index = |
| 66 | + FindIndexByMetadataTensorName(tensors_metadata, kIdsTensorName); |
| 67 | + ids_tensor_index_ = |
| 68 | + ids_tensor_index == -1 ? tensor_indices_[0] : ids_tensor_index; |
| 69 | + int mask_tensor_index = |
| 70 | + FindIndexByMetadataTensorName(tensors_metadata, kMaskTensorName); |
| 71 | + mask_tensor_index_ = |
| 72 | + mask_tensor_index == -1 ? tensor_indices_[1] : mask_tensor_index; |
| 73 | + int segment_ids_tensor_index = |
| 74 | + FindIndexByMetadataTensorName(tensors_metadata, kSegmentIdsTensorName); |
| 75 | + segment_ids_tensor_index_ = segment_ids_tensor_index == -1 |
| 76 | + ? tensor_indices_[2] |
| 77 | + : segment_ids_tensor_index; |
| 78 | + |
| 79 | + if (GetLastDimSize(ids_tensor_index_) != GetLastDimSize(mask_tensor_index_) || |
| 80 | + GetLastDimSize(ids_tensor_index_) != |
| 81 | + GetLastDimSize(segment_ids_tensor_index_)) { |
| 82 | + return CreateStatusWithPayload( |
| 83 | + absl::StatusCode::kInternal, |
| 84 | + absl::StrFormat("The three input tensors in Bert models are " |
| 85 | + "expected to have same length, but got ids_tensor " |
| 86 | + "(%d), mask_tensor (%d), segment_ids_tensor (%d).", |
| 87 | + GetLastDimSize(ids_tensor_index_), |
| 88 | + GetLastDimSize(mask_tensor_index_), |
| 89 | + GetLastDimSize(segment_ids_tensor_index_)), |
| 90 | + TfLiteSupportStatus::kInvalidNumOutputTensorsError); |
| 91 | + } |
| 92 | + bert_max_seq_len_ = GetLastDimSize(ids_tensor_index_); |
| 93 | + |
| 94 | + ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromProcessUnit( |
| 95 | + tokenzier_metadata, GetMetadataExtractor())); |
| 96 | + return absl::OkStatus(); |
| 97 | +} |
| 98 | + |
| 99 | +absl::Status BertPreprocessor::Preprocess(const std::string& input_text) { |
| 100 | + auto* ids_tensor = |
| 101 | + engine_->GetInput(engine_->interpreter(), ids_tensor_index_); |
| 102 | + auto* mask_tensor = |
| 103 | + engine_->GetInput(engine_->interpreter(), mask_tensor_index_); |
| 104 | + auto* segment_ids_tensor = |
| 105 | + engine_->GetInput(engine_->interpreter(), segment_ids_tensor_index_); |
| 106 | + |
| 107 | + std::string processed_input = input_text; |
| 108 | + absl::AsciiStrToLower(&processed_input); |
| 109 | + |
| 110 | + TokenizerResult input_tokenize_results; |
| 111 | + input_tokenize_results = tokenizer_->Tokenize(processed_input); |
| 112 | + |
| 113 | + // 2 accounts for [CLS], [SEP] |
| 114 | + absl::Span<const std::string> query_tokens = |
| 115 | + absl::MakeSpan(input_tokenize_results.subwords.data(), |
| 116 | + input_tokenize_results.subwords.data() + |
| 117 | + std::min(static_cast<size_t>(bert_max_seq_len_ - 2), |
| 118 | + input_tokenize_results.subwords.size())); |
| 119 | + |
| 120 | + std::vector<std::string> tokens; |
| 121 | + tokens.reserve(2 + query_tokens.size()); |
| 122 | + // Start of generating the features. |
| 123 | + tokens.push_back(kClassificationToken); |
| 124 | + // For query input. |
| 125 | + for (const auto& query_token : query_tokens) { |
| 126 | + tokens.push_back(query_token); |
| 127 | + } |
| 128 | + // For Separation. |
| 129 | + tokens.push_back(kSeparator); |
| 130 | + |
| 131 | + std::vector<int> input_ids(bert_max_seq_len_, 0); |
| 132 | + std::vector<int> input_mask(bert_max_seq_len_, 0); |
| 133 | + // Convert tokens back into ids and set mask |
| 134 | + for (int i = 0; i < tokens.size(); ++i) { |
| 135 | + tokenizer_->LookupId(tokens[i], &input_ids[i]); |
| 136 | + input_mask[i] = 1; |
| 137 | + } |
| 138 | + // |<--------bert_max_seq_len_--------->| |
| 139 | + // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 |
| 140 | + // input_masks 1 1 1... 1 1 0 0... 0 |
| 141 | + // segment_ids 0 0 0... 0 0 0 0... 0 |
| 142 | + |
| 143 | + RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor)); |
| 144 | + RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor)); |
| 145 | + RETURN_IF_ERROR(PopulateTensor(std::vector<int>(bert_max_seq_len_, 0), |
| 146 | + segment_ids_tensor)); |
| 147 | + return absl::OkStatus(); |
| 148 | +} |
| 149 | + |
| 150 | +int BertPreprocessor::GetLastDimSize(int tensor_index) { |
| 151 | + auto tensor = engine_->GetInput(engine_->interpreter(), tensor_index); |
| 152 | + return tensor->dims->data[tensor->dims->size - 1]; |
| 153 | +} |
| 154 | + |
| 155 | +} // namespace processor |
| 156 | +} // namespace task |
| 157 | +} // namespace tflite |
0 commit comments