-
-
Notifications
You must be signed in to change notification settings - Fork 237
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #97 from hunyadi/embedding_layer
New layer type "Embedding"
- Loading branch information
Showing
6 changed files
with
193 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright 2016, Tobias Hermann. | ||
// https://github.com/Dobiasd/frugally-deep | ||
// Distributed under the MIT License. | ||
// (See accompanying LICENSE file or at | ||
// https://opensource.org/licenses/MIT) | ||
|
||
#pragma once | ||
|
||
#include "fdeep/layers/layer.hpp" | ||
|
||
#include <string> | ||
#include <functional> | ||
|
||
namespace fdeep | ||
{ | ||
namespace internal | ||
{ | ||
|
||
class embedding_layer : public layer | ||
{ | ||
public: | ||
explicit embedding_layer(const std::string& name, | ||
std::size_t input_dim, | ||
std::size_t output_dim, | ||
const float_vec& weights) | ||
: layer(name) | ||
, input_dim_(input_dim) | ||
, output_dim_(output_dim) | ||
, weights_(weights) | ||
{} | ||
|
||
protected: | ||
tensor5s apply_impl(const tensor5s &inputs) const override final | ||
{ | ||
const auto input_shapes = fplus::transform(fplus_c_mem_fn_t(tensor5, shape, shape5), inputs); | ||
|
||
// ensure that tensor5 shape is (1, 1, 1, 1, seq_len) | ||
assertion(inputs.front().shape().size_dim_5_ == 1 | ||
&& inputs.front().shape().size_dim_4_ == 1 | ||
&& inputs.front().shape().height_ == 1 | ||
&& inputs.front().shape().width_ == 1, | ||
"size_dim_5, size_dim_4, height and width dimension must be 1, but shape is '" + show_shape5s(input_shapes) + "'"); | ||
|
||
tensor5s results; | ||
for (auto&& input : inputs) | ||
{ | ||
const std::size_t sequence_len = input.shape().depth_; | ||
float_vec output_vec(sequence_len * output_dim_); | ||
auto&& it = output_vec.begin(); | ||
|
||
for (std::size_t i = 0; i < sequence_len; ++i) | ||
{ | ||
std::size_t index = static_cast<std::size_t>(input.get(0, 0, 0, 0, i)); | ||
assertion(index < input_dim_, "vocabulary item indices must all be strictly less than the value of input_dim"); | ||
it = std::copy_n(weights_.cbegin() + static_cast<float_vec::const_iterator::difference_type>(index * output_dim_), output_dim_, it); | ||
} | ||
|
||
results.push_back(tensor5(shape5(1, 1, 1, sequence_len, output_dim_), std::move(output_vec))); | ||
} | ||
return results; | ||
} | ||
|
||
const std::size_t input_dim_; | ||
const std::size_t output_dim_; | ||
const float_vec weights_; | ||
}; | ||
|
||
} // namespace internal | ||
} // namespace fdeep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// Copyright 2016, Tobias Hermann. | ||
// https://github.com/Dobiasd/frugally-deep | ||
// Distributed under the MIT License. | ||
// (See accompanying LICENSE file or at | ||
// https://opensource.org/licenses/MIT) | ||
|
||
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN | ||
#include "doctest.h" | ||
#include <fdeep/fdeep.hpp> | ||
|
||
#define FDEEP_FLOAT_TYPE double | ||
|
||
TEST_CASE("test_model_embedding_test, load_model") | ||
{ | ||
const auto model = fdeep::load_model("../test_model_embedding.json", | ||
true, fdeep::cout_logger, static_cast<fdeep::float_type>(0.00001)); | ||
const auto multi_inputs = fplus::generate<std::vector<fdeep::tensor5s>>( | ||
[&]() -> fdeep::tensor5s {return model.generate_dummy_inputs();}, | ||
10); | ||
|
||
model.predict_multi(multi_inputs, false); | ||
model.predict_multi(multi_inputs, true); | ||
} |