From 3176f89314185794fa8122fd0de8e68cc6a06091 Mon Sep 17 00:00:00 2001 From: Dobiasd Date: Sun, 31 Dec 2023 09:06:39 +0100 Subject: [PATCH] create dense output layer separately --- .../layers/multi_head_attention_layer.hpp | 52 ++++++++++++++----- keras_export/generate_test_models.py | 3 ++ 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/include/fdeep/layers/multi_head_attention_layer.hpp b/include/fdeep/layers/multi_head_attention_layer.hpp index 74dc992a..bd293fd4 100644 --- a/include/fdeep/layers/multi_head_attention_layer.hpp +++ b/include/fdeep/layers/multi_head_attention_layer.hpp @@ -27,7 +27,7 @@ class multi_head_attention_layer : public layer query_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 0, name + "_query_dense")), value_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 2, name + "_value_dense")), key_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 1, name + "_key_dense")), - output_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 3, name + "_output_dense")) + output_dense_(create_output_dense_layer(weights_and_biases, use_bias, name + "_output_dense")) { } private: @@ -35,16 +35,22 @@ class multi_head_attention_layer : public layer const tensors& weights_and_biases, bool use_bias, const std::size_t num_heads, const std::size_t index, const std::string& name) { + assertion(index <= 2, "Invalid dense layer index."); + const std::size_t index_factor = use_bias ? 2 : 1; - const tensor weights = weights_and_biases[index_factor * index]; + + tensor weights = weights_and_biases[index_factor * index]; + if (index == 3) + weights = permute_tensor(weights, {3, 1, 2}); + const std::size_t units = weights.shape().depth_; - const tensor biases = use_bias ? + + tensor biases = use_bias ? weights_and_biases[index_factor * index + 1] : - tensor(index == 3 ? tensor_shape(num_heads, 1, units) : tensor_shape(num_heads, units), 0); - const auto weights_per_head = - index == 3 ? tensor_to_tensors_height_slices(weights) : tensor_to_tensors_width_slices(weights); - const auto biases_per_head = - index == 3 ? tensor_to_tensors_height_slices(biases) : tensor_to_tensors_width_slices(biases); + tensor(index == 3 ? tensor_shape(units) : tensor_shape(num_heads, units), 0); + + const auto weights_per_head = tensor_to_tensors_width_slices(weights); + const auto biases_per_head = tensor_to_tensors_width_slices(biases); assertion(weights_per_head.size() == num_heads, "Invalid weights for number of heads."); assertion(biases_per_head.size() == num_heads, "Invalid biases for number of heads."); const std::vector dense_layers = @@ -60,6 +66,23 @@ class multi_head_attention_layer : public layer fplus::enumerate(fplus::zip(weights_per_head, biases_per_head))); return dense_layers; } + dense_layer create_output_dense_layer( + const tensors& weights_and_biases, bool use_bias, const std::string& name) + { + const std::size_t index_factor = use_bias ? 2 : 1; + + tensor weights = weights_and_biases[index_factor * 3]; + + const std::size_t units = weights.shape().depth_; + + tensor biases = use_bias ? + weights_and_biases[index_factor * 3 + 1] : + tensor(tensor_shape(units), 0); + + const auto weights_per_head = tensor_to_tensors_width_slices(weights); + const auto biases_per_head = tensor_to_tensors_width_slices(biases); + return dense_layer(name + "_output", units, *weights.as_vector(), *biases.as_vector()); + } tensors extract_biases(const tensors& saved_weights, bool use_bias) { return use_bias ? fplus::unweave(saved_weights).second : tensors(); @@ -89,8 +112,7 @@ class multi_head_attention_layer : public layer // https://gist.github.com/sevagh/b71d253a347a9b59c026580625452fc5 const tensor scores = dot_product_tensors(query, transpose(key), std::vector({2, 1}), false); const tensor distribution = softmax(scores); - const tensor output = dot_product_tensors(distribution, value, std::vector({2, 1}), false); - return output_dense_[head_index].apply({output}).front(); // todo + return dot_product_tensors(distribution, value, std::vector({2, 1}), false); } protected: tensors apply_impl(const tensors& input) const override @@ -99,16 +121,22 @@ class multi_head_attention_layer : public layer const tensor query_raw = input[0]; const tensor value_raw = input[1]; const tensor key_raw = input.size() > 2 ? input[2] : value_raw; - return {apply_head(query_raw, value_raw, key_raw, 0)}; // todo: all + const auto outputs = fplus::transform([&](const std::size_t head_idx) + { + return apply_head(query_raw, value_raw, key_raw, head_idx); + }, fplus::numbers(0, num_heads_)); + const tensor merged = concatenate_tensors_depth(outputs); + return output_dense_.apply({merged}); } std::size_t num_heads_; std::size_t key_dim_; std::size_t value_dim_; std::vector attention_axes_; + // todo: store each head as a separate object? std::vector query_dense_; std::vector value_dense_; std::vector key_dense_; - std::vector output_dense_; + dense_layer output_dense_; }; } } // namespace fdeep, namespace internal diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index 37bb3886..018a2124 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -457,6 +457,9 @@ def get_test_model_exhaustive(): outputs.append(MultiHeadAttention( num_heads=3, key_dim=1, value_dim=None, use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) + outputs.append(MultiHeadAttention( + num_heads=3, key_dim=1, value_dim=None, + use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50])) outputs.append(MultiHeadAttention( num_heads=1, key_dim=1, value_dim=None, use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))