Skip to content

Commit

Permalink
Re-add support for Embedding and CategoryEncoding layers
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Apr 16, 2024
1 parent 2ad8e09 commit 2b43bc2
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 4 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Would you like to build/train a model using Keras/Python? And would you like to

* `Add`, `Concatenate`, `Subtract`, `Multiply`, `Average`, `Maximum`, `Minimum`, `Dot`
* `AveragePooling1D/2D/3D`, `GlobalAveragePooling1D/2D/3D`
* `TimeDistributed`
* `TimeDistributed`, `Embedding`, `CategoryEncoding`
* `Conv1D/2D`, `SeparableConv2D`, `DepthwiseConv2D`
* `Cropping1D/2D/3D`, `ZeroPadding1D/2D/3D`, `CenterCrop`
* `BatchNormalization`, `Dense`, `Flatten`, `Normalization`
Expand Down Expand Up @@ -80,7 +80,6 @@ Would you like to build/train a model using Keras/Python? And would you like to
`LSTMCell`, `Masking`,
`RepeatVector`, `RNN`, `SimpleRNN`,
`SimpleRNNCell`, `StackedRNNCells`, `StringLookup`, `TextVectorization`,
`Embedding`, `CategoryEncoding`,
`Bidirectional`, `GRU`, `LSTM`, `CuDNNGRU`, `CuDNNLSTM`,
`ThresholdedReLU`, `Upsampling3D`, `temporal` models

Expand Down
24 changes: 24 additions & 0 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "fdeep/layers/average_layer.hpp"
#include "fdeep/layers/average_pooling_3d_layer.hpp"
#include "fdeep/layers/batch_normalization_layer.hpp"
#include "fdeep/layers/category_encoding_layer.hpp"
#include "fdeep/layers/centercrop_layer.hpp"
#include "fdeep/layers/concatenate_layer.hpp"
#include "fdeep/layers/conv_2d_layer.hpp"
Expand All @@ -41,6 +42,7 @@
#include "fdeep/layers/depthwise_conv_2d_layer.hpp"
#include "fdeep/layers/dot_layer.hpp"
#include "fdeep/layers/elu_layer.hpp"
#include "fdeep/layers/embedding_layer.hpp"
#include "fdeep/layers/exponential_layer.hpp"
#include "fdeep/layers/flatten_layer.hpp"
#include "fdeep/layers/gelu_layer.hpp"
Expand Down Expand Up @@ -998,6 +1000,15 @@ namespace internal {
return std::make_shared<normalization_layer>(name, axex, mean, variance);
}

inline layer_ptr create_category_encoding_layer(
const get_param_f&,
const nlohmann::json& data, const std::string& name)
{
const std::size_t num_tokens = data["config"]["num_tokens"];
const std::string output_mode = data["config"]["output_mode"];
return std::make_shared<category_encoding_layer>(name, num_tokens, output_mode);
}

inline layer_ptr create_attention_layer(
const get_param_f& get_param,
const nlohmann::json& data, const std::string& name)
Expand Down Expand Up @@ -1136,6 +1147,17 @@ namespace internal {
return fplus::transform(create_node, inbound_nodes_data);
}

inline layer_ptr create_embedding_layer(const get_param_f& get_param,
const nlohmann::json& data,
const std::string& name)
{
const std::size_t input_dim = data["config"]["input_dim"];
const std::size_t output_dim = data["config"]["output_dim"];
const float_vec weights = decode_floats(get_param(name, "weights"));

return std::make_shared<embedding_layer>(name, input_dim, output_dim, weights);
}

inline layer_ptr create_time_distributed_layer(const get_param_f& get_param,
const nlohmann::json& data,
const std::string& name,
Expand Down Expand Up @@ -1228,8 +1250,10 @@ namespace internal {
{ "Rescaling", create_rescaling_layer },
{ "Reshape", create_reshape_layer },
{ "Resizing", create_resizing_layer },
{ "Embedding", create_embedding_layer },
{ "Softmax", create_softmax_layer },
{ "Normalization", create_normalization_layer },
{ "CategoryEncoding", create_category_encoding_layer },
{ "Attention", create_attention_layer },
{ "AdditiveAttention", create_additive_attention_layer },
{ "MultiHeadAttention", create_multi_head_attention_layer },
Expand Down
63 changes: 63 additions & 0 deletions include/fdeep/layers/category_encoding_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)

#pragma once

#include "fdeep/layers/layer.hpp"

#include <string>
#include <vector>

namespace fdeep {
namespace internal {

class category_encoding_layer : public layer {
public:
explicit category_encoding_layer(const std::string& name,
const std::size_t& num_tokens,
const std::string& output_mode)
: layer(name)
, num_tokens_(num_tokens)
, output_mode_(output_mode)
{
assertion(output_mode_ == "one_hot" || output_mode_ == "multi_hot" || output_mode_ == "count",
"Unsupported output mode (" + output_mode_ + ").");
}

protected:
tensors apply_impl(const tensors& inputs) const override
{
assertion(inputs.size() == 1, "need exactly one input");
const auto input = inputs[0];
assertion(input.shape().rank() == 1, "Tensor of rank 1 required, but shape is '" + show_tensor_shape(input.shape()) + "'");

if (output_mode_ == "one_hot") {
assertion(input.shape().depth_ == 1, "Tensor of depth 1 required, but is: " + fplus::show(input.shape().depth_));
tensor out(tensor_shape(num_tokens_), float_type(0));
const std::size_t idx = fplus::floor<float_type, std::size_t>(input.get_ignore_rank(tensor_pos(0)));
assertion(idx <= num_tokens_, "Invalid input value (> num_tokens).");
out.set_ignore_rank(tensor_pos(idx), 1);
return { out };
} else {
tensor out(tensor_shape(num_tokens_), float_type(0));
for (const auto& x : *(input.as_vector())) {
const std::size_t idx = fplus::floor<float_type, std::size_t>(x);
assertion(idx <= num_tokens_, "Invalid input value (> num_tokens).");
if (output_mode_ == "multi_hot") {
out.set_ignore_rank(tensor_pos(idx), 1);
} else if (output_mode_ == "count") {
out.set_ignore_rank(tensor_pos(idx), out.get_ignore_rank(tensor_pos(idx)) + 1);
}
}
return { out };
}
}
std::size_t num_tokens_;
std::string output_mode_;
};

}
}
65 changes: 65 additions & 0 deletions include/fdeep/layers/embedding_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)

#pragma once

#include "fdeep/layers/layer.hpp"

#include <functional>
#include <string>

namespace fdeep {
namespace internal {

class embedding_layer : public layer {
public:
explicit embedding_layer(const std::string& name,
std::size_t input_dim,
std::size_t output_dim,
const float_vec& weights)
: layer(name)
, input_dim_(input_dim)
, output_dim_(output_dim)
, weights_(weights)
{
}

protected:
tensors apply_impl(const tensors& inputs) const override final
{
const auto input_shapes = fplus::transform(fplus_c_mem_fn_t(tensor, shape, tensor_shape), inputs);

// ensure that tensor shape is (1, 1, 1, 1, seq_len)
assertion(inputs.front().shape().size_dim_5_ == 1
&& inputs.front().shape().size_dim_4_ == 1
&& inputs.front().shape().height_ == 1
&& inputs.front().shape().width_ == 1,
"size_dim_5, size_dim_4, height and width dimension must be 1, but shape is '" + show_tensor_shapes(input_shapes) + "'");

tensors results;
for (auto&& input : inputs) {
const std::size_t sequence_len = input.shape().depth_;
float_vec output_vec(sequence_len * output_dim_);
auto&& it = output_vec.begin();

for (std::size_t i = 0; i < sequence_len; ++i) {
std::size_t index = static_cast<std::size_t>(input.get(tensor_pos(i)));
assertion(index < input_dim_, "vocabulary item indices must all be strictly less than the value of input_dim");
it = std::copy_n(weights_.cbegin() + static_cast<float_vec::const_iterator::difference_type>(index * output_dim_), output_dim_, it);
}

results.push_back(tensor(tensor_shape(sequence_len, output_dim_), std::move(output_vec)));
}
return results;
}

const std::size_t input_dim_;
const std::size_t output_dim_;
const float_vec weights_;
};

}
}
33 changes: 31 additions & 2 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Input, Embedding, CategoryEncoding
from tensorflow.keras.models import Model, load_model

__author__ = "Tobias Hermann"
Expand Down Expand Up @@ -102,8 +102,20 @@ def set_shape_idx_0_to_1_if_none(shape):

def generate_input_data(input_layer):
"""Random data fitting the input shape of a layer."""
print("input input_layer type", type(input_layer).__name__) # todo: remove
print("input_layer._outbound_nodes type", type(input_layer._outbound_nodes).__name__) # todo: remove
if input_layer._outbound_nodes and isinstance(
get_first_outbound_op(input_layer), Embedding):
random_fn = lambda size: np.random.randint(
0, get_first_outbound_op(input_layer).input_dim, size)
elif input_layer._outbound_nodes and isinstance(
get_first_outbound_op(input_layer), CategoryEncoding):
random_fn = lambda size: np.random.randint(
0, get_first_outbound_op(input_layer).num_tokens, size)
else:
random_fn = np.random.normal
shape = get_layer_input_shape(input_layer)
return np.random.normal(
return random_fn(
size=replace_none_with(32, set_shape_idx_0_to_1_if_none(singleton_list_to_value(shape)))).astype(np.float32)

data_in = list(map(generate_input_data, get_model_input_layers(model)))
Expand Down Expand Up @@ -304,6 +316,16 @@ def show_prelu_layer(layer):
return result


def show_embedding_layer(layer):
"""Serialize Embedding layer to dict"""
weights = layer.get_weights()
assert len(weights) == 1
result = {
'weights': encode_floats(weights[0])
}
return result


def show_input_layer(layer):
"""Serialize input layer to dict"""
assert not layer.sparse
Expand Down Expand Up @@ -340,6 +362,11 @@ def show_rescaling_layer(layer):
assert isinstance(layer.scale, float)


def show_category_encoding_layer(layer):
"""Serialize CategoryEncoding layer to dict"""
assert layer.output_mode in ["multi_hot", "count", "one_hot"]


def show_attention_layer(layer):
"""Serialize Attention layer to dict"""
assert layer.score_mode in ["dot", "concat"]
Expand Down Expand Up @@ -381,6 +408,7 @@ def get_layer_functions_dict():
'Dense': show_dense_layer,
'Dot': show_dot_layer,
'PReLU': show_prelu_layer,
'Embedding': show_embedding_layer,
'LayerNormalization': show_layer_normalization_layer,
'TimeDistributed': show_time_distributed_layer,
'Input': show_input_layer,
Expand All @@ -389,6 +417,7 @@ def get_layer_functions_dict():
'UpSampling2D': show_upsampling2d_layer,
'Resizing': show_resizing_layer,
'Rescaling': show_rescaling_layer,
'CategoryEncoding': show_category_encoding_layer,
'Attention': show_attention_layer,
'AdditiveAttention': show_additive_attention_layer,
'MultiHeadAttention': show_multi_head_attention_layer,
Expand Down
45 changes: 45 additions & 0 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tensorflow.keras.layers import AdditiveAttention
from tensorflow.keras.layers import Attention
from tensorflow.keras.layers import BatchNormalization, Concatenate, LayerNormalization, UnitNormalization
from tensorflow.keras.layers import CategoryEncoding, Embedding
from tensorflow.keras.layers import Conv1D, ZeroPadding1D, Cropping1D
from tensorflow.keras.layers import Conv2D, ZeroPadding2D, Cropping2D, CenterCrop
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D
Expand Down Expand Up @@ -580,6 +581,49 @@ def get_test_model_exhaustive():
return model


def get_test_model_embedding():
"""Returns a minimalistic test model for the Embedding and CategoryEncoding layers."""

input_dims = [
1023, # maximum integer value in input data
255,
15,
]
input_shapes = [
(100,), # must be single-element tuple (for sequence length)
(1000,),
(1,),
]
assert len(input_dims) == len(input_shapes)
output_dims = [8, 3] # embedding dimension

inputs = [Input(shape=s) for s in input_shapes]

outputs = []
for k in range(2):
embedding = Embedding(input_dim=input_dims[k], output_dim=output_dims[k])(inputs[k])
outputs.append(embedding)

outputs.append(CategoryEncoding(1024, output_mode='multi_hot', sparse=False)(inputs[0]))
# No longer working since TF 2.16: https://github.com/tensorflow/tensorflow/issues/65390
# Error: Value passed to parameter 'values' has DataType float32 not in list of allowed values: int32, int64
# outputs.append(CategoryEncoding(1024, output_mode='count', sparse=False)(inputs[0]))
# outputs.append(CategoryEncoding(16, output_mode='one_hot', sparse=False)(inputs[2]))
# Error: Value passed to parameter 'values' has DataType float32 not in list of allowed values: int32, int64
# outputs.append(CategoryEncoding(1023, output_mode='multi_hot', sparse=True)(inputs[0]))

model = Model(inputs=inputs, outputs=outputs, name='test_model_embedding')
model.compile(loss='mse', optimizer='adam')

# fit to dummy data
training_data_size = 2
data_in = generate_integer_input_data(training_data_size, 0, input_dims, input_shapes)
initial_data_out = model.predict(data_in)
data_out = generate_output_data(training_data_size, initial_data_out)
model.fit(data_in, data_out, epochs=1)
return model


def get_test_model_variable():
"""Returns a model with variably shaped input tensors."""

Expand Down Expand Up @@ -663,6 +707,7 @@ def main():

get_model_functions = {
'exhaustive': get_test_model_exhaustive,
'embedding': get_test_model_embedding,
'variable': get_test_model_variable,
'sequential': get_test_model_sequential,
}
Expand Down
4 changes: 4 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ add_custom_command ( OUTPUT test_model_exhaustive.json
COMMAND bash -c "${Python3_EXECUTABLE} ${FDEEP_TOP_DIR}/keras_export/convert_model.py test_model_exhaustive.keras test_model_exhaustive.json"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/)

add_custom_command ( OUTPUT test_model_embedding.keras
COMMAND bash -c "${Python3_EXECUTABLE} ${FDEEP_TOP_DIR}/keras_export/generate_test_models.py embedding test_model_embedding.keras"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/)

add_custom_command ( OUTPUT test_model_variable.json
DEPENDS test_model_variable.keras
COMMAND bash -c "${Python3_EXECUTABLE} ${FDEEP_TOP_DIR}/keras_export/convert_model.py test_model_variable.keras test_model_variable.json"
Expand Down
21 changes: 21 additions & 0 deletions test/test_model_embedding_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)

#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
#include "doctest/doctest.h"
#define FDEEP_FLOAT_TYPE double
#include <fdeep/fdeep.hpp>

TEST_CASE("test_model_embedding_test, load_model")
{
const auto model = fdeep::load_model("../test_model_embedding.json",
true, fdeep::cout_logger, static_cast<fdeep::float_type>(0.00001));
const auto multi_inputs = fplus::generate<std::vector<fdeep::tensors>>(
[&]() -> fdeep::tensors { return model.generate_dummy_inputs(); },
10);
model.predict_multi(multi_inputs, false);
model.predict_multi(multi_inputs, true);
}

0 comments on commit 2b43bc2

Please sign in to comment.