-
-
Notifications
You must be signed in to change notification settings - Fork 237
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Re-add support for Embedding and CategoryEncoding layers
- Loading branch information
Showing
8 changed files
with
254 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright 2016, Tobias Hermann. | ||
// https://github.com/Dobiasd/frugally-deep | ||
// Distributed under the MIT License. | ||
// (See accompanying LICENSE file or at | ||
// https://opensource.org/licenses/MIT) | ||
|
||
#pragma once | ||
|
||
#include "fdeep/layers/layer.hpp" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
namespace fdeep { | ||
namespace internal { | ||
|
||
class category_encoding_layer : public layer { | ||
public: | ||
explicit category_encoding_layer(const std::string& name, | ||
const std::size_t& num_tokens, | ||
const std::string& output_mode) | ||
: layer(name) | ||
, num_tokens_(num_tokens) | ||
, output_mode_(output_mode) | ||
{ | ||
assertion(output_mode_ == "one_hot" || output_mode_ == "multi_hot" || output_mode_ == "count", | ||
"Unsupported output mode (" + output_mode_ + ")."); | ||
} | ||
|
||
protected: | ||
tensors apply_impl(const tensors& inputs) const override | ||
{ | ||
assertion(inputs.size() == 1, "need exactly one input"); | ||
const auto input = inputs[0]; | ||
assertion(input.shape().rank() == 1, "Tensor of rank 1 required, but shape is '" + show_tensor_shape(input.shape()) + "'"); | ||
|
||
if (output_mode_ == "one_hot") { | ||
assertion(input.shape().depth_ == 1, "Tensor of depth 1 required, but is: " + fplus::show(input.shape().depth_)); | ||
tensor out(tensor_shape(num_tokens_), float_type(0)); | ||
const std::size_t idx = fplus::floor<float_type, std::size_t>(input.get_ignore_rank(tensor_pos(0))); | ||
assertion(idx <= num_tokens_, "Invalid input value (> num_tokens)."); | ||
out.set_ignore_rank(tensor_pos(idx), 1); | ||
return { out }; | ||
} else { | ||
tensor out(tensor_shape(num_tokens_), float_type(0)); | ||
for (const auto& x : *(input.as_vector())) { | ||
const std::size_t idx = fplus::floor<float_type, std::size_t>(x); | ||
assertion(idx <= num_tokens_, "Invalid input value (> num_tokens)."); | ||
if (output_mode_ == "multi_hot") { | ||
out.set_ignore_rank(tensor_pos(idx), 1); | ||
} else if (output_mode_ == "count") { | ||
out.set_ignore_rank(tensor_pos(idx), out.get_ignore_rank(tensor_pos(idx)) + 1); | ||
} | ||
} | ||
return { out }; | ||
} | ||
} | ||
std::size_t num_tokens_; | ||
std::string output_mode_; | ||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Copyright 2016, Tobias Hermann. | ||
// https://github.com/Dobiasd/frugally-deep | ||
// Distributed under the MIT License. | ||
// (See accompanying LICENSE file or at | ||
// https://opensource.org/licenses/MIT) | ||
|
||
#pragma once | ||
|
||
#include "fdeep/layers/layer.hpp" | ||
|
||
#include <functional> | ||
#include <string> | ||
|
||
namespace fdeep { | ||
namespace internal { | ||
|
||
class embedding_layer : public layer { | ||
public: | ||
explicit embedding_layer(const std::string& name, | ||
std::size_t input_dim, | ||
std::size_t output_dim, | ||
const float_vec& weights) | ||
: layer(name) | ||
, input_dim_(input_dim) | ||
, output_dim_(output_dim) | ||
, weights_(weights) | ||
{ | ||
} | ||
|
||
protected: | ||
tensors apply_impl(const tensors& inputs) const override final | ||
{ | ||
const auto input_shapes = fplus::transform(fplus_c_mem_fn_t(tensor, shape, tensor_shape), inputs); | ||
|
||
// ensure that tensor shape is (1, 1, 1, 1, seq_len) | ||
assertion(inputs.front().shape().size_dim_5_ == 1 | ||
&& inputs.front().shape().size_dim_4_ == 1 | ||
&& inputs.front().shape().height_ == 1 | ||
&& inputs.front().shape().width_ == 1, | ||
"size_dim_5, size_dim_4, height and width dimension must be 1, but shape is '" + show_tensor_shapes(input_shapes) + "'"); | ||
|
||
tensors results; | ||
for (auto&& input : inputs) { | ||
const std::size_t sequence_len = input.shape().depth_; | ||
float_vec output_vec(sequence_len * output_dim_); | ||
auto&& it = output_vec.begin(); | ||
|
||
for (std::size_t i = 0; i < sequence_len; ++i) { | ||
std::size_t index = static_cast<std::size_t>(input.get(tensor_pos(i))); | ||
assertion(index < input_dim_, "vocabulary item indices must all be strictly less than the value of input_dim"); | ||
it = std::copy_n(weights_.cbegin() + static_cast<float_vec::const_iterator::difference_type>(index * output_dim_), output_dim_, it); | ||
} | ||
|
||
results.push_back(tensor(tensor_shape(sequence_len, output_dim_), std::move(output_vec))); | ||
} | ||
return results; | ||
} | ||
|
||
const std::size_t input_dim_; | ||
const std::size_t output_dim_; | ||
const float_vec weights_; | ||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// Copyright 2016, Tobias Hermann. | ||
// https://github.com/Dobiasd/frugally-deep | ||
// Distributed under the MIT License. | ||
// (See accompanying LICENSE file or at | ||
// https://opensource.org/licenses/MIT) | ||
|
||
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN | ||
#include "doctest/doctest.h" | ||
#define FDEEP_FLOAT_TYPE double | ||
#include <fdeep/fdeep.hpp> | ||
|
||
TEST_CASE("test_model_embedding_test, load_model") | ||
{ | ||
const auto model = fdeep::load_model("../test_model_embedding.json", | ||
true, fdeep::cout_logger, static_cast<fdeep::float_type>(0.00001)); | ||
const auto multi_inputs = fplus::generate<std::vector<fdeep::tensors>>( | ||
[&]() -> fdeep::tensors { return model.generate_dummy_inputs(); }, | ||
10); | ||
model.predict_multi(multi_inputs, false); | ||
model.predict_multi(multi_inputs, true); | ||
} |