Skip to content

Commit 118f663

Browse files
committed
apply dense layers to query, value and key
1 parent 212d609 commit 118f663

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

include/fdeep/import_model.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,15 +1084,15 @@ inline layer_ptr create_multi_head_attention_layer(
10841084
create_vector<std::size_t, decltype(create_size_t)>, create_size_t),
10851085
get_param(name, "weight_shapes"));
10861086
const auto weight_values = create_vector<float_vec>(decode_floats, get_param(name, "weights"));
1087-
const auto weights = fplus::zip_with(
1087+
const auto weights_and_biases = fplus::zip_with(
10881088
[](const std::vector<std::size_t>& shape, const float_vec& values) -> tensor
10891089
{
10901090
return tensor(
10911091
create_tensor_shape_from_dims(shape),
10921092
fplus::convert_container<float_vec>(values));
10931093
}, weight_shapes, weight_values);
10941094
return std::make_shared<multi_head_attention_layer>(name,
1095-
num_heads, key_dim, value_dim, use_bias, attention_axes, weights);
1095+
num_heads, key_dim, value_dim, use_bias, attention_axes, weights_and_biases);
10961096
}
10971097

10981098
inline std::string get_activation_type(const nlohmann::json& data)

include/fdeep/layers/multi_head_attention_layer.hpp

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include "fdeep/layers/layer.hpp"
10+
#include "fdeep/layers/dense_layer.hpp"
1011
#include "fdeep/layers/softmax_layer.hpp"
1112

1213
#include <string>
@@ -20,41 +21,68 @@ class multi_head_attention_layer : public layer
2021
explicit multi_head_attention_layer(const std::string& name,
2122
std::size_t num_heads, std::size_t key_dim, std::size_t value_dim,
2223
bool use_bias, const std::vector<std::size_t>& attention_axes,
23-
const std::vector<tensor>& saved_weights)
24+
const std::vector<tensor>& weights_and_biases)
2425
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
2526
value_dim_(value_dim), attention_axes_(attention_axes),
26-
weights_(extract_weights(saved_weights, use_bias)),
27-
biases_(extract_biases(saved_weights, use_bias))
27+
query_dense_(create_dense_layer(weights_and_biases, use_bias, 0, name + "_query_dense")),
28+
value_dense_(create_dense_layer(weights_and_biases, use_bias, 1, name + "_value_dense")),
29+
key_dense_(create_dense_layer(weights_and_biases, use_bias, 2, name + "_key_dense")),
30+
output_dense_(create_dense_layer(weights_and_biases, use_bias, 3, name + "_output_dense"))
2831
{
2932
}
3033
private:
31-
tensors extract_weights(const tensors& saved_weights, bool use_bias)
34+
dense_layer create_dense_layer(
35+
const tensors& weights_and_biases, bool use_bias,
36+
std::size_t index, const std::string& name)
3237
{
33-
return use_bias ? fplus::unweave(saved_weights).first : saved_weights;
38+
const std::size_t index_factor = use_bias ? 2 : 1;
39+
const tensor weights = weights_and_biases[index_factor * index];
40+
const std::size_t n = weights.shape().width_ * weights.shape().depth_;
41+
const tensor biases = use_bias ?
42+
weights_and_biases[index_factor * index + 1] :
43+
tensor(tensor_shape(n), 1);
44+
return dense_layer(name, n, *weights.as_vector(), *biases.as_vector());
3445
}
3546
tensors extract_biases(const tensors& saved_weights, bool use_bias)
3647
{
37-
return use_bias ? fplus::unweave(saved_weights).second : tensors(); // todo: create biases with zeroes in right shape
48+
return use_bias ? fplus::unweave(saved_weights).second : tensors();
3849
}
3950
protected:
4051
tensors apply_impl(const tensors& input) const override
4152
{
4253
assertion(input.size() == 2 || input.size() == 3, "Invalid number of inputs for MultiHeadAttention layer.");
43-
//const tensor& query = input[0];
44-
//const tensor& value = input[1];
45-
//const tensor& key = input.size() > 2 ? input[2] : value;
54+
const tensor query_raw = input[0];
55+
const tensor value_raw = input[1];
56+
const tensor key_raw = input.size() > 2 ? input[2] : value_raw;
57+
const tensor query = query_dense_.apply({query_raw}).front();
58+
const tensor value = value_dense_.apply({value_raw}).front();
59+
const tensor key = key_dense_.apply({key_raw}).front();
60+
assertion(
61+
query.shape().rank() == 2 &&
62+
value.shape().rank() == 2 &&
63+
key.shape().rank() == 2 &&
64+
query.shape().depth_ == value.shape().depth_ &&
65+
query.shape().depth_ == key.shape().depth_ &&
66+
value.shape().width_ == key.shape().width_,
67+
"Invalid shapes; need a query tensor of shape (B, T, dim) and a value/key tensor of shape (B, S, dim)."
68+
);
4669
// https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853
4770
// https://dmol.pub/dl/attention.html#multi-head-attention-block
4871
// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/attention/multi_head_attention.py
4972
// https://gist.github.com/sevagh/b71d253a347a9b59c026580625452fc5
50-
return input;
73+
const tensor scores = dot_product_tensors(query, transpose(key), std::vector<int>({2, 1}), false);
74+
const tensor distribution = softmax(scores);
75+
const tensor output = dot_product_tensors(distribution, value, std::vector<int>({2, 1}), false);
76+
return output_dense_.apply({output});
5177
}
5278
std::size_t num_heads_;
5379
std::size_t key_dim_;
5480
std::size_t value_dim_;
5581
std::vector<std::size_t> attention_axes_;
56-
std::vector<tensor> weights_;
57-
std::vector<tensor> biases_;
82+
dense_layer query_dense_;
83+
dense_layer value_dense_;
84+
dense_layer key_dense_;
85+
dense_layer output_dense_;
5886
};
5987

6088
} } // namespace fdeep, namespace internal

0 commit comments

Comments
 (0)