From 212d60999d28dfa07cf1a3bb6a9c718ef8a649ab Mon Sep 17 00:00:00 2001 From: Dobiasd Date: Tue, 26 Dec 2023 21:01:50 +0100 Subject: [PATCH] Separate weights and biases --- .../layers/multi_head_attention_layer.hpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/include/fdeep/layers/multi_head_attention_layer.hpp b/include/fdeep/layers/multi_head_attention_layer.hpp index 4ad8c85c..f649d3fe 100644 --- a/include/fdeep/layers/multi_head_attention_layer.hpp +++ b/include/fdeep/layers/multi_head_attention_layer.hpp @@ -20,12 +20,22 @@ class multi_head_attention_layer : public layer explicit multi_head_attention_layer(const std::string& name, std::size_t num_heads, std::size_t key_dim, std::size_t value_dim, bool use_bias, const std::vector& attention_axes, - const std::vector& weights) + const std::vector& saved_weights) : layer(name), num_heads_(num_heads), key_dim_(key_dim), - value_dim_(value_dim), use_bias_(use_bias), attention_axes_(attention_axes), - weights_(weights) + value_dim_(value_dim), attention_axes_(attention_axes), + weights_(extract_weights(saved_weights, use_bias)), + biases_(extract_biases(saved_weights, use_bias)) { } +private: + tensors extract_weights(const tensors& saved_weights, bool use_bias) + { + return use_bias ? fplus::unweave(saved_weights).first : saved_weights; + } + tensors extract_biases(const tensors& saved_weights, bool use_bias) + { + return use_bias ? fplus::unweave(saved_weights).second : tensors(); // todo: create biases with zeroes in right shape + } protected: tensors apply_impl(const tensors& input) const override { @@ -42,9 +52,9 @@ class multi_head_attention_layer : public layer std::size_t num_heads_; std::size_t key_dim_; std::size_t value_dim_; - bool use_bias_; std::vector attention_axes_; std::vector weights_; + std::vector biases_; }; } } // namespace fdeep, namespace internal