Skip to content

Commit 612a65f

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Split TextPreprocessor into RegexPreprocessor and BertPreprocessor
The previous TextPreprocessor contains both tokenizer support for RegexTokenizer and BertTokenizer. When migrating NLClassifier using TextPreprocessor, clients got pulled in extra dependencies for BertTokenizer. This is not ideal for clients who care about binary size. Therefore, TextPreprocessor is split into RegexPreprocessor and BertPreprocessor, and RegexPreprocessor will be used as the implementation for NLClassifier. PiperOrigin-RevId: 412955296
1 parent 728ac00 commit 612a65f

File tree

7 files changed

+568
-284
lines changed

7 files changed

+568
-284
lines changed

tensorflow_lite_support/cc/task/processor/BUILD

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,69 @@ cc_library_with_tflite(
101101
)
102102

103103
cc_library_with_tflite(
104-
name = "text_preprocessor",
105-
srcs = ["text_preprocessor.cc"],
106-
hdrs = ["text_preprocessor.h"],
104+
name = "regex_preprocessor",
105+
srcs = ["regex_preprocessor.cc"],
106+
hdrs = ["regex_preprocessor.h"],
107107
tflite_deps = [
108-
":processor",
109-
"//tensorflow_lite_support/cc/task/core:tflite_engine",
108+
":text_preprocessor_header",
110109
],
111110
deps = [
112111
"//tensorflow_lite_support/cc:common",
113112
"//tensorflow_lite_support/cc/port:status_macros",
114113
"//tensorflow_lite_support/cc/port:statusor",
115114
"//tensorflow_lite_support/cc/task/core:task_utils",
116115
"//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer",
116+
"@com_google_absl//absl/status",
117+
],
118+
)
119+
120+
cc_library_with_tflite(
121+
name = "bert_preprocessor",
122+
srcs = ["bert_preprocessor.cc"],
123+
hdrs = ["bert_preprocessor.h"],
124+
tflite_deps = [
125+
":text_preprocessor_header",
126+
],
127+
deps = [
128+
"//tensorflow_lite_support/cc:common",
129+
"//tensorflow_lite_support/cc/port:status_macros",
130+
"//tensorflow_lite_support/cc/port:statusor",
131+
"//tensorflow_lite_support/cc/task/core:task_utils",
117132
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
118133
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",
119134
"//tensorflow_lite_support/cc/utils:common_utils",
120-
"@com_google_absl//absl/memory",
121135
"@com_google_absl//absl/status",
122136
"@com_google_absl//absl/strings",
123-
"@com_google_absl//absl/strings:str_format",
137+
],
138+
)
139+
140+
cc_library_with_tflite(
141+
name = "text_preprocessor_header",
142+
hdrs = ["text_preprocessor.h"],
143+
tflite_deps = [
144+
":processor",
145+
"//tensorflow_lite_support/cc/task/core:tflite_engine",
146+
],
147+
deps = [
148+
"//tensorflow_lite_support/cc/port:statusor",
149+
"@com_google_absl//absl/status",
150+
],
151+
)
152+
153+
cc_library_with_tflite(
154+
name = "text_preprocessor",
155+
srcs = ["text_preprocessor.cc"],
156+
hdrs = ["text_preprocessor.h"],
157+
tflite_deps = [
158+
":processor",
159+
":bert_preprocessor",
160+
":regex_preprocessor",
161+
"//tensorflow_lite_support/cc/task/core:tflite_engine",
162+
],
163+
deps = [
164+
"//tensorflow_lite_support/cc:common",
165+
"//tensorflow_lite_support/cc/port:status_macros",
166+
"//tensorflow_lite_support/cc/port:statusor",
167+
"@com_google_absl//absl/status",
124168
],
125169
)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow_lite_support/cc/task/processor/bert_preprocessor.h"
16+
17+
#include "absl/status/status.h" // from @com_google_absl
18+
#include "absl/strings/ascii.h" // from @com_google_absl
19+
#include "tensorflow_lite_support/cc/common.h"
20+
#include "tensorflow_lite_support/cc/port/status_macros.h"
21+
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
22+
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
23+
#include "tensorflow_lite_support/cc/utils/common_utils.h"
24+
25+
namespace tflite {
26+
namespace task {
27+
namespace processor {
28+
29+
using ::absl::StatusCode;
30+
using ::tflite::support::CreateStatusWithPayload;
31+
using ::tflite::support::StatusOr;
32+
using ::tflite::support::TfLiteSupportStatus;
33+
using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
34+
using ::tflite::support::text::tokenizer::TokenizerResult;
35+
using ::tflite::task::core::FindIndexByMetadataTensorName;
36+
using ::tflite::task::core::PopulateTensor;
37+
38+
constexpr int kTokenizerProcessUnitIndex = 0;
39+
constexpr char kIdsTensorName[] = "ids";
40+
constexpr char kMaskTensorName[] = "mask";
41+
constexpr char kSegmentIdsTensorName[] = "segment_ids";
42+
constexpr char kClassificationToken[] = "[CLS]";
43+
constexpr char kSeparator[] = "[SEP]";
44+
45+
/* static */
46+
StatusOr<std::unique_ptr<BertPreprocessor>> BertPreprocessor::Create(
47+
tflite::task::core::TfLiteEngine* engine,
48+
const std::initializer_list<int> input_tensor_indices) {
49+
ASSIGN_OR_RETURN(auto processor, Processor::Create<BertPreprocessor>(
50+
/* num_expected_tensors = */ 3, engine,
51+
input_tensor_indices,
52+
/* requires_metadata = */ false));
53+
RETURN_IF_ERROR(processor->Init());
54+
return processor;
55+
}
56+
57+
absl::Status BertPreprocessor::Init() {
58+
// Try if RegexTokenzier can be found.
59+
// BertTokenzier is packed in the processing unit of the InputTensors in
60+
// SubgraphMetadata.
61+
const tflite::ProcessUnit* tokenzier_metadata =
62+
GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
63+
// Identify the tensor index for three Bert input tensors.
64+
auto tensors_metadata = GetMetadataExtractor()->GetInputTensorMetadata();
65+
int ids_tensor_index =
66+
FindIndexByMetadataTensorName(tensors_metadata, kIdsTensorName);
67+
ids_tensor_index_ =
68+
ids_tensor_index == -1 ? tensor_indices_[0] : ids_tensor_index;
69+
int mask_tensor_index =
70+
FindIndexByMetadataTensorName(tensors_metadata, kMaskTensorName);
71+
mask_tensor_index_ =
72+
mask_tensor_index == -1 ? tensor_indices_[1] : mask_tensor_index;
73+
int segment_ids_tensor_index =
74+
FindIndexByMetadataTensorName(tensors_metadata, kSegmentIdsTensorName);
75+
segment_ids_tensor_index_ = segment_ids_tensor_index == -1
76+
? tensor_indices_[2]
77+
: segment_ids_tensor_index;
78+
79+
if (GetLastDimSize(ids_tensor_index_) != GetLastDimSize(mask_tensor_index_) ||
80+
GetLastDimSize(ids_tensor_index_) !=
81+
GetLastDimSize(segment_ids_tensor_index_)) {
82+
return CreateStatusWithPayload(
83+
absl::StatusCode::kInternal,
84+
absl::StrFormat("The three input tensors in Bert models are "
85+
"expected to have same length, but got ids_tensor "
86+
"(%d), mask_tensor (%d), segment_ids_tensor (%d).",
87+
GetLastDimSize(ids_tensor_index_),
88+
GetLastDimSize(mask_tensor_index_),
89+
GetLastDimSize(segment_ids_tensor_index_)),
90+
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
91+
}
92+
bert_max_seq_len_ = GetLastDimSize(ids_tensor_index_);
93+
94+
ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromProcessUnit(
95+
tokenzier_metadata, GetMetadataExtractor()));
96+
return absl::OkStatus();
97+
}
98+
99+
absl::Status BertPreprocessor::Preprocess(const std::string& input_text) {
100+
auto* ids_tensor =
101+
engine_->GetInput(engine_->interpreter(), ids_tensor_index_);
102+
auto* mask_tensor =
103+
engine_->GetInput(engine_->interpreter(), mask_tensor_index_);
104+
auto* segment_ids_tensor =
105+
engine_->GetInput(engine_->interpreter(), segment_ids_tensor_index_);
106+
107+
std::string processed_input = input_text;
108+
absl::AsciiStrToLower(&processed_input);
109+
110+
TokenizerResult input_tokenize_results;
111+
input_tokenize_results = tokenizer_->Tokenize(processed_input);
112+
113+
// 2 accounts for [CLS], [SEP]
114+
absl::Span<const std::string> query_tokens =
115+
absl::MakeSpan(input_tokenize_results.subwords.data(),
116+
input_tokenize_results.subwords.data() +
117+
std::min(static_cast<size_t>(bert_max_seq_len_ - 2),
118+
input_tokenize_results.subwords.size()));
119+
120+
std::vector<std::string> tokens;
121+
tokens.reserve(2 + query_tokens.size());
122+
// Start of generating the features.
123+
tokens.push_back(kClassificationToken);
124+
// For query input.
125+
for (const auto& query_token : query_tokens) {
126+
tokens.push_back(query_token);
127+
}
128+
// For Separation.
129+
tokens.push_back(kSeparator);
130+
131+
std::vector<int> input_ids(bert_max_seq_len_, 0);
132+
std::vector<int> input_mask(bert_max_seq_len_, 0);
133+
// Convert tokens back into ids and set mask
134+
for (int i = 0; i < tokens.size(); ++i) {
135+
tokenizer_->LookupId(tokens[i], &input_ids[i]);
136+
input_mask[i] = 1;
137+
}
138+
// |<--------bert_max_seq_len_--------->|
139+
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
140+
// input_masks 1 1 1... 1 1 0 0... 0
141+
// segment_ids 0 0 0... 0 0 0 0... 0
142+
143+
RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor));
144+
RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor));
145+
RETURN_IF_ERROR(PopulateTensor(std::vector<int>(bert_max_seq_len_, 0),
146+
segment_ids_tensor));
147+
return absl::OkStatus();
148+
}
149+
150+
int BertPreprocessor::GetLastDimSize(int tensor_index) {
151+
auto tensor = engine_->GetInput(engine_->interpreter(), tensor_index);
152+
return tensor->dims->data[tensor->dims->size - 1];
153+
}
154+
155+
} // namespace processor
156+
} // namespace task
157+
} // namespace tflite
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_BERT_PREPROCESOR_H_
16+
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_BERT_PREPROCESOR_H_
17+
18+
#include "absl/status/status.h" // from @com_google_absl
19+
#include "tensorflow_lite_support/cc/port/statusor.h"
20+
#include "tensorflow_lite_support/cc/task/processor/text_preprocessor.h"
21+
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
22+
23+
namespace tflite {
24+
namespace task {
25+
namespace processor {
26+
27+
// Processes input text and populates the associated bert input tensors.
28+
// Requirements for the input tensors:
29+
// - The 3 input tensors should be populated with the metadata tensor names,
30+
// "ids", "mask", and "segment_ids", respectively.
31+
// - The input_process_units metadata should contain WordPiece or
32+
// Sentencepiece Tokenizer metadata.
33+
class BertPreprocessor : public TextPreprocessor {
34+
public:
35+
static tflite::support::StatusOr<std::unique_ptr<BertPreprocessor>> Create(
36+
tflite::task::core::TfLiteEngine* engine,
37+
const std::initializer_list<int> input_tensor_indices);
38+
39+
absl::Status Preprocess(const std::string& text);
40+
41+
private:
42+
using TextPreprocessor::TextPreprocessor;
43+
44+
absl::Status Init();
45+
46+
int GetLastDimSize(int tensor_index);
47+
48+
std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
49+
int ids_tensor_index_;
50+
int mask_tensor_index_;
51+
int segment_ids_tensor_index_;
52+
int bert_max_seq_len_;
53+
};
54+
55+
} // namespace processor
56+
} // namespace task
57+
} // namespace tflite
58+
59+
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_BERT_PREPROCESOR_H_

0 commit comments

Comments
 (0)