From e8efb244a42da3d7d2e22432dea25aa316a16e09 Mon Sep 17 00:00:00 2001 From: Dobiasd Date: Tue, 7 Nov 2023 08:24:30 +0100 Subject: [PATCH] Implement special treatment for the inbound_nodes format of MultiHeadAttention --- include/fdeep/import_model.hpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/include/fdeep/import_model.hpp b/include/fdeep/import_model.hpp index 5a01da47..9db2c3bf 100644 --- a/include/fdeep/import_model.hpp +++ b/include/fdeep/import_model.hpp @@ -1137,11 +1137,35 @@ inline node create_node(const nlohmann::json& inbound_nodes_data) inbound_nodes_data)); } +inline nodes create_multi_head_attention_nodes(const std::vector inbound_nodes_data) +{ + assertion(inbound_nodes_data.size() == 1 && inbound_nodes_data.front().size() == 1, + "multi_head_attention needs to have exactly one primary inbound node; see https://stackoverflow.com/q/77400589/1866775"); + const auto inbound_node_data = inbound_nodes_data.front().front(); + const auto value = inbound_node_data[3]["value"]; + if (json_obj_has_member(inbound_node_data[3], "key")) { + return { + node({ + create_node_connection(inbound_node_data), + create_node_connection(value), + create_node_connection(inbound_node_data[3]["key"]) + })}; + } + return { + node({ + create_node_connection(inbound_node_data), + create_node_connection(value) + })}; +} + inline nodes create_nodes(const nlohmann::json& data) { assertion(data["inbound_nodes"].is_array(), "no inbound nodes"); const std::vector inbound_nodes_data = data["inbound_nodes"]; + if (data["class_name"] == "MultiHeadAttention") { + return create_multi_head_attention_nodes(inbound_nodes_data); + } return fplus::transform(create_node, inbound_nodes_data); }