Skip to content

Commit

Permalink
Adjust comment in multi_head_attention_layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Nov 7, 2023
1 parent e8efb24 commit 7a7928d
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class multi_head_attention_layer : public layer
protected:
tensors apply_impl(const tensors& input) const override
{
// input.size() is 1. How shall the other tensors passed here? How is it in TF?
// https://stackoverflow.com/questions/77400589/what-is-the-reason-for-multiheadattention-having-a-different-call-convention-tha
// todo: implement
assertion(input.size() == 2 || input.size() == 3, "Invalid number of inputs for MultiHeadAttention layer.");
//const tensor& query = input[0];
//const tensor& value = input[1];
//const tensor& key = input.size() > 2 ? input[2] : value;
// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/attention/multi_head_attention.py
return input;
}
std::size_t num_heads_;
Expand Down

0 comments on commit 7a7928d

Please sign in to comment.