Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiHeadAttention layer #392

Merged
merged 38 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a2992f9
Add MultiHeadAttention layer
Dobiasd Aug 14, 2023
e0a886d
Pass parameters to multi_head_attention_layer
Dobiasd Aug 14, 2023
931bc27
Persist weight shapes too
Dobiasd Aug 14, 2023
3a5f8ec
Merge branch 'master' into multi-head-attention
Dobiasd Sep 12, 2023
a1e0a58
Merge branch 'master' into multi-head-attention
Dobiasd Sep 13, 2023
60f77d5
Merge master
Dobiasd Nov 1, 2023
14aa03a
Add comment with link to stackoverflow question
Dobiasd Nov 1, 2023
e8efb24
Implement special treatment for the inbound_nodes format of MultiHead…
Dobiasd Nov 7, 2023
7a7928d
Adjust comment in multi_head_attention_layer
Dobiasd Nov 7, 2023
84c492e
Merge branch 'master' into multi-head-attention
Dobiasd Nov 28, 2023
3bfbf07
comment
Dobiasd Dec 16, 2023
1b6597e
Merge branch 'master' into multi-head-attention
Dobiasd Dec 16, 2023
9743a49
add comment
Dobiasd Dec 26, 2023
9231a60
Merge branch 'master' into multi-head-attention
Dobiasd Dec 26, 2023
78c0d4e
Merge branch 'master' into multi-head-attention
Dobiasd Dec 26, 2023
09d5868
Convert weights to tensors for ctor
Dobiasd Dec 26, 2023
212d609
Separate weights and biases
Dobiasd Dec 26, 2023
118f663
apply dense layers to query, value and key
Dobiasd Dec 29, 2023
53f2d9b
only allow the usual shapes
Dobiasd Dec 29, 2023
4184515
adjust tests
Dobiasd Dec 29, 2023
291e127
Fix default bias and loading order of key and value weights
Dobiasd Dec 30, 2023
5b3fbd2
fix shapes, add tests
Dobiasd Dec 30, 2023
0e17267
decode weights into multiple heads
Dobiasd Dec 31, 2023
d896c4c
enable tests
Dobiasd Dec 31, 2023
3176f89
create dense output layer separately
Dobiasd Dec 31, 2023
c7adc7c
clean up
Dobiasd Dec 31, 2023
db62540
add more tests
Dobiasd Dec 31, 2023
702cb60
shorten
Dobiasd Dec 31, 2023
b825b57
const
Dobiasd Dec 31, 2023
67264ea
teeeests
Dobiasd Dec 31, 2023
8f630a8
fix the distribution calculation by dividing with the square root of …
Dobiasd Dec 31, 2023
0c6cc0a
remove debug tests
Dobiasd Dec 31, 2023
0d2be86
remove todo comment
Dobiasd Dec 31, 2023
fd6e7c4
Check for attention_axes=None in conversion
Dobiasd Dec 31, 2023
a95abf4
Do not pass unused attention_axes
Dobiasd Dec 31, 2023
41ac53a
double-check weights shapes
Dobiasd Dec 31, 2023
b35deb9
Revert debug output
Dobiasd Dec 31, 2023
7a574b7
Add MultiHeadAttention layer to list of supported layers in README
Dobiasd Dec 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Would you like to build/train a model using Keras/Python? And would you like to
* `UpSampling1D/2D`, `Resizing`
* `Reshape`, `Permute`, `RepeatVector`
* `Embedding`, `CategoryEncoding`
* `Attention`, `AdditiveAttention`
* `Attention`, `AdditiveAttention`, `MultiHeadAttention`


### Also supported
Expand All @@ -78,7 +78,7 @@ Would you like to build/train a model using Keras/Python? And would you like to
`GRUCell`, `Hashing`,
`IntegerLookup`,
`LocallyConnected1D`, `LocallyConnected2D`,
`LSTMCell`, `Masking`, `MultiHeadAttention`,
`LSTMCell`, `Masking`,
`RepeatVector`, `RNN`, `SimpleRNN`,
`SimpleRNNCell`, `StackedRNNCells`, `StringLookup`, `TextVectorization`,
`ThresholdedReLU`, `Upsampling3D`, `temporal` models
Expand Down
50 changes: 50 additions & 0 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "fdeep/layers/maximum_layer.hpp"
#include "fdeep/layers/minimum_layer.hpp"
#include "fdeep/layers/model_layer.hpp"
#include "fdeep/layers/multi_head_attention_layer.hpp"
#include "fdeep/layers/multiply_layer.hpp"
#include "fdeep/layers/normalization_layer.hpp"
#include "fdeep/layers/pooling_3d_layer.hpp"
Expand Down Expand Up @@ -1068,6 +1069,30 @@ inline layer_ptr create_additive_attention_layer(
return std::make_shared<additive_attention_layer>(name, scale);
}

inline layer_ptr create_multi_head_attention_layer(
const get_param_f& get_param,
const nlohmann::json& data, const std::string& name)
{
const std::size_t num_heads = data["config"]["num_heads"];
const std::size_t key_dim = data["config"]["key_dim"];
const std::size_t value_dim = data["config"]["value_dim"];
const bool use_bias = data["config"]["use_bias"];
const auto weight_shapes =
create_vector<std::vector<std::size_t>>(fplus::bind_1st_of_2(
create_vector<std::size_t, decltype(create_size_t)>, create_size_t),
get_param(name, "weight_shapes"));
const auto weight_values = create_vector<float_vec>(decode_floats, get_param(name, "weights"));
const auto weights_and_biases = fplus::zip_with(
[](const std::vector<std::size_t>& shape, const float_vec& values) -> tensor
{
return tensor(
create_tensor_shape_from_dims(shape),
fplus::convert_container<float_vec>(values));
}, weight_shapes, weight_values);
return std::make_shared<multi_head_attention_layer>(name,
num_heads, key_dim, value_dim, use_bias, weights_and_biases);
}

inline std::string get_activation_type(const nlohmann::json& data)
{
assertion(data.is_string(), "Layer activation must be a string.");
Expand Down Expand Up @@ -1141,11 +1166,35 @@ inline node create_node(const nlohmann::json& inbound_nodes_data)
inbound_nodes_data));
}

inline nodes create_multi_head_attention_nodes(const std::vector<nlohmann::json> inbound_nodes_data)
{
assertion(inbound_nodes_data.size() == 1 && inbound_nodes_data.front().size() == 1,
"multi_head_attention needs to have exactly one primary inbound node; see https://stackoverflow.com/q/77400589/1866775");
const auto inbound_node_data = inbound_nodes_data.front().front();
const auto value = inbound_node_data[3]["value"];
if (json_obj_has_member(inbound_node_data[3], "key")) {
return {
node({
create_node_connection(inbound_node_data),
create_node_connection(value),
create_node_connection(inbound_node_data[3]["key"])
})};
}
return {
node({
create_node_connection(inbound_node_data),
create_node_connection(value)
})};
}

inline nodes create_nodes(const nlohmann::json& data)
{
assertion(data["inbound_nodes"].is_array(), "no inbound nodes");
const std::vector<nlohmann::json> inbound_nodes_data =
data["inbound_nodes"];
if (data["class_name"] == "MultiHeadAttention") {
return create_multi_head_attention_nodes(inbound_nodes_data);
}
return fplus::transform(create_node, inbound_nodes_data);
}

Expand Down Expand Up @@ -1378,6 +1427,7 @@ inline layer_ptr create_layer(const get_param_f& get_param,
{"CategoryEncoding", create_category_encoding_layer},
{"Attention", create_attention_layer},
{"AdditiveAttention", create_additive_attention_layer},
{"MultiHeadAttention", create_multi_head_attention_layer},
};

const wrapper_layer_creators wrapper_creators = {
Expand Down
128 changes: 128 additions & 0 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)

#pragma once

#include "fdeep/layers/layer.hpp"
#include "fdeep/layers/dense_layer.hpp"
#include "fdeep/layers/softmax_layer.hpp"

#include <string>

namespace fdeep { namespace internal
{

class multi_head_attention_layer : public layer
{
public:
explicit multi_head_attention_layer(const std::string& name,
std::size_t num_heads, std::size_t key_dim, std::size_t value_dim,
bool use_bias, const std::vector<tensor>& weights_and_biases)
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
value_dim_(value_dim),
query_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 0, key_dim, name + "_query_dense")),
value_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 2, value_dim, name + "_value_dense")),
key_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 1, key_dim, name + "_key_dense")),
output_dense_(create_output_dense_layer(weights_and_biases, use_bias, name + "_output_dense"))
{
}
private:
std::vector<dense_layer> create_dense_layers(
const tensors& weights_and_biases, bool use_bias, const std::size_t num_heads,
const std::size_t index, const std::size_t units, const std::string& name)
{
assertion(index <= 2, "Invalid dense layer index.");
const std::size_t index_factor = use_bias ? 2 : 1;
const tensor weights = weights_and_biases[index_factor * index];
const tensor biases = use_bias ?
weights_and_biases[index_factor * index + 1] :
tensor(tensor_shape(num_heads, units), 0);

assertion(weights.shape().depth_ == units, "Invalid weights shape for attention head dimension.");
assertion(biases.shape().depth_ == units, "Invalid biases shape for attention head dimension.");

const auto weights_per_head = tensor_to_tensors_width_slices(weights);
const auto biases_per_head = tensor_to_tensors_width_slices(biases);
assertion(weights_per_head.size() == num_heads, "Invalid weights for number of heads.");
assertion(biases_per_head.size() == num_heads, "Invalid biases for number of heads.");
return fplus::transform(
[&](const std::pair<std::size_t, std::pair<tensor, tensor>>& n_and_w_with_b)
{
return dense_layer(
name + "_" + std::to_string(n_and_w_with_b.first),
units,
*n_and_w_with_b.second.first.as_vector(),
*n_and_w_with_b.second.second.as_vector());
},
fplus::enumerate(fplus::zip(weights_per_head, biases_per_head)));
}
dense_layer create_output_dense_layer(
const tensors& weights_and_biases, bool use_bias, const std::string& name)
{
const std::size_t index_factor = use_bias ? 2 : 1;
const tensor weights = weights_and_biases[index_factor * 3];
const std::size_t units = weights.shape().depth_;
const tensor biases = use_bias ?
weights_and_biases[index_factor * 3 + 1] :
tensor(tensor_shape(units), 0);
return dense_layer(name + "_output", units, *weights.as_vector(), *biases.as_vector());
}
tensors extract_biases(const tensors& saved_weights, bool use_bias)
{
return use_bias ? fplus::unweave(saved_weights).second : tensors();
}
tensor apply_head(
const tensor& query_raw,
const tensor& value_raw,
const tensor& key_raw,
std::size_t head_index) const
{
assertion(
query_raw.shape().rank() == 2 &&
value_raw.shape().rank() == 2 &&
key_raw.shape().rank() == 2 &&
query_raw.shape().depth_ == value_raw.shape().depth_ &&
query_raw.shape().depth_ == key_raw.shape().depth_ &&
value_raw.shape().width_ == key_raw.shape().width_,
"Invalid shapes; need a query tensor of shape (B, T, dim) and a value/key tensor of shape (B, S, dim)."
);
const tensor query = query_dense_[head_index].apply({query_raw}).front();
const tensor value = value_dense_[head_index].apply({value_raw}).front();
const tensor key = key_dense_[head_index].apply({key_raw}).front();

// https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853
// https://dmol.pub/dl/attention.html#multi-head-attention-block
// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/attention/multi_head_attention.py
// https://gist.github.com/sevagh/b71d253a347a9b59c026580625452fc5
const tensor scores = dot_product_tensors(query, transpose(key), std::vector<int>({2, 1}), false);
const std::size_t query_size = query.shape().depth_;
const tensor distribution = softmax(transform_tensor(fplus::multiply_with(1 / std::sqrt(query_size)), scores));
return dot_product_tensors(distribution, value, std::vector<int>({2, 1}), false);
}
protected:
tensors apply_impl(const tensors& input) const override
{
assertion(input.size() == 2 || input.size() == 3, "Invalid number of inputs for MultiHeadAttention layer.");
const tensor query_raw = input[0];
const tensor value_raw = input[1];
const tensor key_raw = input.size() > 2 ? input[2] : value_raw;
const auto outputs = fplus::transform([&](const std::size_t head_idx)
{
return apply_head(query_raw, value_raw, key_raw, head_idx);
}, fplus::numbers<std::size_t>(0, num_heads_));
const tensor merged = concatenate_tensors_depth(outputs);
return output_dense_.apply({merged});
}
std::size_t num_heads_;
std::size_t key_dim_;
std::size_t value_dim_;
std::vector<dense_layer> query_dense_;
std::vector<dense_layer> value_dense_;
std::vector<dense_layer> key_dense_;
dense_layer output_dense_;
};

} } // namespace fdeep, namespace internal
13 changes: 13 additions & 0 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,18 @@ def show_additive_attention_layer(layer):
return data


def show_multi_head_attention_layer(layer):
"""Serialize MultiHeadAttention layer to dict"""
assert len(layer.input_shape) == 3
assert layer.input_shape[0] is None
assert layer._output_shape is None
assert layer._attention_axes == (1,), "MultiHeadAttention supported only with attention_axes=None"
return {
'weight_shapes': list(map(lambda w: list(w.shape), layer.weights)),
'weights': list(map(lambda w: encode_floats(w.numpy()), layer.weights)),
}


def get_layer_functions_dict():
return {
'Conv1D': show_conv_1d_layer,
Expand Down Expand Up @@ -588,6 +600,7 @@ def get_layer_functions_dict():
'CategoryEncoding': show_category_encoding_layer,
'Attention': show_attention_layer,
'AdditiveAttention': show_additive_attention_layer,
'MultiHeadAttention': show_multi_head_attention_layer,
}


Expand Down
35 changes: 35 additions & 0 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow.keras.layers import MaxPooling1D, AveragePooling1D, UpSampling1D
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, UpSampling2D
from tensorflow.keras.layers import MaxPooling3D, AveragePooling3D
from tensorflow.keras.layers import MultiHeadAttention
from tensorflow.keras.layers import Multiply, Add, Subtract, Average, Maximum, Minimum, Dot
from tensorflow.keras.layers import Permute, Reshape, RepeatVector
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
Expand Down Expand Up @@ -435,6 +436,40 @@ def get_test_model_exhaustive():
outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50]]))
outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50], inputs[51]]))

outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=2, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=2, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=2,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=2,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=1, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))
outputs.append(MultiHeadAttention(
num_heads=2, key_dim=3, value_dim=5,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))
outputs.append(MultiHeadAttention(
num_heads=2, key_dim=3, value_dim=5,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))

shared_conv = Conv2D(1, (1, 1),
padding='valid', name='shared_conv', activation='relu')

Expand Down