Skip to content

Commit

Permalink
Implement special treatment for the inbound_nodes format of MultiHead…
Browse files Browse the repository at this point in the history
…Attention
  • Loading branch information
Dobiasd committed Nov 7, 2023
1 parent 14aa03a commit e8efb24
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,11 +1137,35 @@ inline node create_node(const nlohmann::json& inbound_nodes_data)
inbound_nodes_data));
}

inline nodes create_multi_head_attention_nodes(const std::vector<nlohmann::json> inbound_nodes_data)
{
assertion(inbound_nodes_data.size() == 1 && inbound_nodes_data.front().size() == 1,
"multi_head_attention needs to have exactly one primary inbound node; see https://stackoverflow.com/q/77400589/1866775");
const auto inbound_node_data = inbound_nodes_data.front().front();
const auto value = inbound_node_data[3]["value"];
if (json_obj_has_member(inbound_node_data[3], "key")) {
return {
node({
create_node_connection(inbound_node_data),
create_node_connection(value),
create_node_connection(inbound_node_data[3]["key"])
})};
}
return {
node({
create_node_connection(inbound_node_data),
create_node_connection(value)
})};
}

inline nodes create_nodes(const nlohmann::json& data)
{
assertion(data["inbound_nodes"].is_array(), "no inbound nodes");
const std::vector<nlohmann::json> inbound_nodes_data =
data["inbound_nodes"];
if (data["class_name"] == "MultiHeadAttention") {
return create_multi_head_attention_nodes(inbound_nodes_data);
}
return fplus::transform(create_node, inbound_nodes_data);
}

Expand Down

0 comments on commit e8efb24

Please sign in to comment.