diff --git a/tensorflow/lite/micro/kernels/decode.cc b/tensorflow/lite/micro/kernels/decode.cc index 9f4d34cff15..8db24505f3e 100644 --- a/tensorflow/lite/micro/kernels/decode.cc +++ b/tensorflow/lite/micro/kernels/decode.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -49,6 +51,27 @@ TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node, return kTfLiteOk; } +DecodeState* GetDecodeStateFromCustomRegistration(const TfLiteContext* context, + uint8_t type) { + const MicroContext* mc = GetMicroContext(context); + const MicroContext::CustomDecodeRegistration* registrations; + size_t registrations_count; + std::tie(registrations, registrations_count) = + mc->GetCustomDecodeRegistrations(); + if (registrations == nullptr) { + return nullptr; + } + + for (size_t i = 0; i < registrations_count; i++) { + auto& reg = registrations[i]; + if (reg.type == type && reg.create_state != nullptr) { + return reg.create_state(context, mc->GetAlternateProfiler()); + } + } + + return nullptr; +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const size_t num_inputs = NumInputs(node); const size_t num_outputs = NumOutputs(node); @@ -113,21 +136,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { dsp = DecodeState::CreateDecodeStateHuffman( context, micro_context->GetAlternateProfiler()); break; - case DecodeState::kDcmTypeCustom: - MicroPrintf("Custom decode type not yet supported"); - break; default: - MicroPrintf("unsupported decode type %u", - DecodeState::Type(*ancillary)); + uint32_t type = DecodeState::Type(*ancillary); + if (type >= DecodeState::kDcmTypeCustomFirst && + type <= DecodeState::kDcmTypeCustomLast) { + dsp = GetDecodeStateFromCustomRegistration(context, type); + } else { + MicroPrintf("unsupported decode type %u", type); + } break; } - status = SetOutputTensorData(context, node, i / 2, output); - if (status != kTfLiteOk) { - break; - } - if (dsp != nullptr) { + status = SetOutputTensorData(context, node, i / 2, output); + if (status != kTfLiteOk) { + break; + } status = dsp->Setup(*input, *ancillary, *output); if (status != kTfLiteOk) { break; diff --git a/tensorflow/lite/micro/kernels/decode_state.h b/tensorflow/lite/micro/kernels/decode_state.h index 06f821dbc3c..9be36e32de3 100644 --- a/tensorflow/lite/micro/kernels/decode_state.h +++ b/tensorflow/lite/micro/kernels/decode_state.h @@ -72,7 +72,8 @@ class DecodeState { static constexpr uint8_t kDcmTypeLUT = 0; static constexpr uint8_t kDcmTypeHuffman = 1; static constexpr uint8_t kDcmTypePrune = 2; - static constexpr uint8_t kDcmTypeCustom = 127; + static constexpr uint8_t kDcmTypeCustomFirst = 128; + static constexpr uint8_t kDcmTypeCustomLast = 255; static constexpr size_t kDcmSizeInBytes = 16; diff --git a/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc b/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc index 0030b371d14..269bdd17e11 100644 --- a/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc +++ b/tensorflow/lite/micro/kernels/decode_state_huffman_test.cc @@ -271,7 +271,7 @@ TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) { tflite::testing::TestDecode( encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), - nullptr, kTfLiteError); + nullptr, nullptr, kTfLiteError); } TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) { diff --git a/tensorflow/lite/micro/kernels/decode_state_prune_test.cc b/tensorflow/lite/micro/kernels/decode_state_prune_test.cc index 636a5d9a746..955c4008157 100644 --- a/tensorflow/lite/micro/kernels/decode_state_prune_test.cc +++ b/tensorflow/lite/micro/kernels/decode_state_prune_test.cc @@ -575,7 +575,7 @@ TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) { tflite::testing::TestDecode( kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(), - nullptr, kTfLiteError); + nullptr, nullptr, kTfLiteError); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/decode_test.cc b/tensorflow/lite/micro/kernels/decode_test.cc index b07afda1b14..fdf4c1477b8 100644 --- a/tensorflow/lite/micro/kernels/decode_test.cc +++ b/tensorflow/lite/micro/kernels/decode_test.cc @@ -66,6 +66,76 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)}; constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1}; constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5}; +// +// Custom DECODE test data +// +constexpr int kDecodeTypeCustom = 200; + +constexpr int8_t kAncillaryDataCustom[] = {0x42}; + +constexpr uint8_t kDcmCustom[tflite::DecodeState::kDcmSizeInBytes] = { + kDecodeTypeCustom, // type: custom + 1, // DCM version: 1 +}; + +// Align the tensor data the same as a Buffer in the TfLite schema +alignas(16) const uint8_t kEncodedCustom[] = {0x42, 0x43, 0x40, 0x46, + 0x4A, 0x52, 0x62, 0x02}; + +// Tensor shapes as TfLiteIntArray +constexpr int kOutputShapeCustom[] = {1, 8}; +constexpr int kEncodedShapeCustom[] = {1, sizeof(kEncodedCustom)}; + +constexpr int8_t kExpectCustom[] = {0x00, 0x01, 0x02, 0x04, + 0x08, 0x10, 0x20, 0x40}; + +class DecodeStateCustom : public tflite::DecodeState { + public: + DecodeStateCustom() = delete; + + DecodeStateCustom(const TfLiteContext* context, + tflite::MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override { + return kTfLiteOk; + } + + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override { + const uint8_t* inp = tflite::micro::GetTensorData(&input); + TF_LITE_ENSURE(const_cast(context_), inp != nullptr); + uint8_t* outp = tflite::micro::GetTensorData( + const_cast(&output)); + TF_LITE_ENSURE(const_cast(context_), outp != nullptr); + const uint8_t* vp = tflite::micro::GetTensorData(&ancillary); + TF_LITE_ENSURE(const_cast(context_), vp != nullptr); + vp += kDcmSizeInBytes; + + // simple XOR de-obfuscation + std::transform(inp, inp + input.dims->data[0], outp, + [vp](uint8_t i) { return i ^ *vp; }); + + return kTfLiteOk; + } + + static DecodeState* CreateDecodeStateCustom( + const TfLiteContext* context, tflite::MicroProfilerInterface* profiler) { + alignas(4) static uint8_t buffer[sizeof(DecodeStateCustom)]; + DecodeState* instance = new (buffer) DecodeStateCustom(context, profiler); + return instance; + } + + protected: + virtual ~DecodeStateCustom() = default; + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + } // namespace TF_LITE_MICRO_TESTS_BEGIN @@ -246,4 +316,63 @@ TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) { encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr); } +TF_LITE_MICRO_TEST(DecodeWithCustomRegistration) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectCustom)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmCustom}, {kAncillaryDataCustom}}; + + constexpr int kAncillaryShapeCustom[] = {1, sizeof(kAncillaryData)}; + + const TfLiteIntArray* const encoded_dims = + tflite::testing::IntArrayFromInts(kEncodedShapeCustom); + static const TensorInDatum tid_encode = { + kEncodedCustom, + *encoded_dims, + }; + static constexpr std::initializer_list encodes = { + &tid_encode, + }; + + const TfLiteIntArray* const ancillary_dims = + tflite::testing::IntArrayFromInts(kAncillaryShapeCustom); + static const TensorInDatum tid_ancillary = { + &kAncillaryData, + *ancillary_dims, + }; + static constexpr std::initializer_list ancillaries = { + &tid_ancillary}; + + const TfLiteIntArray* const output_dims = + tflite::testing::IntArrayFromInts(kOutputShapeCustom); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum tod = { + output_data, *output_dims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list outputs = { + &tod}; + + const std::initializer_list expected = {kExpectCustom}; + + const std::initializer_list + cdr = { + { + kDecodeTypeCustom, + 0, // reserved + 0, // reserved + 0, // reserved + DecodeStateCustom::CreateDecodeStateCustom, + }, + }; + + tflite::testing::TestDecode( + encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), + nullptr, &cdr); +} + TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/decode_test_helpers.h b/tensorflow/lite/micro/kernels/decode_test_helpers.h index 7afd9fd96af..accd6661052 100644 --- a/tensorflow/lite/micro/kernels/decode_test_helpers.h +++ b/tensorflow/lite/micro/kernels/decode_test_helpers.h @@ -85,6 +85,8 @@ TfLiteStatus ExecuteDecodeTest( TfLiteTensor* tensors, const TFLMRegistration& registration, const std::initializer_list& expected, const std::initializer_list* amr = + nullptr, + const std::initializer_list* cdr = nullptr) { int kInputArrayData[kNumInputs + 1] = {kNumInputs}; for (size_t i = 0; i < kNumInputs; i++) { @@ -105,6 +107,10 @@ TfLiteStatus ExecuteDecodeTest( runner.GetFakeMicroContext()->SetDecompressionMemory(amr->begin(), amr->size()); } + if (cdr != nullptr) { + runner.GetFakeMicroContext()->SetCustomDecodeRegistrations(cdr->begin(), + cdr->size()); + } if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) { return kTfLiteError; @@ -150,6 +156,8 @@ void TestDecode( const TFLMRegistration& registration, const std::initializer_list* amr = nullptr, + const std::initializer_list* cdr = + nullptr, const TfLiteStatus expected_status = kTfLiteOk) { TfLiteTensor tensors[kNumInputs + kNumOutputs] = {}; @@ -183,7 +191,7 @@ void TestDecode( } TfLiteStatus s = ExecuteDecodeTest( - tensors, registration, expected, amr); + tensors, registration, expected, amr, cdr); TF_LITE_MICRO_EXPECT_EQ(s, expected_status); } diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h index 5cb71af7953..84e541f735d 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.h +++ b/tensorflow/lite/micro/kernels/kernel_util.h @@ -23,7 +23,9 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/micro/micro_common.h" #include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_graph.h" #ifdef USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/micro_context.cc b/tensorflow/lite/micro/micro_context.cc index fb03ac08351..acb21688170 100644 --- a/tensorflow/lite/micro/micro_context.cc +++ b/tensorflow/lite/micro/micro_context.cc @@ -174,4 +174,14 @@ void MicroContext::ResetDecompressionMemoryAllocations() { std::fill_n(decompress_regions_allocations_, decompress_regions_size_, 0); } +TfLiteStatus MicroContext::SetCustomDecodeRegistrations( + const CustomDecodeRegistration* registrations, size_t count) { + if (custom_decode_registrations_ != nullptr) { + return kTfLiteError; + } + custom_decode_registrations_ = registrations; + custom_decode_registrations_size_ = count; + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_context.h b/tensorflow/lite/micro/micro_context.h index 0b8120abe1d..73c90ee3a10 100644 --- a/tensorflow/lite/micro/micro_context.h +++ b/tensorflow/lite/micro/micro_context.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_MICRO_MICRO_CONTEXT_H_ #include -#include +#include #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/micro_graph.h" @@ -33,6 +33,8 @@ namespace tflite { // TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus. const TfLiteStatus kTfLiteAbort = static_cast(15); +class DecodeState; // can't use decode_state.h due to circular include + // MicroContext is eventually going to become the API between TFLM and the // kernels, replacing all the functions in TfLiteContext. The end state is code // kernels to have code like: @@ -136,7 +138,7 @@ class MicroContext { }; // Set the alternate decompression memory regions. - // Can only be called during the MicroInterpreter kInit state. + // Can only be called during the kInit state. virtual TfLiteStatus SetDecompressionMemory( const AlternateMemoryRegion* regions, size_t count); @@ -169,12 +171,36 @@ class MicroContext { return nullptr; } + struct CustomDecodeRegistration { + uint8_t type; // custom decode type + uint8_t reserved1; // reserved + uint8_t reserved2; // reserved + uint8_t reserved3; // reserved + tflite::DecodeState* (*create_state)(const TfLiteContext*, + MicroProfilerInterface*); + }; + + // Set the DECODE operator custom registrations. + // Can only be called during the kInit state. + virtual TfLiteStatus SetCustomDecodeRegistrations( + const CustomDecodeRegistration* registrations, size_t count); + + // Get the custom decompression registrations. + virtual const std::pair + GetCustomDecodeRegistrations() const { + return std::make_pair(custom_decode_registrations_, + custom_decode_registrations_size_); + } + private: const AlternateMemoryRegion* decompress_regions_ = nullptr; size_t decompress_regions_size_ = 0; // array of size_t elements with length equal to decompress_regions_size_ size_t* decompress_regions_allocations_ = nullptr; + const CustomDecodeRegistration* custom_decode_registrations_ = nullptr; + size_t custom_decode_registrations_size_ = 0; + TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 871e90de13d..010c652a45c 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -344,4 +344,9 @@ TfLiteStatus MicroInterpreter::SetDecompressionMemory( return micro_context_.SetDecompressionMemory(regions, count); } +TfLiteStatus MicroInterpreter::SetCustomDecodeRegistrations( + const MicroContext::CustomDecodeRegistration* registrations, size_t count) { + return micro_context_.SetCustomDecodeRegistrations(registrations, count); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index adec9ff148b..36b767f1f46 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -171,6 +171,18 @@ class MicroInterpreter { TfLiteStatus SetDecompressionMemory( const MicroContext::AlternateMemoryRegion* regions, size_t count); + // Set the DECODE operator custom registrations. + // Can only be called during the MicroInterpreter kInit state (i.e. must + // be called before MicroInterpreter::AllocateTensors). + // The regions pointer argument is the start of a + // MicroContext::CustomDecodeRegistration array where the length of the array + // is given by the count argument. The lifetime of the + // MicroContext::CustomDecodeRegistration array must be at least that of the + // MicroInterpreter. + TfLiteStatus SetCustomDecodeRegistrations( + const MicroContext::CustomDecodeRegistration* registrations, + size_t count); + protected: const MicroAllocator& allocator() const { return allocator_; } const TfLiteContext& context() const { return context_; } diff --git a/tensorflow/lite/micro/micro_interpreter_context.cc b/tensorflow/lite/micro/micro_interpreter_context.cc index e454509362c..841b7e7e00e 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.cc +++ b/tensorflow/lite/micro/micro_interpreter_context.cc @@ -247,4 +247,12 @@ MicroProfilerInterface* MicroInterpreterContext::GetAlternateProfiler() const { return alt_profiler_; } +TfLiteStatus MicroInterpreterContext::SetCustomDecodeRegistrations( + const CustomDecodeRegistration* registrations, size_t count) { + if (state_ != InterpreterState::kInit) { + return kTfLiteError; + } + return MicroContext::SetCustomDecodeRegistrations(registrations, count); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_interpreter_context.h b/tensorflow/lite/micro/micro_interpreter_context.h index 72a5262f30c..8f18359d7af 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.h +++ b/tensorflow/lite/micro/micro_interpreter_context.h @@ -159,6 +159,11 @@ class MicroInterpreterContext : public MicroContext { // decompression subsystem. MicroProfilerInterface* GetAlternateProfiler() const override; + // Set the DECODE operator custom registrations. + // Can only be called during the kInit state. + virtual TfLiteStatus SetCustomDecodeRegistrations( + const CustomDecodeRegistration* registrations, size_t count) override; + private: MicroAllocator& allocator_; MicroInterpreterGraph& graph_; diff --git a/tensorflow/lite/micro/micro_interpreter_context_test.cc b/tensorflow/lite/micro/micro_interpreter_context_test.cc index ca432a05cc3..9126306b54c 100644 --- a/tensorflow/lite/micro/micro_interpreter_context_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_context_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_interpreter_context.h" #include +#include #include "tensorflow/lite/micro/micro_allocator.h" #include "tensorflow/lite/micro/micro_arena_constants.h" @@ -317,4 +318,58 @@ TF_LITE_MICRO_TEST(TestResetDecompressionMemory) { TF_LITE_MICRO_EXPECT(p == &g_alt_memory[0]); } +TF_LITE_MICRO_TEST(TestSetCustomDecode) { + tflite::MicroInterpreterContext micro_context = + tflite::CreateMicroInterpreterContext(); + + constexpr int kDecodeTypeCustom = 200; + const std::initializer_list + cdr = { + { + kDecodeTypeCustom, + 0, // reserved + 0, // reserved + 0, // reserved + nullptr, // the test won't instantiate tflite::DecodeState + }, + }; + TfLiteStatus status; + + // Test that all of the MicroInterpreterContext fences are correct, by + // forcing the MicroInterpreterContext state. The SetCustomDecodeRegistrations + // method should only be allowed during the kInit state, and can only be + // set once. + + // fail during Prepare state + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kPrepare); + status = micro_context.SetCustomDecodeRegistrations(cdr.begin(), cdr.size()); + TF_LITE_MICRO_EXPECT(status == kTfLiteError); + + // fail during Invoke state + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInvoke); + status = micro_context.SetCustomDecodeRegistrations(cdr.begin(), cdr.size()); + TF_LITE_MICRO_EXPECT(status == kTfLiteError); + + // succeed during Init state + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInit); + status = micro_context.SetCustomDecodeRegistrations(cdr.begin(), cdr.size()); + TF_LITE_MICRO_EXPECT(status == kTfLiteOk); + + // fail on second Init state attempt + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInit); + status = micro_context.SetCustomDecodeRegistrations(cdr.begin(), cdr.size()); + TF_LITE_MICRO_EXPECT(status == kTfLiteError); + + // check registered info. matches + const tflite::MicroContext::CustomDecodeRegistration* registration; + size_t count; + std::tie(registration, count) = micro_context.GetCustomDecodeRegistrations(); + TF_LITE_MICRO_EXPECT(registration == cdr.begin()); + TF_LITE_MICRO_EXPECT(count == 1); +} + TF_LITE_MICRO_TESTS_END