Skip to content

Commit

Permalink
Separate weights and biases
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 26, 2023
1 parent 09d5868 commit 212d609
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,22 @@ class multi_head_attention_layer : public layer
explicit multi_head_attention_layer(const std::string& name,
std::size_t num_heads, std::size_t key_dim, std::size_t value_dim,
bool use_bias, const std::vector<std::size_t>& attention_axes,
const std::vector<tensor>& weights)
const std::vector<tensor>& saved_weights)
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
value_dim_(value_dim), use_bias_(use_bias), attention_axes_(attention_axes),
weights_(weights)
value_dim_(value_dim), attention_axes_(attention_axes),
weights_(extract_weights(saved_weights, use_bias)),
biases_(extract_biases(saved_weights, use_bias))
{
}
private:
tensors extract_weights(const tensors& saved_weights, bool use_bias)
{
return use_bias ? fplus::unweave(saved_weights).first : saved_weights;
}
tensors extract_biases(const tensors& saved_weights, bool use_bias)
{
return use_bias ? fplus::unweave(saved_weights).second : tensors(); // todo: create biases with zeroes in right shape
}
protected:
tensors apply_impl(const tensors& input) const override
{
Expand All @@ -42,9 +52,9 @@ class multi_head_attention_layer : public layer
std::size_t num_heads_;
std::size_t key_dim_;
std::size_t value_dim_;
bool use_bias_;
std::vector<std::size_t> attention_axes_;
std::vector<tensor> weights_;
std::vector<tensor> biases_;
};

} } // namespace fdeep, namespace internal

0 comments on commit 212d609

Please sign in to comment.