Skip to content

Commit

Permalink
Convert weights to tensors for ctor
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 26, 2023
1 parent 78c0d4e commit 09d5868
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 8 additions & 2 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,14 @@ inline layer_ptr create_multi_head_attention_layer(
create_vector<std::vector<std::size_t>>(fplus::bind_1st_of_2(
create_vector<std::size_t, decltype(create_size_t)>, create_size_t),
get_param(name, "weight_shapes"));
const auto weights = create_vector<float_vec>(decode_floats, get_param(name, "weights"));
// todo: Convert weight_shapes and weights to Tensors before passing to ctor?
const auto weight_values = create_vector<float_vec>(decode_floats, get_param(name, "weights"));
const auto weights = fplus::zip_with(
[](const std::vector<std::size_t>& shape, const float_vec& values) -> tensor
{
return tensor(
create_tensor_shape_from_dims(shape),
fplus::convert_container<float_vec>(values));
}, weight_shapes, weight_values);
return std::make_shared<multi_head_attention_layer>(name,
num_heads, key_dim, value_dim, use_bias, attention_axes, weights);
}
Expand Down
4 changes: 2 additions & 2 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class multi_head_attention_layer : public layer
explicit multi_head_attention_layer(const std::string& name,
std::size_t num_heads, std::size_t key_dim, std::size_t value_dim,
bool use_bias, const std::vector<std::size_t>& attention_axes,
const std::vector<float_vec>& weights)
const std::vector<tensor>& weights)
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
value_dim_(value_dim), use_bias_(use_bias), attention_axes_(attention_axes),
weights_(weights)
Expand All @@ -44,7 +44,7 @@ class multi_head_attention_layer : public layer
std::size_t value_dim_;
bool use_bias_;
std::vector<std::size_t> attention_axes_;
std::vector<float_vec> weights_;
std::vector<tensor> weights_;
};

} } // namespace fdeep, namespace internal

0 comments on commit 09d5868

Please sign in to comment.