Skip to content

Commit

Permalink
double-check weights shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 31, 2023
1 parent a95abf4 commit 41ac53a
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,27 @@ class multi_head_attention_layer : public layer
bool use_bias, const std::vector<tensor>& weights_and_biases)
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
value_dim_(value_dim),
query_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 0, name + "_query_dense")),
value_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 2, name + "_value_dense")),
key_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 1, name + "_key_dense")),
query_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 0, key_dim, name + "_query_dense")),
value_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 2, value_dim, name + "_value_dense")),
key_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 1, key_dim, name + "_key_dense")),
output_dense_(create_output_dense_layer(weights_and_biases, use_bias, name + "_output_dense"))
{
}
private:
std::vector<dense_layer> create_dense_layers(
const tensors& weights_and_biases, bool use_bias, const std::size_t num_heads,
const std::size_t index, const std::string& name)
const std::size_t index, const std::size_t units, const std::string& name)
{
assertion(index <= 2, "Invalid dense layer index.");
const std::size_t index_factor = use_bias ? 2 : 1;
const tensor weights = weights_and_biases[index_factor * index];
const std::size_t units = weights.shape().depth_;
const tensor biases = use_bias ?
weights_and_biases[index_factor * index + 1] :
tensor(tensor_shape(num_heads, units), 0);

assertion(weights.shape().depth_ == units, "Invalid weights shape for attention head dimension.");
assertion(biases.shape().depth_ == units, "Invalid biases shape for attention head dimension.");

const auto weights_per_head = tensor_to_tensors_width_slices(weights);
const auto biases_per_head = tensor_to_tensors_width_slices(biases);
assertion(weights_per_head.size() == num_heads, "Invalid weights for number of heads.");
Expand Down

0 comments on commit 41ac53a

Please sign in to comment.