Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions tensorflow/lite/micro/kernels/decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
Expand Down Expand Up @@ -49,6 +51,27 @@ TfLiteStatus SetOutputTensorData(TfLiteContext* context, const TfLiteNode* node,
return kTfLiteOk;
}

DecodeState* GetDecodeStateFromCustomRegistration(const TfLiteContext* context,
uint8_t type) {
const MicroContext* mc = GetMicroContext(context);
const MicroContext::CustomDecodeRegistration* registrations;
size_t registrations_count;
std::tie(registrations, registrations_count) =
mc->GetCustomDecodeRegistrations();
if (registrations == nullptr) {
return nullptr;
}

for (size_t i = 0; i < registrations_count; i++) {
auto& reg = registrations[i];
if (reg.type == type && reg.create_state != nullptr) {
return reg.create_state(context, mc->GetAlternateProfiler());
}
}

return nullptr;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
const size_t num_outputs = NumOutputs(node);
Expand Down Expand Up @@ -113,21 +136,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
dsp = DecodeState::CreateDecodeStateHuffman(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeCustom:
MicroPrintf("Custom decode type not yet supported");
break;
default:
MicroPrintf("unsupported decode type %u",
DecodeState::Type(*ancillary));
uint32_t type = DecodeState::Type(*ancillary);
if (type >= DecodeState::kDcmTypeCustomFirst &&
type <= DecodeState::kDcmTypeCustomLast) {
dsp = GetDecodeStateFromCustomRegistration(context, type);
} else {
MicroPrintf("unsupported decode type %u", type);
}
break;
}

status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}

if (dsp != nullptr) {
status = SetOutputTensorData(context, node, i / 2, output);
if (status != kTfLiteOk) {
break;
}
status = dsp->Setup(*input, *ancillary, *output);
if (status != kTfLiteOk) {
break;
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/micro/kernels/decode_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class DecodeState {
static constexpr uint8_t kDcmTypeLUT = 0;
static constexpr uint8_t kDcmTypeHuffman = 1;
static constexpr uint8_t kDcmTypePrune = 2;
static constexpr uint8_t kDcmTypeCustom = 127;
static constexpr uint8_t kDcmTypeCustomFirst = 128;
static constexpr uint8_t kDcmTypeCustomLast = 255;

static constexpr size_t kDcmSizeInBytes = 16;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ TF_LITE_MICRO_TEST(DecodeHuffmanTable16BitsInt16Fail) {
tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TEST(DecodeHuffmanTable32BitsInt8) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/kernels/decode_state_prune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) {
tflite::testing::TestDecode<kEncodes.size() + kAncillaries.size(),
kOutputs.size()>(
kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(),
nullptr, kTfLiteError);
nullptr, nullptr, kTfLiteError);
}

TF_LITE_MICRO_TESTS_END
129 changes: 129 additions & 0 deletions tensorflow/lite/micro/kernels/decode_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,76 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)};
constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1};
constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5};

//
// Custom DECODE test data
//
constexpr int kDecodeTypeCustom = 200;

constexpr int8_t kAncillaryDataCustom[] = {0x42};

constexpr uint8_t kDcmCustom[tflite::DecodeState::kDcmSizeInBytes] = {
kDecodeTypeCustom, // type: custom
1, // DCM version: 1
};

// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) const uint8_t kEncodedCustom[] = {0x42, 0x43, 0x40, 0x46,
0x4A, 0x52, 0x62, 0x02};

// Tensor shapes as TfLiteIntArray
constexpr int kOutputShapeCustom[] = {1, 8};
constexpr int kEncodedShapeCustom[] = {1, sizeof(kEncodedCustom)};

constexpr int8_t kExpectCustom[] = {0x00, 0x01, 0x02, 0x04,
0x08, 0x10, 0x20, 0x40};

class DecodeStateCustom : public tflite::DecodeState {
public:
DecodeStateCustom() = delete;

DecodeStateCustom(const TfLiteContext* context,
tflite::MicroProfilerInterface* profiler)
: DecodeState(context, profiler) {}

virtual TfLiteStatus Setup(const TfLiteTensor& input,
const TfLiteTensor& ancillary,
const TfLiteTensor& output) override {
return kTfLiteOk;
}

virtual TfLiteStatus Decode(const TfLiteEvalTensor& input,
const TfLiteEvalTensor& ancillary,
const TfLiteEvalTensor& output) override {
const uint8_t* inp = tflite::micro::GetTensorData<uint8_t>(&input);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), inp != nullptr);
uint8_t* outp = tflite::micro::GetTensorData<uint8_t>(
const_cast<TfLiteEvalTensor*>(&output));
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), outp != nullptr);
const uint8_t* vp = tflite::micro::GetTensorData<uint8_t>(&ancillary);
TF_LITE_ENSURE(const_cast<TfLiteContext*>(context_), vp != nullptr);
vp += kDcmSizeInBytes;

// simple XOR de-obfuscation
std::transform(inp, inp + input.dims->data[0], outp,
[vp](uint8_t i) { return i ^ *vp; });

return kTfLiteOk;
}

static DecodeState* CreateDecodeStateCustom(
const TfLiteContext* context, tflite::MicroProfilerInterface* profiler) {
alignas(4) static uint8_t buffer[sizeof(DecodeStateCustom)];
DecodeState* instance = new (buffer) DecodeStateCustom(context, profiler);
return instance;
}

protected:
virtual ~DecodeStateCustom() = default;

private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};

} // namespace

TF_LITE_MICRO_TESTS_BEGIN
Expand Down Expand Up @@ -246,4 +316,63 @@ TF_LITE_MICRO_TEST(DecodeWithAltDecompressionMemory) {
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(), &amr);
}

TF_LITE_MICRO_TEST(DecodeWithCustomRegistration) {
// Align the tensor data the same as a Buffer in the TfLite schema
alignas(16) int8_t output_data[std::size(kExpectCustom)] = {};
alignas(16) const AncillaryData<int8_t, std::size(kAncillaryDataCustom)>
kAncillaryData = {{kDcmCustom}, {kAncillaryDataCustom}};

constexpr int kAncillaryShapeCustom[] = {1, sizeof(kAncillaryData)};

const TfLiteIntArray* const encoded_dims =
tflite::testing::IntArrayFromInts(kEncodedShapeCustom);
static const TensorInDatum tid_encode = {
kEncodedCustom,
*encoded_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> encodes = {
&tid_encode,
};

const TfLiteIntArray* const ancillary_dims =
tflite::testing::IntArrayFromInts(kAncillaryShapeCustom);
static const TensorInDatum tid_ancillary = {
&kAncillaryData,
*ancillary_dims,
};
static constexpr std::initializer_list<const TensorInDatum*> ancillaries = {
&tid_ancillary};

const TfLiteIntArray* const output_dims =
tflite::testing::IntArrayFromInts(kOutputShapeCustom);
constexpr int kOutputZeroPointsData[] = {0};
const TfLiteIntArray* const kOutputZeroPoints =
tflite::testing::IntArrayFromInts(kOutputZeroPointsData);
const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size};
static const TensorOutDatum tod = {
output_data, *output_dims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints,
0, {},
};
static constexpr std::initializer_list<const TensorOutDatum*> outputs = {
&tod};

const std::initializer_list<const void*> expected = {kExpectCustom};

const std::initializer_list<tflite::MicroContext::CustomDecodeRegistration>
cdr = {
{
kDecodeTypeCustom,
0, // reserved
0, // reserved
0, // reserved
DecodeStateCustom::CreateDecodeStateCustom,
},
};

tflite::testing::TestDecode<encodes.size() + ancillaries.size(),
outputs.size()>(
encodes, ancillaries, outputs, expected, tflite::Register_DECODE(),
nullptr, &cdr);
}

TF_LITE_MICRO_TESTS_END
10 changes: 9 additions & 1 deletion tensorflow/lite/micro/kernels/decode_test_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ TfLiteStatus ExecuteDecodeTest(
TfLiteTensor* tensors, const TFLMRegistration& registration,
const std::initializer_list<const void*>& expected,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr) {
int kInputArrayData[kNumInputs + 1] = {kNumInputs};
for (size_t i = 0; i < kNumInputs; i++) {
Expand All @@ -105,6 +107,10 @@ TfLiteStatus ExecuteDecodeTest(
runner.GetFakeMicroContext()->SetDecompressionMemory(amr->begin(),
amr->size());
}
if (cdr != nullptr) {
runner.GetFakeMicroContext()->SetCustomDecodeRegistrations(cdr->begin(),
cdr->size());
}

if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) {
return kTfLiteError;
Expand Down Expand Up @@ -150,6 +156,8 @@ void TestDecode(
const TFLMRegistration& registration,
const std::initializer_list<MicroContext::AlternateMemoryRegion>* amr =
nullptr,
const std::initializer_list<MicroContext::CustomDecodeRegistration>* cdr =
nullptr,
const TfLiteStatus expected_status = kTfLiteOk) {
TfLiteTensor tensors[kNumInputs + kNumOutputs] = {};

Expand Down Expand Up @@ -183,7 +191,7 @@ void TestDecode(
}

TfLiteStatus s = ExecuteDecodeTest<kNumInputs, kNumOutputs>(
tensors, registration, expected, amr);
tensors, registration, expected, amr, cdr);
TF_LITE_MICRO_EXPECT_EQ(s, expected_status);
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_graph.h"

#ifdef USE_TFLM_COMPRESSION

Expand Down
10 changes: 10 additions & 0 deletions tensorflow/lite/micro/micro_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,14 @@ void MicroContext::ResetDecompressionMemoryAllocations() {
std::fill_n(decompress_regions_allocations_, decompress_regions_size_, 0);
}

TfLiteStatus MicroContext::SetCustomDecodeRegistrations(
const CustomDecodeRegistration* registrations, size_t count) {
if (custom_decode_registrations_ != nullptr) {
return kTfLiteError;
}
custom_decode_registrations_ = registrations;
custom_decode_registrations_size_ = count;
return kTfLiteOk;
}

} // namespace tflite
30 changes: 28 additions & 2 deletions tensorflow/lite/micro/micro_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_LITE_MICRO_MICRO_CONTEXT_H_

#include <cstddef>
#include <initializer_list>
#include <utility>

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_graph.h"
Expand All @@ -33,6 +33,8 @@ namespace tflite {
// TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus.
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(15);

class DecodeState; // can't use decode_state.h due to circular include

// MicroContext is eventually going to become the API between TFLM and the
// kernels, replacing all the functions in TfLiteContext. The end state is code
// kernels to have code like:
Expand Down Expand Up @@ -136,7 +138,7 @@ class MicroContext {
};

// Set the alternate decompression memory regions.
// Can only be called during the MicroInterpreter kInit state.
// Can only be called during the kInit state.
virtual TfLiteStatus SetDecompressionMemory(
const AlternateMemoryRegion* regions, size_t count);

Expand Down Expand Up @@ -169,12 +171,36 @@ class MicroContext {
return nullptr;
}

struct CustomDecodeRegistration {
uint8_t type; // custom decode type
uint8_t reserved1; // reserved
uint8_t reserved2; // reserved
uint8_t reserved3; // reserved
tflite::DecodeState* (*create_state)(const TfLiteContext*,
MicroProfilerInterface*);
};

// Set the DECODE operator custom registrations.
// Can only be called during the kInit state.
virtual TfLiteStatus SetCustomDecodeRegistrations(
const CustomDecodeRegistration* registrations, size_t count);

// Get the custom decompression registrations.
virtual const std::pair<const CustomDecodeRegistration*, size_t /*count*/>
GetCustomDecodeRegistrations() const {
return std::make_pair(custom_decode_registrations_,
custom_decode_registrations_size_);
}

private:
const AlternateMemoryRegion* decompress_regions_ = nullptr;
size_t decompress_regions_size_ = 0;
// array of size_t elements with length equal to decompress_regions_size_
size_t* decompress_regions_allocations_ = nullptr;

const CustomDecodeRegistration* custom_decode_registrations_ = nullptr;
size_t custom_decode_registrations_size_ = 0;

TF_LITE_REMOVE_VIRTUAL_DELETE
};

Expand Down
5 changes: 5 additions & 0 deletions tensorflow/lite/micro/micro_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,9 @@ TfLiteStatus MicroInterpreter::SetDecompressionMemory(
return micro_context_.SetDecompressionMemory(regions, count);
}

TfLiteStatus MicroInterpreter::SetCustomDecodeRegistrations(
const MicroContext::CustomDecodeRegistration* registrations, size_t count) {
return micro_context_.SetCustomDecodeRegistrations(registrations, count);
}

} // namespace tflite
12 changes: 12 additions & 0 deletions tensorflow/lite/micro/micro_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ class MicroInterpreter {
TfLiteStatus SetDecompressionMemory(
const MicroContext::AlternateMemoryRegion* regions, size_t count);

// Set the DECODE operator custom registrations.
// Can only be called during the MicroInterpreter kInit state (i.e. must
// be called before MicroInterpreter::AllocateTensors).
// The regions pointer argument is the start of a
// MicroContext::CustomDecodeRegistration array where the length of the array
// is given by the count argument. The lifetime of the
// MicroContext::CustomDecodeRegistration array must be at least that of the
// MicroInterpreter.
TfLiteStatus SetCustomDecodeRegistrations(
const MicroContext::CustomDecodeRegistration* registrations,
size_t count);

protected:
const MicroAllocator& allocator() const { return allocator_; }
const TfLiteContext& context() const { return context_; }
Expand Down
Loading
Loading