7
7
#pragma once
8
8
9
9
#include " fdeep/layers/layer.hpp"
10
+ #include " fdeep/layers/dense_layer.hpp"
10
11
#include " fdeep/layers/softmax_layer.hpp"
11
12
12
13
#include < string>
@@ -20,41 +21,68 @@ class multi_head_attention_layer : public layer
20
21
explicit multi_head_attention_layer (const std::string& name,
21
22
std::size_t num_heads, std::size_t key_dim, std::size_t value_dim,
22
23
bool use_bias, const std::vector<std::size_t >& attention_axes,
23
- const std::vector<tensor>& saved_weights )
24
+ const std::vector<tensor>& weights_and_biases )
24
25
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
25
26
value_dim_(value_dim), attention_axes_(attention_axes),
26
- weights_(extract_weights(saved_weights, use_bias)),
27
- biases_(extract_biases(saved_weights, use_bias))
27
+ query_dense_(create_dense_layer(weights_and_biases, use_bias, 0 , name + " _query_dense" )),
28
+ value_dense_(create_dense_layer(weights_and_biases, use_bias, 1 , name + " _value_dense" )),
29
+ key_dense_(create_dense_layer(weights_and_biases, use_bias, 2 , name + " _key_dense" )),
30
+ output_dense_(create_dense_layer(weights_and_biases, use_bias, 3 , name + " _output_dense" ))
28
31
{
29
32
}
30
33
private:
31
- tensors extract_weights (const tensors& saved_weights, bool use_bias)
34
+ dense_layer create_dense_layer (
35
+ const tensors& weights_and_biases, bool use_bias,
36
+ std::size_t index, const std::string& name)
32
37
{
33
- return use_bias ? fplus::unweave (saved_weights).first : saved_weights;
38
+ const std::size_t index_factor = use_bias ? 2 : 1 ;
39
+ const tensor weights = weights_and_biases[index_factor * index];
40
+ const std::size_t n = weights.shape ().width_ * weights.shape ().depth_ ;
41
+ const tensor biases = use_bias ?
42
+ weights_and_biases[index_factor * index + 1 ] :
43
+ tensor (tensor_shape (n), 1 );
44
+ return dense_layer (name, n, *weights.as_vector (), *biases.as_vector ());
34
45
}
35
46
tensors extract_biases (const tensors& saved_weights, bool use_bias)
36
47
{
37
- return use_bias ? fplus::unweave (saved_weights).second : tensors (); // todo: create biases with zeroes in right shape
48
+ return use_bias ? fplus::unweave (saved_weights).second : tensors ();
38
49
}
39
50
protected:
40
51
tensors apply_impl (const tensors& input) const override
41
52
{
42
53
assertion (input.size () == 2 || input.size () == 3 , " Invalid number of inputs for MultiHeadAttention layer." );
43
- // const tensor& query = input[0];
44
- // const tensor& value = input[1];
45
- // const tensor& key = input.size() > 2 ? input[2] : value;
54
+ const tensor query_raw = input[0 ];
55
+ const tensor value_raw = input[1 ];
56
+ const tensor key_raw = input.size () > 2 ? input[2 ] : value_raw;
57
+ const tensor query = query_dense_.apply ({query_raw}).front ();
58
+ const tensor value = value_dense_.apply ({value_raw}).front ();
59
+ const tensor key = key_dense_.apply ({key_raw}).front ();
60
+ assertion (
61
+ query.shape ().rank () == 2 &&
62
+ value.shape ().rank () == 2 &&
63
+ key.shape ().rank () == 2 &&
64
+ query.shape ().depth_ == value.shape ().depth_ &&
65
+ query.shape ().depth_ == key.shape ().depth_ &&
66
+ value.shape ().width_ == key.shape ().width_ ,
67
+ " Invalid shapes; need a query tensor of shape (B, T, dim) and a value/key tensor of shape (B, S, dim)."
68
+ );
46
69
// https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853
47
70
// https://dmol.pub/dl/attention.html#multi-head-attention-block
48
71
// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/attention/multi_head_attention.py
49
72
// https://gist.github.com/sevagh/b71d253a347a9b59c026580625452fc5
50
- return input;
73
+ const tensor scores = dot_product_tensors (query, transpose (key), std::vector<int >({2 , 1 }), false );
74
+ const tensor distribution = softmax (scores);
75
+ const tensor output = dot_product_tensors (distribution, value, std::vector<int >({2 , 1 }), false );
76
+ return output_dense_.apply ({output});
51
77
}
52
78
std::size_t num_heads_;
53
79
std::size_t key_dim_;
54
80
std::size_t value_dim_;
55
81
std::vector<std::size_t > attention_axes_;
56
- std::vector<tensor> weights_;
57
- std::vector<tensor> biases_;
82
+ dense_layer query_dense_;
83
+ dense_layer value_dense_;
84
+ dense_layer key_dense_;
85
+ dense_layer output_dense_;
58
86
};
59
87
60
88
} } // namespace fdeep, namespace internal
0 commit comments