From 4a8fa9e23009fa33b1e82fb642c2bf2cb2d85863 Mon Sep 17 00:00:00 2001 From: luxincn <2391719912@qq.com> Date: Sat, 24 Aug 2024 23:48:28 +0800 Subject: [PATCH 01/28] add annotation --- tools/pnnx/src/main.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index c25128032d9e..8bf6fcc869ea 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -378,6 +378,8 @@ int main(int argc, char** argv) pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath, fp16); } + // flops count + // pnnx::Graph pnnx_graph2; // pnnx_graph2.load("pnnx.param", "pnnx.bin"); From e48e094541875d65a7f1b71b7f39bf07cfb77fcb Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Sun, 1 Sep 2024 16:39:21 +0800 Subject: [PATCH 02/28] pnnx print flops memops count --- tools/pnnx/src/ir.cpp | 27 +++++++++++++++++++++++++++ tools/pnnx/src/ir.h | 3 +++ tools/pnnx/src/main.cpp | 3 +++ 3 files changed, 33 insertions(+) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 8b2b6dfd2d7f..994371954d6d 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include "storezip.h" #include "utils.h" @@ -1441,6 +1442,32 @@ static std::string make_index_expression(const Operator* op) return index_expr; } +int Graph::calculate_flops() +{ + int flops = 0; + for(auto op:ops) { + if(expand_expression(op) == "*") + { + int m = op->inputs[0]->shape[0]; + int k = op->inputs[0]->shape[1]; + int n = op->inputs[1]->shape[1]; + flops += 2 * m * k * n; + } + else if(expand_expression(op) == "+") { + int m = op->inputs[0]->shape[0]; + int n = op->inputs[0]->shape[1]; + flops += m * n; + } + } + return flops; +} + +int Graph::calculate_memops() +{ + int mem = sizeof(Operator) * ops.size() + sizeof(Operand) * operands.size(); + return mem; +} + int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { FILE* pyfp = fopen(pypath.c_str(), "wb"); diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 779c2eec9f10..91e0e2a69fe3 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -346,6 +346,9 @@ class Graph std::vector ops; std::vector operands; + int calculate_flops(); + int calculate_memops(); + private: Graph(const Graph& rhs); Graph& operator=(const Graph& rhs); diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index c25128032d9e..dda54b1932dd 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -361,6 +361,9 @@ int main(int argc, char** argv) pnnx_graph.save(pnnxparampath, pnnxbinpath); + fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops()); + fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops()); + pnnx_graph.python(pnnxpypath, pnnxbinpath); #if BUILD_PNNX2ONNX From db91abd606ff00d7ebf9bdb2d3349720eb07fd21 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Sun, 1 Sep 2024 16:57:58 +0800 Subject: [PATCH 03/28] pnnx print flops memops count --- tools/pnnx/src/ir.cpp | 28 ++++++++++++++++++++++------ tools/pnnx/src/ir.h | 4 ++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 994371954d6d..7c5923acac1e 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1442,9 +1442,9 @@ static std::string make_index_expression(const Operator* op) return index_expr; } -int Graph::calculate_flops() +int Graph::calculate_flops_M() { - int flops = 0; + long long flops = 0; for(auto op:ops) { if(expand_expression(op) == "*") { @@ -1459,13 +1459,29 @@ int Graph::calculate_flops() flops += m * n; } } - return flops; + return int(flops / 1e6); } -int Graph::calculate_memops() +int Graph::calculate_memops_M() { - int mem = sizeof(Operator) * ops.size() + sizeof(Operand) * operands.size(); - return mem; + long long mem = 0; + for(auto op : ops) + { + if(expand_expression(op) == "*") + { + int m = op->inputs[0]->shape[0]; + int k = op->inputs[0]->shape[1]; + int n = op->inputs[1]->shape[1]; + mem += m * k + k * n + m * n; + } + else if(expand_expression(op) == "+") + { + int m = op->inputs[0]->shape[0]; + int n = op->inputs[0]->shape[1]; + mem += 3 * m * n; + } + } + return int(mem / 1e6); } int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 91e0e2a69fe3..bc1f0089591d 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -346,8 +346,8 @@ class Graph std::vector ops; std::vector operands; - int calculate_flops(); - int calculate_memops(); + int calculate_flops_M(); + int calculate_memops_M(); private: Graph(const Graph& rhs); From 4af97a8c0cb75e0ed171e264b176d77f099850dc Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Mon, 2 Sep 2024 10:49:01 +0800 Subject: [PATCH 04/28] pnnx print flops memops count --- tools/pnnx/src/main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index dda54b1932dd..32e628be8d6f 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -361,8 +361,8 @@ int main(int argc, char** argv) pnnx_graph.save(pnnxparampath, pnnxbinpath); - fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops()); - fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops()); + fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); + fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); pnnx_graph.python(pnnxpypath, pnnxbinpath); From a4fd3191d66a8d96e679eb7ae8fdd8b8a6e80d49 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Mon, 2 Sep 2024 17:47:02 +0800 Subject: [PATCH 05/28] pnnx print flops memops count --- tools/pnnx/src/main.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 32e628be8d6f..5ef47b2409ac 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -313,6 +313,8 @@ int main(int argc, char** argv) std::string foldable_constants_zippath = ptbase + ".foldable_constants.zip"; pnnx::Graph pnnx_graph; + fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); + fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); #if BUILD_ONNX2PNNX if (!model_file_maybe_torchscript(ptpath)) { @@ -361,9 +363,6 @@ int main(int argc, char** argv) pnnx_graph.save(pnnxparampath, pnnxbinpath); - fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); - fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); - pnnx_graph.python(pnnxpypath, pnnxbinpath); #if BUILD_PNNX2ONNX From 54659e042f843547f18b4b908da300ae2af89f8d Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Mon, 2 Sep 2024 17:53:29 +0800 Subject: [PATCH 06/28] pnnx print flops memops count --- tools/pnnx/src/main.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 5ef47b2409ac..23fdc0102224 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -313,8 +313,6 @@ int main(int argc, char** argv) std::string foldable_constants_zippath = ptbase + ".foldable_constants.zip"; pnnx::Graph pnnx_graph; - fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); - fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); #if BUILD_ONNX2PNNX if (!model_file_maybe_torchscript(ptpath)) { @@ -384,6 +382,7 @@ int main(int argc, char** argv) // pnnx_graph2.load("pnnx.param", "pnnx.bin"); // pnnx_graph2.save("pnnx2.param", "pnnx2.bin"); - + fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); + fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); return 0; } From 2c80f272cbbe9e2e0ae09321bb1af7eac19e01b0 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Tue, 3 Sep 2024 08:38:27 +0800 Subject: [PATCH 07/28] pnnx print flops memops count --- tools/pnnx/src/ir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 7c5923acac1e..6dd429ebbfb4 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1446,14 +1446,14 @@ int Graph::calculate_flops_M() { long long flops = 0; for(auto op:ops) { - if(expand_expression(op) == "*") + if(op->type == "aten::matmul") { int m = op->inputs[0]->shape[0]; int k = op->inputs[0]->shape[1]; int n = op->inputs[1]->shape[1]; flops += 2 * m * k * n; } - else if(expand_expression(op) == "+") { + else if(op->type == "aten::add") { int m = op->inputs[0]->shape[0]; int n = op->inputs[0]->shape[1]; flops += m * n; @@ -1467,14 +1467,14 @@ int Graph::calculate_memops_M() long long mem = 0; for(auto op : ops) { - if(expand_expression(op) == "*") + if(op->type == "aten::matmul") { int m = op->inputs[0]->shape[0]; int k = op->inputs[0]->shape[1]; int n = op->inputs[1]->shape[1]; mem += m * k + k * n + m * n; } - else if(expand_expression(op) == "+") + else if(op->type == "aten::add") { int m = op->inputs[0]->shape[0]; int n = op->inputs[0]->shape[1]; From a89c0f7ce11c58fdef3c292bb198e6c0c741c8cb Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Tue, 3 Sep 2024 09:39:43 +0800 Subject: [PATCH 08/28] pnnx print flops memops count --- tools/pnnx/src/ir.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 6dd429ebbfb4..985cf47a3b97 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1459,7 +1459,7 @@ int Graph::calculate_flops_M() flops += m * n; } } - return int(flops / 1e6); + return int(flops); } int Graph::calculate_memops_M() @@ -1481,7 +1481,7 @@ int Graph::calculate_memops_M() mem += 3 * m * n; } } - return int(mem / 1e6); + return int(mem); } int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) From d247bb9aba088946b0ea86b0fff0bb62fa0f6c82 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Tue, 3 Sep 2024 09:42:58 +0800 Subject: [PATCH 09/28] pnnx print flops memops count --- tools/pnnx/src/ir.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 985cf47a3b97..59f26b17f054 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -14,6 +14,7 @@ #include "ir.h" +#include #include #include #include @@ -1467,6 +1468,7 @@ int Graph::calculate_memops_M() long long mem = 0; for(auto op : ops) { + fprintf(stderr, "%s\n", op->type.c_str()); if(op->type == "aten::matmul") { int m = op->inputs[0]->shape[0]; From 8172355f9f592fecdde081056e120d4b15db5aa4 Mon Sep 17 00:00:00 2001 From: luxincn <2391719912@qq.com> Date: Fri, 6 Sep 2024 20:41:18 +0800 Subject: [PATCH 10/28] add flops_mem_count() --- tools/pnnx/src/ir.cpp | 57 +++++++++++++++++++++++++++++++++++++++++ tools/pnnx/src/ir.h | 12 +++++++++ tools/pnnx/src/main.cpp | 6 ++++- 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 8b2b6dfd2d7f..7467dc1e512d 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2862,4 +2862,61 @@ const Operand* Graph::get_operand(const std::string& name) const return 0; } +pnnx::ModelInfo Graph::flops_mem_count() +{ + pnnx::ModelInfo m; + for (const Operator* op : ops) + { + if(op->type == "nn.Conv2d") + { + if(op->inputs[0]->type != 0) + { + int ci = op->inputs[0]->shape[1]; + int kw = op->params.at("kernel_size").ai[0]; + int kh = op->params.at("kernel_size").ai[1]; + int co = op->params.at("out_channels").i; + int w = op->outputs[0]->shape[3]; + int h = op->outputs[0]->shape[2]; + int bias = op->params.at("bias").b ? 1 : 0; + int wi = op->inputs[0]->shape[2]; + int hi = op->inputs[0]->shape[3]; + int g = op->params.at("groups").i; + if(bias == 1) + { + m.flops += 2 * ci * kw * kh * co * w * h; + } + else + { + m.flops += (2 * ci * kw * kh -1) * co * w * h; + } + int input_m = wi * hi * ci; + int output_m = w * h * co; + int weights_m = kw * kh * ci * co; + m.memory_access += input_m + output_m + weights_m; + } + } + else if(op->type == "nn.Linear") + { + int in = op->params.at("in_features").i; + int out = op->params.at("out_features").i; + int bias = op->params.at("bias").b ? 1 : 0; + if(bias == 1) + { + m.flops += 2 * in * out; + } + else + { + m.flops += (2 * in - 1) * out; + } + m.memory_access += in + out + in * out; + } + else + { + + } + } + + return m; +} + } // namespace pnnx diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 779c2eec9f10..a4486f6c7b50 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -51,6 +51,16 @@ class OnnxAttributeProxy; namespace pnnx { +struct ModelInfo { + ModelInfo() + : flops(0), memory_access(0) + { + } + + long long flops; + long long memory_access; +}; + class Parameter { public: @@ -324,6 +334,8 @@ class Graph int parse(const std::string& param); + ModelInfo flops_mem_count(); + Operator* new_operator(const std::string& type, const std::string& name); Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur); diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 8bf6fcc869ea..04d4e23a98e2 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -363,6 +363,11 @@ int main(int argc, char** argv) pnnx_graph.python(pnnxpypath, pnnxbinpath); + // count float + pnnx::ModelInfo md = pnnx_graph.flops_mem_count(); + fprintf(stderr, "float ops: %lld\n", md.flops); + fprintf(stderr, "memory ops: %lld\n", md.memory_access); + #if BUILD_PNNX2ONNX pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); #else @@ -378,7 +383,6 @@ int main(int argc, char** argv) pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath, fp16); } - // flops count // pnnx::Graph pnnx_graph2; From 103111985ac3969c8203d5fdbffe2c9e28187d41 Mon Sep 17 00:00:00 2001 From: luxincn Date: Fri, 6 Sep 2024 13:01:35 +0000 Subject: [PATCH 11/28] apply code-format changes --- tools/pnnx/src/ir.cpp | 13 ++++++------- tools/pnnx/src/ir.h | 3 ++- tools/pnnx/src/main.cpp | 1 - 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 7467dc1e512d..d9d5203b1bf1 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2867,9 +2867,9 @@ pnnx::ModelInfo Graph::flops_mem_count() pnnx::ModelInfo m; for (const Operator* op : ops) { - if(op->type == "nn.Conv2d") + if (op->type == "nn.Conv2d") { - if(op->inputs[0]->type != 0) + if (op->inputs[0]->type != 0) { int ci = op->inputs[0]->shape[1]; int kw = op->params.at("kernel_size").ai[0]; @@ -2881,13 +2881,13 @@ pnnx::ModelInfo Graph::flops_mem_count() int wi = op->inputs[0]->shape[2]; int hi = op->inputs[0]->shape[3]; int g = op->params.at("groups").i; - if(bias == 1) + if (bias == 1) { m.flops += 2 * ci * kw * kh * co * w * h; } else { - m.flops += (2 * ci * kw * kh -1) * co * w * h; + m.flops += (2 * ci * kw * kh - 1) * co * w * h; } int input_m = wi * hi * ci; int output_m = w * h * co; @@ -2895,12 +2895,12 @@ pnnx::ModelInfo Graph::flops_mem_count() m.memory_access += input_m + output_m + weights_m; } } - else if(op->type == "nn.Linear") + else if (op->type == "nn.Linear") { int in = op->params.at("in_features").i; int out = op->params.at("out_features").i; int bias = op->params.at("bias").b ? 1 : 0; - if(bias == 1) + if (bias == 1) { m.flops += 2 * in * out; } @@ -2912,7 +2912,6 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - } } diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index a4486f6c7b50..aefe6d598026 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -51,7 +51,8 @@ class OnnxAttributeProxy; namespace pnnx { -struct ModelInfo { +struct ModelInfo +{ ModelInfo() : flops(0), memory_access(0) { diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 04d4e23a98e2..0a0e71ebdc15 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -383,7 +383,6 @@ int main(int argc, char** argv) pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath, fp16); } - // pnnx::Graph pnnx_graph2; // pnnx_graph2.load("pnnx.param", "pnnx.bin"); From 62eb64898435ab9294df16a1ace61e8fd953e858 Mon Sep 17 00:00:00 2001 From: luxincn <2391719912@qq.com> Date: Tue, 10 Sep 2024 22:58:57 +0800 Subject: [PATCH 12/28] update flops_mem_count --- tools/pnnx/src/ir.cpp | 100 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index d9d5203b1bf1..35f66cb98f85 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2910,6 +2910,106 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += in + out + in * out; } + else if (op->type == "nn.MultiheadAttention") + { + int in_size = op->inputs.size(); + + if (std::find(op->nputnames.begin(), op->inputnames.end(), "attn_mask") != op->inputnames.end()) + { + in_size -= 1; + } + + int q_l, k_s, v_s; + bool batch_first = op->params.find("batch_first") != op->params.end() && op->params.at("batch_first").b; + + if (in_size == 3) + { + q_l = op->inputs[0]->shape[batch_first ? 1 : 0]; + k_s = op->inputs[1]->shape[batch_first ? 1 : 0]; + v_s = op->inputs[2]->shape[batch_first ? 1 : 0]; + } + else if (in_size == 2) + { + q_l = op->inputs[0]->shape[batch_first ? 1 : 0]; + k_s = op->inputs[1]->shape[batch_first ? 1 : 0]; + v_s = k_s; + } + else + { + q_l = op->inputs[0]->shape[batch_first ? 1 : 0]; + k_s = q_l; + v_s = q_l; + } + + int num_heads = op->params.at("num_heads").i; + int embed_dim = op->params.at("embed_dim").i; + int Kdim = op->params.at("kdim").i; + int vdim = op->params.at("vdim").i; + + long long linear1 = q_l * embed_dim * embed_dim + k_s * embed_dim * Kdim + v_s * embed_dim * vdim; + long long attention = q_l * k_s * embed_dim + 2 * q_l * k_s * num_heads + q_l * v_s * embed_dim; + long long linerar2 = q_l* embed_dim * embed_dim; + m.flops += linear1 + attention + linerar2; + + long long weights = embed_dim * embed_dim + embed_dim * Kdim + embed_dim * vdim + num_heads * vdim * embed_dim; + long long in = q_l * embed_dim + k_s * Kdim + v_s * vdim; + long long attention_m = q_l * embed_dim + k_s * Kdim + 2 * q_l * k_s + v_s * vdim; + long long out = q_l * embed_dim; + m.memory_access += weights + in + attention_m + out; + } + else if (op->type == "nn.MaxPool2d") + { + int num_o = op->params.at("return_indices").b ? 2 : 1; + int batch_size, in_c, in_h, in_w, out_h, out_w; + if (op->inputs[0]->shape.size() == 4) + { + batch_size = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + out_h = op->outputs[0]->shape[2]; + out_w = op->outputs[0]->shape[3]; + } + else if (op->inputs[0]->shape.size() == 3) + { + batch_size = 1; + in_c = op->inputs[0]->shape[0]; + in_h = op->inputs[0]->shape[1]; + in_w = op->inputs[0]->shape[2]; + out_h = op->outputs[0]->shape[1]; + out_w = op->outputs[0]->shape[2]; + } + m.memory_access += batch_size * in_c * ( in_h * in_w + out_h * out_w * num_o ) + } + else if (op->type == "nn.AvgPool2d") + { + int batch_size, in_c, in_h, in_w, out_h, out_w, k_h, k_w, kernel_add, kernel_avg; + if (op->inputs[0]->shape.size() == 4) + { + batch_size = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + out_h = op->outputs[0]->shape[2]; + out_w = op->outputs[0]->shape[3]; + } + else if (op->inputs[0]->shape.size() == 3) + { + batch_size = 1; + in_c = op->inputs[0]->shape[0]; + in_h = op->inputs[0]->shape[1]; + in_w = op->inputs[0]->shape[2]; + out_h = op->outputs[0]->shape[1]; + out_w = op->outputs[0]->shape[2]; + } + k_h = op->params.at("kernel_size").ai[0]; + k_w = op->params.at("kernel_size").ai[1]; + + kernel_add = k_h * k_w - 1; + kernel_avg = 1; + m.flops += ( kernel_add + kernel_avg ) * ( out_h * out_w ) * in_c; + m.memory_access += batch_size * in_c * ( in_h * in_w + out_h * out_w ) + } else { } From 48c6f18a8e9ebd39b1c12b2264ce29d87d2784bd Mon Sep 17 00:00:00 2001 From: luxincn Date: Tue, 10 Sep 2024 15:00:40 +0000 Subject: [PATCH 13/28] apply code-format changes --- tools/pnnx/src/ir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 35f66cb98f85..16d7c96f5164 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2948,7 +2948,7 @@ pnnx::ModelInfo Graph::flops_mem_count() long long linear1 = q_l * embed_dim * embed_dim + k_s * embed_dim * Kdim + v_s * embed_dim * vdim; long long attention = q_l * k_s * embed_dim + 2 * q_l * k_s * num_heads + q_l * v_s * embed_dim; - long long linerar2 = q_l* embed_dim * embed_dim; + long long linerar2 = q_l * embed_dim * embed_dim; m.flops += linear1 + attention + linerar2; long long weights = embed_dim * embed_dim + embed_dim * Kdim + embed_dim * vdim + num_heads * vdim * embed_dim; @@ -2979,7 +2979,7 @@ pnnx::ModelInfo Graph::flops_mem_count() out_h = op->outputs[0]->shape[1]; out_w = op->outputs[0]->shape[2]; } - m.memory_access += batch_size * in_c * ( in_h * in_w + out_h * out_w * num_o ) + m.memory_access += batch_size * in_c * (in_h * in_w + out_h * out_w * num_o) } else if (op->type == "nn.AvgPool2d") { @@ -3007,8 +3007,8 @@ pnnx::ModelInfo Graph::flops_mem_count() kernel_add = k_h * k_w - 1; kernel_avg = 1; - m.flops += ( kernel_add + kernel_avg ) * ( out_h * out_w ) * in_c; - m.memory_access += batch_size * in_c * ( in_h * in_w + out_h * out_w ) + m.flops += (kernel_add + kernel_avg) * (out_h * out_w) * in_c; + m.memory_access += batch_size * in_c * (in_h * in_w + out_h * out_w) } else { From b10faaf661c085179e3a62b33c24cf4826c12e38 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Wed, 11 Sep 2024 17:14:53 +0800 Subject: [PATCH 14/28] test --- tools/pnnx/src/ir.cpp | 148 ++++++++++++++++++++++++++++++++-------- tools/pnnx/src/main.cpp | 5 +- 2 files changed, 123 insertions(+), 30 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 59f26b17f054..6fd139627ad7 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1446,21 +1446,75 @@ static std::string make_index_expression(const Operator* op) int Graph::calculate_flops_M() { long long flops = 0; - for(auto op:ops) { - if(op->type == "aten::matmul") - { - int m = op->inputs[0]->shape[0]; - int k = op->inputs[0]->shape[1]; - int n = op->inputs[1]->shape[1]; - flops += 2 * m * k * n; - } - else if(op->type == "aten::add") { - int m = op->inputs[0]->shape[0]; - int n = op->inputs[0]->shape[1]; - flops += m * n; - } - } - return int(flops); + for(auto op:ops) + { + fprintf(stderr, "op->type %s\n", op->type.c_str()); + if(op->type[0] == 'F') + { + std::string sub_type = op->type.substr(2); + if(sub_type == "adaptive_avg_pool1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int o = op->params.at("output_size").ai[0]; + flops += n * c * l * o; + } + else if(sub_type == "adaptive_avg_pool2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int oh = op->params.at("output_size").ai[0]; + int ow = op->params.at("output_size").ai[1]; + flops += n * c * h * w * oh * ow; + } + else if(sub_type == "adaptive_avg_pool3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int od = op->params.at("output_size").ai[0]; + int oh = op->params.at("output_size").ai[1]; + int ow = op->params.at("output_size").ai[2]; + flops += n * c * d * h * w * od * oh * ow; + } + else if(sub_type == "adaptive_max_pool1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int o = op->params.at("output_size").ai[0]; + flops += n * c * l * o; + } + else if(sub_type == "adaptive_max_pool2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int oh = op->params.at("output_size").ai[0]; + int ow = op->params.at("output_size").ai[1]; + flops += n * c * h * w * oh * ow; + } + else if(sub_type == "adaptive_max_pool3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int od = op->params.at("output_size").ai[0]; + int oh = op->params.at("output_size").ai[1]; + int ow = op->params.at("output_size").ai[2]; + flops += n * c * d * h * w * od * oh * ow; + } + } + } + return int(flops / 1e6); } int Graph::calculate_memops_M() @@ -1468,22 +1522,60 @@ int Graph::calculate_memops_M() long long mem = 0; for(auto op : ops) { - fprintf(stderr, "%s\n", op->type.c_str()); - if(op->type == "aten::matmul") + if(op->type[0] == 'F') { - int m = op->inputs[0]->shape[0]; - int k = op->inputs[0]->shape[1]; - int n = op->inputs[1]->shape[1]; - mem += m * k + k * n + m * n; - } - else if(op->type == "aten::add") - { - int m = op->inputs[0]->shape[0]; - int n = op->inputs[0]->shape[1]; - mem += 3 * m * n; + std::string sub_type = op->type.substr(2); + if(sub_type == "adaptive_avg_pool1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int o = op->params.at("output_size").ai[0]; + mem += n * c * l * o; + } + else if(sub_type == "adaptive_avg_pool2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int oh = op->params.at("output_size").ai[0]; + int ow = op->params.at("output_size").ai[1]; + mem += n * c * h * w * oh * ow; + } + else if(sub_type == "adaptive_avg_pool3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int od = op->params.at("output_size").ai[0]; + int oh = op->params.at("output_size").ai[1]; + int ow = op->params.at("output_size").ai[2]; + mem += n * c * d * h * w * od * oh * ow; + } + else if(sub_type == "adaptive_max_pool1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int o = op->params.at("output_size").ai[0]; + mem += n * c * l * o; + } + else if(sub_type == "adaptive_max_pool2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int oh = op->params.at("output_size").ai[0]; + int ow = op->params.at("output_size").ai[1]; + mem += n * c * h * w * oh * ow; + } } } - return int(mem); + return int(mem / 1e6); } int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 23fdc0102224..5f5cb3aa7fcd 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -362,6 +362,9 @@ int main(int argc, char** argv) pnnx_graph.save(pnnxparampath, pnnxbinpath); pnnx_graph.python(pnnxpypath, pnnxbinpath); + + fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); + fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); #if BUILD_PNNX2ONNX pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); @@ -382,7 +385,5 @@ int main(int argc, char** argv) // pnnx_graph2.load("pnnx.param", "pnnx.bin"); // pnnx_graph2.save("pnnx2.param", "pnnx2.bin"); - fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); - fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); return 0; } From ef1e8dfcfd82b201c6c811870577bc309df147d8 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Wed, 11 Sep 2024 20:17:11 +0800 Subject: [PATCH 15/28] test --- tools/pnnx/src/ir.cpp | 66 ++++++++--------------------------------- tools/pnnx/src/ir.h | 5 ++-- tools/pnnx/src/main.cpp | 5 ++-- 3 files changed, 19 insertions(+), 57 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 6fd139627ad7..c01b66b2c31a 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1443,12 +1443,10 @@ static std::string make_index_expression(const Operator* op) return index_expr; } -int Graph::calculate_flops_M() +void Graph::flops_memops_sum() { - long long flops = 0; for(auto op:ops) { - fprintf(stderr, "op->type %s\n", op->type.c_str()); if(op->type[0] == 'F') { std::string sub_type = op->type.substr(2); @@ -1459,6 +1457,7 @@ int Graph::calculate_flops_M() int l = op->inputs[0]->shape[2]; int o = op->params.at("output_size").ai[0]; flops += n * c * l * o; + memops += n * c * l + n * c * o; } else if(sub_type == "adaptive_avg_pool2d") { @@ -1469,6 +1468,7 @@ int Graph::calculate_flops_M() int oh = op->params.at("output_size").ai[0]; int ow = op->params.at("output_size").ai[1]; flops += n * c * h * w * oh * ow; + memops += n * c * h * w + n * c * oh * ow; } else if(sub_type == "adaptive_avg_pool3d") { @@ -1481,6 +1481,7 @@ int Graph::calculate_flops_M() int oh = op->params.at("output_size").ai[1]; int ow = op->params.at("output_size").ai[2]; flops += n * c * d * h * w * od * oh * ow; + memops += n * c * d * h * w + n * c * od * oh * ow; } else if(sub_type == "adaptive_max_pool1d") { @@ -1489,6 +1490,7 @@ int Graph::calculate_flops_M() int l = op->inputs[0]->shape[2]; int o = op->params.at("output_size").ai[0]; flops += n * c * l * o; + memops += n * c * l + n * c * o; } else if(sub_type == "adaptive_max_pool2d") { @@ -1499,6 +1501,7 @@ int Graph::calculate_flops_M() int oh = op->params.at("output_size").ai[0]; int ow = op->params.at("output_size").ai[1]; flops += n * c * h * w * oh * ow; + memops += n * c * h * w + n * c * oh * ow; } else if(sub_type == "adaptive_max_pool3d") { @@ -1511,71 +1514,28 @@ int Graph::calculate_flops_M() int oh = op->params.at("output_size").ai[1]; int ow = op->params.at("output_size").ai[2]; flops += n * c * d * h * w * od * oh * ow; + memops += n * c * d * h * w + n * c * od * oh * ow; } - } - } - return int(flops / 1e6); -} - -int Graph::calculate_memops_M() -{ - long long mem = 0; - for(auto op : ops) - { - if(op->type[0] == 'F') - { - std::string sub_type = op->type.substr(2); - if(sub_type == "adaptive_avg_pool1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int o = op->params.at("output_size").ai[0]; - mem += n * c * l * o; - } - else if(sub_type == "adaptive_avg_pool2d") + else if(sub_type == "celu") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int h = op->inputs[0]->shape[2]; int w = op->inputs[0]->shape[3]; - int oh = op->params.at("output_size").ai[0]; - int ow = op->params.at("output_size").ai[1]; - mem += n * c * h * w * oh * ow; - } - else if(sub_type == "adaptive_avg_pool3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int od = op->params.at("output_size").ai[0]; - int oh = op->params.at("output_size").ai[1]; - int ow = op->params.at("output_size").ai[2]; - mem += n * c * d * h * w * od * oh * ow; + flops += n * c * h * w; + memops += 2 * n * c * h * w; } - else if(sub_type == "adaptive_max_pool1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int o = op->params.at("output_size").ai[0]; - mem += n * c * l * o; - } - else if(sub_type == "adaptive_max_pool2d") + else if(sub_type == "elu") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int h = op->inputs[0]->shape[2]; int w = op->inputs[0]->shape[3]; - int oh = op->params.at("output_size").ai[0]; - int ow = op->params.at("output_size").ai[1]; - mem += n * c * h * w * oh * ow; + flops += n * c * h * w; + memops += 2 * n * c * h * w; } } } - return int(mem / 1e6); } int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index bc1f0089591d..c66141d7324c 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -346,8 +346,9 @@ class Graph std::vector ops; std::vector operands; - int calculate_flops_M(); - int calculate_memops_M(); + long long flops = 0; + long long memops = 0; + void flops_memops_sum(); private: Graph(const Graph& rhs); diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 5f5cb3aa7fcd..f75af022cdeb 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -363,8 +363,9 @@ int main(int argc, char** argv) pnnx_graph.python(pnnxpypath, pnnxbinpath); - fprintf(stderr, "float ops = %dM\n", pnnx_graph.calculate_flops_M()); - fprintf(stderr, "memory ops = %dM\n", pnnx_graph.calculate_memops_M()); + pnnx_graph.flops_memops_sum(); + fprintf(stderr, "float ops = %.3fM\n", double(pnnx_graph.flops) / 1e6); + fprintf(stderr, "mem ops = %.3fM\n", double(pnnx_graph.memops) / 1e6); #if BUILD_PNNX2ONNX pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); From b977f730f5ea119315711d891f22234ec91b3251 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Wed, 11 Sep 2024 20:35:12 +0800 Subject: [PATCH 16/28] test --- tools/pnnx/src/ir.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index c01b66b2c31a..64495974a639 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1447,6 +1447,7 @@ void Graph::flops_memops_sum() { for(auto op:ops) { + fprintf(stderr, "op->type: %s\n", op->type.c_str()); if(op->type[0] == 'F') { std::string sub_type = op->type.substr(2); From 9f4180002f838924958fed329d5831b8b248ba1a Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Thu, 12 Sep 2024 17:35:19 +0800 Subject: [PATCH 17/28] test --- tools/pnnx/src/ir.cpp | 565 ++++++++++++++++++++++++++++++++++++++-- tools/pnnx/src/ir.h | 6 +- tools/pnnx/src/main.cpp | 2 + 3 files changed, 548 insertions(+), 25 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 64495974a639..5125ac1d1ad7 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include "storezip.h" #include "utils.h" @@ -1445,13 +1444,13 @@ static std::string make_index_expression(const Operator* op) void Graph::flops_memops_sum() { - for(auto op:ops) + for (auto op : ops) { fprintf(stderr, "op->type: %s\n", op->type.c_str()); - if(op->type[0] == 'F') + if (op->type[0] == 'F') { std::string sub_type = op->type.substr(2); - if(sub_type == "adaptive_avg_pool1d") + if (sub_type == "adaptive_avg_pool1d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; @@ -1460,7 +1459,7 @@ void Graph::flops_memops_sum() flops += n * c * l * o; memops += n * c * l + n * c * o; } - else if(sub_type == "adaptive_avg_pool2d") + else if (sub_type == "adaptive_avg_pool2d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; @@ -1471,7 +1470,7 @@ void Graph::flops_memops_sum() flops += n * c * h * w * oh * ow; memops += n * c * h * w + n * c * oh * ow; } - else if(sub_type == "adaptive_avg_pool3d") + else if (sub_type == "adaptive_avg_pool3d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; @@ -1484,7 +1483,58 @@ void Graph::flops_memops_sum() flops += n * c * d * h * w * od * oh * ow; memops += n * c * d * h * w + n * c * od * oh * ow; } - else if(sub_type == "adaptive_max_pool1d") + else if (sub_type == "avg_pool1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int k = op->params.at("kernel_size").ai[0]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int o = (l + 2 * p - k) / s + 1; + flops += n * c * l * k; + memops += n * c * l + n * c * o; + } + else if (sub_type == "avg_pool2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int kh = op->params.at("kernel_size").ai[0]; + int kw = op->params.at("kernel_size").ai[1]; + int sh = op->params.at("stride").ai[0]; + int sw = op->params.at("stride").ai[1]; + int ph = op->params.at("padding").ai[0]; + int pw = op->params.at("padding").ai[1]; + int oh = (h + 2 * ph - kh) / sh + 1; + int ow = (w + 2 * pw - kw) / sw + 1; + flops += n * c * h * w * kh * kw; + memops += n * c * h * w + n * c * oh * ow; + } + else if (sub_type == "avg_pool3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int kd = op->params.at("kernel_size").ai[0]; + int kh = op->params.at("kernel_size").ai[1]; + int kw = op->params.at("kernel_size").ai[2]; + int sd = op->params.at("stride").ai[0]; + int sh = op->params.at("stride").ai[1]; + int sw = op->params.at("stride").ai[2]; + int pd = op->params.at("padding").ai[0]; + int ph = op->params.at("padding").ai[1]; + int pw = op->params.at("padding").ai[2]; + int od = (d + 2 * pd - kd) / sd + 1; + int oh = (h + 2 * ph - kh) / sh + 1; + int ow = (w + 2 * pw - kw) / sw + 1; + flops += n * c * d * h * w * kd * kh * kw; + memops += n * c * d * h * w + n * c * od * oh * ow; + } + else if (sub_type == "adaptive_max_pool1d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; @@ -1493,7 +1543,7 @@ void Graph::flops_memops_sum() flops += n * c * l * o; memops += n * c * l + n * c * o; } - else if(sub_type == "adaptive_max_pool2d") + else if (sub_type == "adaptive_max_pool2d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; @@ -1504,7 +1554,7 @@ void Graph::flops_memops_sum() flops += n * c * h * w * oh * ow; memops += n * c * h * w + n * c * oh * ow; } - else if(sub_type == "adaptive_max_pool3d") + else if (sub_type == "adaptive_max_pool3d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; @@ -1517,23 +1567,492 @@ void Graph::flops_memops_sum() flops += n * c * d * h * w * od * oh * ow; memops += n * c * d * h * w + n * c * od * oh * ow; } - else if(sub_type == "celu") + else if (sub_type == "max_pool1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int k = op->params.at("kernel_size").ai[0]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int o = (l + 2 * p - k) / s + 1; + flops += n * c * l * k; + memops += n * c * l + n * c * o; + } + else if (sub_type == "max_pool2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int kh = op->params.at("kernel_size").ai[0]; + int kw = op->params.at("kernel_size").ai[1]; + int sh = op->params.at("stride").ai[0]; + int sw = op->params.at("stride").ai[1]; + int ph = op->params.at("padding").ai[0]; + int pw = op->params.at("padding").ai[1]; + int oh = (h + 2 * ph - kh) / sh + 1; + int ow = (w + 2 * pw - kw) / sw + 1; + flops += n * c * h * w * kh * kw; + memops += n * c * h * w + n * c * oh * ow; + } + else if (sub_type == "max_pool3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int kd = op->params.at("kernel_size").ai[0]; + int kh = op->params.at("kernel_size").ai[1]; + int kw = op->params.at("kernel_size").ai[2]; + int sd = op->params.at("stride").ai[0]; + int sh = op->params.at("stride").ai[1]; + int sw = op->params.at("stride").ai[2]; + int pd = op->params.at("padding").ai[0]; + int ph = op->params.at("padding").ai[1]; + int pw = op->params.at("padding").ai[2]; + int od = (d + 2 * pd - kd) / sd + 1; + int oh = (h + 2 * ph - kh) / sh + 1; + int ow = (w + 2 * pw - kw) / sw + 1; + flops += n * c * d * h * w * kd * kh * kw; + memops += n * c * d * h * w + n * c * od * oh * ow; + } + else if (sub_type == "lp_pool1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int k = op->params.at("kernel_size").i; + int p = op->params.at("p").i; + if (p == 1) + { + extra_flops += 2 * n * c * l * k; + } + else if (p == 2) + { + extra_flops += 3 * n * c * l * k; + } + extra_memops += 2 * n * c * l; + } + else if (sub_type == "lp_pool2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int kh = op->params.at("kernel_size").ai[0]; + int kw = op->params.at("kernel_size").ai[1]; + int p = op->params.at("p").i; + if (p == 1) + { + extra_flops += 2 * n * c * h * w * kh * kw; + } + else if (p == 2) + { + extra_flops += 3 * n * c * h * w * kh * kw; + } + extra_memops += 2 * n * c * h * w; + } + else if (sub_type == "lp_pool3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int kd = op->params.at("kernel_size").ai[0]; + int kh = op->params.at("kernel_size").ai[1]; + int kw = op->params.at("kernel_size").ai[2]; + int p = op->params.at("p").i; + if (p == 1) + { + extra_flops += 2 * n * c * d * h * w * kd * kh * kw; + } + else if (p == 2) + { + extra_flops += 3 * n * c * d * h * w * kd * kh * kw; + } + extra_memops += 2 * n * c * d * h * w; + } + else if ( + sub_type == "elu" || + sub_type == "celu" || + sub_type == "gelu" || + sub_type == "glu" || + sub_type == "hardshrink" || + sub_type == "hardsigmoid" || + sub_type == "hardswish" || + sub_type == "hardtanh" || + sub_type == "leaky_relu" || + sub_type == "prelu" || + sub_type == "relu" || + sub_type == "relu6" || + sub_type == "rrelu" || + sub_type == "mish" || + sub_type == "normalize" || + sub_type == "batch_norm" || + sub_type == "group_norm" || + sub_type == "instance_norm" || + sub_type == "layer_norm" + ) + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int num_elements = 1; + for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) + { + num_elements *= op->inputs[0]->shape[i]; + } + if(sub_type == "elu") + { + extra_flops += 2 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "celu") + { + extra_flops += 3 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "gelu") + { + extra_flops += 3 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "glu") + { + int l = op->inputs[0]->shape[2]; + int o = op->outputs[0]->shape[2]; + extra_flops += n * c * l * o; + extra_memops += 2 * n * c * l + n * o; + } + else if(sub_type == "hardshrink") + { + extra_flops += 2 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "hardsigmoid") + { + extra_flops += 6 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "hardswish") + { + extra_flops += 5 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "hardtanh") + { + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "leaky_relu") + { + extra_flops += 2 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "prelu") + { + extra_flops += 2 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "relu") + { + extra_flops += n * c * num_elements; + extra_memops += n * c * num_elements; + } + else if(sub_type == "relu6") + { + extra_memops += n * c * num_elements; + } + else if(sub_type == "rrelu") + { + extra_flops += n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "mish") + { + extra_flops += 2 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if(sub_type == "normalize") + { + extra_flops += 7 * n * c * num_elements + 3; + extra_memops += 2 * n * c * num_elements; + } + else if( + sub_type == "batch_norm" || + sub_type == "group_norm" || + sub_type == "instance_norm" || + sub_type == "layer_norm" + ) + { + extra_flops += 7 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + } + else if (sub_type == "conv1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int k = op->inputs[1]->shape[0]; + int o = op->outputs[0]->shape[2]; + flops += 2 * n * c * l * k * o; + memops += 2 * n * c * l * k + n * o; + } + else if (sub_type == "conv2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int kh = op->inputs[1]->shape[2]; + int kw = op->inputs[1]->shape[3]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * h * w * kh * kw * o / g; + memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; + } + else if (sub_type == "conv3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int kd = op->inputs[1]->shape[2]; + int kh = op->inputs[1]->shape[3]; + int kw = op->inputs[1]->shape[4]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; + memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; + } + else if (sub_type == "conv_transpose1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int k = op->inputs[1]->shape[0]; + int o = op->outputs[0]->shape[2]; + flops += 2 * n * c * l * k * o; + memops += 2 * n * c * l * k + n * o; + } + else if (sub_type == "conv_transpose2d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int h = op->inputs[0]->shape[2]; + int w = op->inputs[0]->shape[3]; + int kh = op->inputs[1]->shape[2]; + int kw = op->inputs[1]->shape[3]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * h * w * kh * kw * o / g; + memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; + } + else if (sub_type == "conv_transpose3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int kd = op->inputs[1]->shape[2]; + int kh = op->inputs[1]->shape[3]; + int kw = op->inputs[1]->shape[4]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; + memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; + } + else if (sub_type == "embedding") + { + int n = op->inputs[0]->shape[0]; + int l = op->inputs[0]->shape[1]; + int c = op->params.at("num_embeddings").i; + int e = op->params.at("embedding_dim").i; + extra_flops += n * l * e; + extra_memops += n * l + n * e; + } + else if (sub_type == "linear") + { + int n = op->inputs[0]->shape[0]; + int i = op->inputs[0]->shape[1]; + int o = op->outputs[0]->shape[1]; + flops += 2 * n * i * o; + memops += 2 * n * i + n * o; + } + else if (sub_type == "log_softmax") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + extra_flops += 2 * n * c * l; + extra_memops += 2 * n * c * l; + } + else if (sub_type == "logsigmoid") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + extra_flops += 2 * n * c * l; + extra_memops += 2 * n * c * l; + } + else if (sub_type == "scaled_dot_product_attention") + { + int n = op->inputs[0]->shape[0]; + int l = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + flops += 2 * n * l * l + n * l * d + n * l * l * d; + memops += 2 * n * l * d + 3 * n * l * l + n * l; + } + } + + else if (op->type.substr(0, 2) == "nn") + { + std::string sub_type = op->type.substr(3); + if ( + sub_type == "BatchNorm1d" || + sub_type == "BatchNorm2d" || + sub_type == "BatchNorm3d" + ) + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int num_elements = 1; + for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) + { + num_elements *= op->inputs[0]->shape[i]; + } + extra_flops += 7 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if (sub_type == "Conv1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int k = op->inputs[1]->shape[0]; + int o = op->outputs[0]->shape[2]; + flops += 2 * n * c * l * k * o; + memops += 2 * n * c * l * k + n * o; + } + else if (sub_type == "Conv2d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int h = op->inputs[0]->shape[2]; int w = op->inputs[0]->shape[3]; - flops += n * c * h * w; - memops += 2 * n * c * h * w; + int kh = op->inputs[1]->shape[2]; + int kw = op->inputs[1]->shape[3]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * h * w * kh * kw * o / g; + memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; + } + else if (sub_type == "Conv3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int kd = op->inputs[1]->shape[2]; + int kh = op->inputs[1]->shape[3]; + int kw = op->inputs[1]->shape[4]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; + memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; + } + else if (sub_type == "ConvTranspose1d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int l = op->inputs[0]->shape[2]; + int k = op->inputs[1]->shape[0]; + int o = op->outputs[0]->shape[2]; + flops += 2 * n * c * l * k * o; + memops += 2 * n * c * l * k + n * o; } - else if(sub_type == "elu") + else if (sub_type == "ConvTranspose2d") { int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int h = op->inputs[0]->shape[2]; int w = op->inputs[0]->shape[3]; - flops += n * c * h * w; - memops += 2 * n * c * h * w; + int kh = op->inputs[1]->shape[2]; + int kw = op->inputs[1]->shape[3]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * h * w * kh * kw * o / g; + memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; + } + else if (sub_type == "PReLU") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int num_elements = 1; + for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) + { + num_elements *= op->inputs[0]->shape[i]; + } + extra_flops += 2 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; + } + else if (sub_type == "ConvTranspose3d") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int d = op->inputs[0]->shape[2]; + int h = op->inputs[0]->shape[3]; + int w = op->inputs[0]->shape[4]; + int kd = op->inputs[1]->shape[2]; + int kh = op->inputs[1]->shape[3]; + int kw = op->inputs[1]->shape[4]; + int o = op->outputs[0]->shape[2]; + int s = op->params.at("stride").ai[0]; + int p = op->params.at("padding").ai[0]; + int g = op->params.at("groups").i; + flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; + memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; + } + else if (sub_type == "Embedding") + { + int n = op->inputs[0]->shape[0]; + int l = op->inputs[0]->shape[1]; + int c = op->params.at("num_embeddings").i; + int e = op->params.at("embedding_dim").i; + extra_flops += 2 * n * l * e; + extra_memops += 2 * n * l + n * e; + } + else if (sub_type == "GroupNorm" || sub_type == "InstanceNorm" || sub_type == "LayerNorm") + { + int n = op->inputs[0]->shape[0]; + int c = op->inputs[0]->shape[1]; + int num_elements = 1; + for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) + { + num_elements *= op->inputs[0]->shape[i]; + } + + extra_flops += 7 * n * c * num_elements; + extra_memops += 2 * n * c * num_elements; } } } @@ -1630,10 +2149,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) for (size_t i = 0; i < param.ai.size(); i++) { if ((op->type == "nn.AdaptiveAvgPool2d" - || op->type == "nn.AdaptiveAvgPool3d" - || op->type == "nn.AdaptiveMaxPool2d" - || op->type == "nn.AdaptiveMaxPool3d") - && it.first == "output_size" && param.ai[i] == 0) + || op->type == "nn.AdaptiveAvgPool3d" + || op->type == "nn.AdaptiveMaxPool2d" + || op->type == "nn.AdaptiveMaxPool3d") + && it.first == "output_size" && param.ai[i] == 0) { fprintf(pyfp, "None"); } @@ -2386,10 +2905,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) for (size_t i = 0; i < param.ai.size(); i++) { if ((op->type == "F.adaptive_avg_pool2d" - || op->type == "F.adaptive_avg_pool3d" - || op->type == "F.adaptive_max_pool2d" - || op->type == "F.adaptive_max_pool3d") - && it.first == "output_size" && param.ai[i] == 0) + || op->type == "F.adaptive_avg_pool3d" + || op->type == "F.adaptive_max_pool2d" + || op->type == "F.adaptive_max_pool3d") + && it.first == "output_size" && param.ai[i] == 0) { fprintf(pyfp, "None"); } diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index c66141d7324c..37ee81e0a6b5 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -346,8 +346,10 @@ class Graph std::vector ops; std::vector operands; - long long flops = 0; - long long memops = 0; + unsigned long long flops = 0; + unsigned long long memops = 0; + unsigned long long extra_flops = 0; + unsigned long long extra_memops = 0; void flops_memops_sum(); private: diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index f75af022cdeb..949680faab82 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -366,6 +366,8 @@ int main(int argc, char** argv) pnnx_graph.flops_memops_sum(); fprintf(stderr, "float ops = %.3fM\n", double(pnnx_graph.flops) / 1e6); fprintf(stderr, "mem ops = %.3fM\n", double(pnnx_graph.memops) / 1e6); + fprintf(stderr, "extra float ops = %.3fM\n", double(pnnx_graph.extra_flops) / 1e6); + fprintf(stderr, "extra mem ops = %.3fM\n", double(pnnx_graph.extra_memops) / 1e6); #if BUILD_PNNX2ONNX pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); From a91dc5ce90bcd677a92f92916b70f094dbdfc23b Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Mon, 16 Sep 2024 20:51:36 +0800 Subject: [PATCH 18/28] nn part finished --- tools/pnnx/src/ir.cpp | 722 +++++++++++++++++++++++------------------- 1 file changed, 388 insertions(+), 334 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 5125ac1d1ad7..6cbf320acf41 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -14,8 +14,11 @@ #include "ir.h" +#include #include +#include #include +#include #include #include #include @@ -23,6 +26,7 @@ #include #include #include +#include #include "storezip.h" #include "utils.h" @@ -1488,9 +1492,9 @@ void Graph::flops_memops_sum() int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int l = op->inputs[0]->shape[2]; - int k = op->params.at("kernel_size").ai[0]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; + int k = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; + int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; int o = (l + 2 * p - k) / s + 1; flops += n * c * l * k; memops += n * c * l + n * c * o; @@ -1501,12 +1505,12 @@ void Graph::flops_memops_sum() int c = op->inputs[0]->shape[1]; int h = op->inputs[0]->shape[2]; int w = op->inputs[0]->shape[3]; - int kh = op->params.at("kernel_size").ai[0]; - int kw = op->params.at("kernel_size").ai[1]; - int sh = op->params.at("stride").ai[0]; - int sw = op->params.at("stride").ai[1]; - int ph = op->params.at("padding").ai[0]; - int pw = op->params.at("padding").ai[1]; + int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; + int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; + int sh = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int sw = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; + int ph = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int pw = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; int oh = (h + 2 * ph - kh) / sh + 1; int ow = (w + 2 * pw - kw) / sw + 1; flops += n * c * h * w * kh * kw; @@ -1519,15 +1523,15 @@ void Graph::flops_memops_sum() int d = op->inputs[0]->shape[2]; int h = op->inputs[0]->shape[3]; int w = op->inputs[0]->shape[4]; - int kd = op->params.at("kernel_size").ai[0]; - int kh = op->params.at("kernel_size").ai[1]; - int kw = op->params.at("kernel_size").ai[2]; - int sd = op->params.at("stride").ai[0]; - int sh = op->params.at("stride").ai[1]; - int sw = op->params.at("stride").ai[2]; - int pd = op->params.at("padding").ai[0]; - int ph = op->params.at("padding").ai[1]; - int pw = op->params.at("padding").ai[2]; + int kd = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; + int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; + int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[2] : 1; + int sd = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int sh = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; + int sw = op->has_param("stride") ? op->params.at("stride").ai[2] : 1; + int pd = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int ph = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; + int pw = op->has_param("padding") ? op->params.at("padding").ai[2] : 0; int od = (d + 2 * pd - kd) / sd + 1; int oh = (h + 2 * ph - kh) / sh + 1; int ow = (w + 2 * pw - kw) / sw + 1; @@ -1572,9 +1576,9 @@ void Graph::flops_memops_sum() int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int l = op->inputs[0]->shape[2]; - int k = op->params.at("kernel_size").ai[0]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; + int k = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; + int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; int o = (l + 2 * p - k) / s + 1; flops += n * c * l * k; memops += n * c * l + n * c * o; @@ -1585,12 +1589,12 @@ void Graph::flops_memops_sum() int c = op->inputs[0]->shape[1]; int h = op->inputs[0]->shape[2]; int w = op->inputs[0]->shape[3]; - int kh = op->params.at("kernel_size").ai[0]; - int kw = op->params.at("kernel_size").ai[1]; - int sh = op->params.at("stride").ai[0]; - int sw = op->params.at("stride").ai[1]; - int ph = op->params.at("padding").ai[0]; - int pw = op->params.at("padding").ai[1]; + int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; + int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; + int sh = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int sw = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; + int ph = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int pw = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; int oh = (h + 2 * ph - kh) / sh + 1; int ow = (w + 2 * pw - kw) / sw + 1; flops += n * c * h * w * kh * kw; @@ -1603,192 +1607,24 @@ void Graph::flops_memops_sum() int d = op->inputs[0]->shape[2]; int h = op->inputs[0]->shape[3]; int w = op->inputs[0]->shape[4]; - int kd = op->params.at("kernel_size").ai[0]; - int kh = op->params.at("kernel_size").ai[1]; - int kw = op->params.at("kernel_size").ai[2]; - int sd = op->params.at("stride").ai[0]; - int sh = op->params.at("stride").ai[1]; - int sw = op->params.at("stride").ai[2]; - int pd = op->params.at("padding").ai[0]; - int ph = op->params.at("padding").ai[1]; - int pw = op->params.at("padding").ai[2]; + int kd = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; + int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; + int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[2] : 1; + int sd = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int sh = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; + int sw = op->has_param("stride") ? op->params.at("stride").ai[2] : 1; + int pd = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int ph = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; + int pw = op->has_param("padding") ? op->params.at("padding").ai[2] : 0; int od = (d + 2 * pd - kd) / sd + 1; int oh = (h + 2 * ph - kh) / sh + 1; int ow = (w + 2 * pw - kw) / sw + 1; flops += n * c * d * h * w * kd * kh * kw; memops += n * c * d * h * w + n * c * od * oh * ow; } - else if (sub_type == "lp_pool1d") + else if (sub_type == "prelu" || sub_type == "leaky_relu") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int k = op->params.at("kernel_size").i; - int p = op->params.at("p").i; - if (p == 1) - { - extra_flops += 2 * n * c * l * k; - } - else if (p == 2) - { - extra_flops += 3 * n * c * l * k; - } - extra_memops += 2 * n * c * l; - } - else if (sub_type == "lp_pool2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int kh = op->params.at("kernel_size").ai[0]; - int kw = op->params.at("kernel_size").ai[1]; - int p = op->params.at("p").i; - if (p == 1) - { - extra_flops += 2 * n * c * h * w * kh * kw; - } - else if (p == 2) - { - extra_flops += 3 * n * c * h * w * kh * kw; - } - extra_memops += 2 * n * c * h * w; - } - else if (sub_type == "lp_pool3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int kd = op->params.at("kernel_size").ai[0]; - int kh = op->params.at("kernel_size").ai[1]; - int kw = op->params.at("kernel_size").ai[2]; - int p = op->params.at("p").i; - if (p == 1) - { - extra_flops += 2 * n * c * d * h * w * kd * kh * kw; - } - else if (p == 2) - { - extra_flops += 3 * n * c * d * h * w * kd * kh * kw; - } - extra_memops += 2 * n * c * d * h * w; - } - else if ( - sub_type == "elu" || - sub_type == "celu" || - sub_type == "gelu" || - sub_type == "glu" || - sub_type == "hardshrink" || - sub_type == "hardsigmoid" || - sub_type == "hardswish" || - sub_type == "hardtanh" || - sub_type == "leaky_relu" || - sub_type == "prelu" || - sub_type == "relu" || - sub_type == "relu6" || - sub_type == "rrelu" || - sub_type == "mish" || - sub_type == "normalize" || - sub_type == "batch_norm" || - sub_type == "group_norm" || - sub_type == "instance_norm" || - sub_type == "layer_norm" - ) - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int num_elements = 1; - for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) - { - num_elements *= op->inputs[0]->shape[i]; - } - if(sub_type == "elu") - { - extra_flops += 2 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "celu") - { - extra_flops += 3 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "gelu") - { - extra_flops += 3 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "glu") - { - int l = op->inputs[0]->shape[2]; - int o = op->outputs[0]->shape[2]; - extra_flops += n * c * l * o; - extra_memops += 2 * n * c * l + n * o; - } - else if(sub_type == "hardshrink") - { - extra_flops += 2 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "hardsigmoid") - { - extra_flops += 6 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "hardswish") - { - extra_flops += 5 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "hardtanh") - { - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "leaky_relu") - { - extra_flops += 2 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "prelu") - { - extra_flops += 2 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "relu") - { - extra_flops += n * c * num_elements; - extra_memops += n * c * num_elements; - } - else if(sub_type == "relu6") - { - extra_memops += n * c * num_elements; - } - else if(sub_type == "rrelu") - { - extra_flops += n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "mish") - { - extra_flops += 2 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if(sub_type == "normalize") - { - extra_flops += 7 * n * c * num_elements + 3; - extra_memops += 2 * n * c * num_elements; - } - else if( - sub_type == "batch_norm" || - sub_type == "group_norm" || - sub_type == "instance_norm" || - sub_type == "layer_norm" - ) - { - extra_flops += 7 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } + } else if (sub_type == "conv1d") { @@ -1809,9 +1645,9 @@ void Graph::flops_memops_sum() int kh = op->inputs[1]->shape[2]; int kw = op->inputs[1]->shape[3]; int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; + int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int g = op->has_param("groups") ? op->params.at("groups").i : 1; flops += 2 * n * c * h * w * kh * kw * o / g; memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; } @@ -1826,9 +1662,9 @@ void Graph::flops_memops_sum() int kh = op->inputs[1]->shape[3]; int kw = op->inputs[1]->shape[4]; int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; + int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int g = op->has_param("groups") ? op->params.at("groups").i : 1; flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; } @@ -1851,9 +1687,9 @@ void Graph::flops_memops_sum() int kh = op->inputs[1]->shape[2]; int kw = op->inputs[1]->shape[3]; int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; + int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int g = op->has_param("groups") ? op->params.at("groups").i : 1; flops += 2 * n * c * h * w * kh * kw * o / g; memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; } @@ -1868,20 +1704,15 @@ void Graph::flops_memops_sum() int kh = op->inputs[1]->shape[3]; int kw = op->inputs[1]->shape[4]; int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; + int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; + int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; + int g = op->has_param("groups") ? op->params.at("groups").i : 1; flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; } else if (sub_type == "embedding") { - int n = op->inputs[0]->shape[0]; - int l = op->inputs[0]->shape[1]; - int c = op->params.at("num_embeddings").i; - int e = op->params.at("embedding_dim").i; - extra_flops += n * l * e; - extra_memops += n * l + n * e; + /*todo*/ } else if (sub_type == "linear") { @@ -1920,139 +1751,362 @@ void Graph::flops_memops_sum() else if (op->type.substr(0, 2) == "nn") { std::string sub_type = op->type.substr(3); - if ( - sub_type == "BatchNorm1d" || - sub_type == "BatchNorm2d" || - sub_type == "BatchNorm3d" - ) - { + if (sub_type == "BatchNorm1d" + || sub_type == "BatchNorm2d" + || sub_type == "BatchNorm3d" + || sub_type == "GroupNorm" + || sub_type == "LayerNorm" + || sub_type == "InstanceNorm1d" + || sub_type == "InstanceNorm2d" + || sub_type == "InstanceNorm3d") + { + std::vector shape = op->inputs[0]->shape; int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; - int num_elements = 1; - for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) + int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + if((op->has_param("affine") && op->params.at("affine").b) + || (op->has_param("elementwise_affine") && op->params.at("elementwise_affine").b)) { - num_elements *= op->inputs[0]->shape[i]; + extra_flops += 2 * num_elements; + extra_memops += 2 * (num_elements + n * c); + } + else + { + extra_flops += num_elements; + extra_memops += num_elements; + } + } + else if (sub_type == "Conv1d" + || sub_type == "Conv2d" + || sub_type == "Conv3d" + || sub_type == "ConvTranspose1d" + || sub_type == "ConvTranspose2d" + || sub_type == "ConvTranspose3d") + { + int c = op->params.at("in_channels").i; + std::vector k = op->params.at("kernel_size").ai; + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int g = op->params["groups"].i; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int kernel_size = std::accumulate(k.begin() + 2, k.end(), 1, std::multiplies()); + flops += output_size * c * kernel_size / g; + memops += input_size + output_size + std::accumulate(k.begin(), k.end(), 1, std::multiplies()) * c / g; + if(op->has_param("bias")) + { + flops += output_size; + memops += output_size; + } + } + else if (sub_type == "AvgPool1d" + || sub_type == "AvgPool2d" + || sub_type == "AvgPool3d") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size; + memops += input_size + output_size; + } + else if (sub_type == "AdaptiveAvgPool1d" + || sub_type == "AdaptiveAvgPool2d" + || sub_type == "AdaptiveAvgPool3d") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + std::vector kernel_size; + for(size_t i = 2; i < input_shape.size(); i++) + { + kernel_size.emplace_back(output_shape[i] / input_shape[i]); + } + flops += (std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies()) + 1) * output_size; + memops += input_size + output_size; + } + else if(sub_type == "PReLU" + || sub_type == "ELU" + || sub_type == "LeakyReLU" + || sub_type == "GELU") + { + std::vector shape = op->outputs[0]->shape; + int n = shape[0]; + int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + extra_flops += num_elements; + if(sub_type == "PReLU") + { + extra_memops += 2 * num_elements + n * op->params["num_parameters"].i; + } + else + { + extra_memops += 2 * num_elements; } - extra_flops += 7 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; - } - else if (sub_type == "Conv1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int k = op->inputs[1]->shape[0]; - int o = op->outputs[0]->shape[2]; - flops += 2 * n * c * l * k * o; - memops += 2 * n * c * l * k + n * o; - } - else if (sub_type == "Conv2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int kh = op->inputs[1]->shape[2]; - int kw = op->inputs[1]->shape[3]; - int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; - flops += 2 * n * c * h * w * kh * kw * o / g; - memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; } - else if (sub_type == "Conv3d") + else if(sub_type == "Tanh") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int kd = op->inputs[1]->shape[2]; - int kh = op->inputs[1]->shape[3]; - int kw = op->inputs[1]->shape[4]; - int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; - flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; - memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; + std::vector shape = op->outputs[0]->shape; + int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + extra_flops += 2 * num_elements; + extra_memops += 2 * num_elements; } - else if (sub_type == "ConvTranspose1d") + else if (sub_type == "Linear") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int k = op->inputs[1]->shape[0]; - int o = op->outputs[0]->shape[2]; - flops += 2 * n * c * l * k * o; - memops += 2 * n * c * l * k + n * o; + std::vector input_shape = op->inputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int in_features = op->params.at("in_features").i; + int out_features = op->params.at("out_features").i; + int bias = op->has_param("bias") ? out_features : 0; + flops += (in_features * out_features + bias) * input_size / in_features; + memops += input_size + output_size + output_size * (bias ? 1 : 0); } - else if (sub_type == "ConvTranspose2d") + else if (sub_type == "Upsample" + || sub_type == "UnsampleBilinear2d" + || sub_type == "UnsampleNearest2d") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int kh = op->inputs[1]->shape[2]; - int kw = op->inputs[1]->shape[3]; - int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; - flops += 2 * n * c * h * w * kh * kw * o / g; - memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; + std::vector input_shape = op->inputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + std::string mode; + if(sub_type == "Unsample") + { + mode = op->has_param("mode") ? op->params.at("mode").s : "nearest"; + } + else if(sub_type == "UnsampleBilinear2d") + { + mode = "bilinear"; + } + else if(sub_type == "UnsampleNearest2d") + { + mode = "nearest"; + } + + if(mode == "nearest") + { + extra_flops += input_size; + extra_memops += input_size + output_size; + } + else if(mode == "linear") + { + extra_flops += 5 * output_size; + extra_memops += 2 * input_size + output_size; + } + else if(mode == "bilinear") + { + extra_flops += 11 * output_size; + extra_memops += 4 * input_size + output_size; + } + else if(mode == "bicubic") + { + extra_flops += (224 + 35) * output_size; + extra_memops += 16 * input_size + output_size; + } + else if(mode == "trilinear") + { + extra_flops += (13 * 2 + 5) * input_size; + extra_memops += 8 * input_size + output_size; + } } - else if (sub_type == "PReLU") + else if(sub_type == "RNN") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int num_elements = 1; - for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) + bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; + bool bias = op->has_param("bias") && op->params.at("bias").b; + int input_size = op->params.at("input_size").i; + int hidden_size = op->params.at("hidden_size").i; + int flops1 = hidden_size * (input_size + hidden_size) + hidden_size; + if(bias) + { + flops1 += 2 * hidden_size; + } + if(bi) + { + flops1 *= 2; + } + + int num_layers = op->params.at("num_layers").i; + int flops2 = 0; + if(bi) + { + flops2 = 3 * hidden_size * hidden_size + hidden_size; + if(bias) + { + flops2 += 2 * hidden_size; + } + flops2 *= 2 * num_layers; + } + else { - num_elements *= op->inputs[0]->shape[i]; + flops2 = 2 * hidden_size * hidden_size + hidden_size; + if(bias) + { + flops2 += 2 * hidden_size; + } + flops2 *= num_layers; + } + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int num_steps = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + flops += (flops1 + flops2) * num_steps * batch_size; + memops += num_steps * batch_size * input_size; + memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); + if(bias) + { + memops += 2 * hidden_size * num_layers * (bi ? 2 : 1); } - extra_flops += 2 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; } - else if (sub_type == "ConvTranspose3d") + else if(sub_type == "LSTM") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int kd = op->inputs[1]->shape[2]; - int kh = op->inputs[1]->shape[3]; - int kw = op->inputs[1]->shape[4]; - int o = op->outputs[0]->shape[2]; - int s = op->params.at("stride").ai[0]; - int p = op->params.at("padding").ai[0]; - int g = op->params.at("groups").i; - flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; - memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; + bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; + bool bias = op->has_param("bias") && op->params.at("bias").b; + int input_size = op->params.at("input_size").i; + int hidden_size = op->params.at("hidden_size").i; + int flops1 = 4 * hidden_size * (input_size + hidden_size) + 4 * hidden_size; + if(bias) + { + flops1 += 8 * hidden_size; + } + if(bi) + { + flops1 *= 2; + } + flops1 += 4 * hidden_size; + + int num_layers = op->params.at("num_layers").i; + int flops2 = 0; + if(bi) + { + flops2 = 12 * hidden_size * hidden_size + 4 * hidden_size; + if(bias) + { + flops2 += 8 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= 2 * num_layers; + } + else + { + flops2 = 4 * hidden_size * hidden_size + 4 * hidden_size; + if(bias) + { + flops2 += 8 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= num_layers; + } + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int num_steps = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + flops += (flops1 + flops2) * num_steps * batch_size; + memops += num_steps * batch_size * input_size; + memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); + if(bias) + { + memops += 8 * hidden_size * num_layers * (bi ? 2 : 1); + } } - else if (sub_type == "Embedding") + else if (sub_type == "GRU") { - int n = op->inputs[0]->shape[0]; - int l = op->inputs[0]->shape[1]; - int c = op->params.at("num_embeddings").i; - int e = op->params.at("embedding_dim").i; - extra_flops += 2 * n * l * e; - extra_memops += 2 * n * l + n * e; + bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; + bool bias = op->has_param("bias") && op->params.at("bias").b; + int input_size = op->params.at("input_size").i; + int hidden_size = op->params.at("hidden_size").i; + int flops1 = 3 * hidden_size * (input_size + hidden_size) + 3 * hidden_size; + if(bias) + { + flops1 += 6 * hidden_size; + } + flops1 += 4 * hidden_size; + if(bi) + { + flops1 *= 2; + } + + int num_layers = op->params.at("num_layers").i; + int flops2 = 0; + if(bi) + { + flops2 = 9 * hidden_size * hidden_size + 3 * hidden_size; + if(bias) + { + flops2 += 6 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= 2 * num_layers; + } + else + { + flops2 = 6 * hidden_size * hidden_size + 3 * hidden_size; + if(bias) + { + flops2 += 6 * hidden_size; + } + flops2 += 4 * hidden_size; + flops2 *= num_layers; + } + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int num_steps = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + flops += (flops1 + flops2) * num_steps * batch_size; + memops += num_steps * batch_size * input_size; + memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); + if(bias) + { + memops += 6 * hidden_size * num_layers * (bi ? 2 : 1); + } } - else if (sub_type == "GroupNorm" || sub_type == "InstanceNorm" || sub_type == "LayerNorm") + else if(sub_type == "MultiheadAttention") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int num_elements = 1; - for (size_t i = 2; i < op->inputs[0]->shape.size(); ++i) + bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; + int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; + int qlen = batch_first ? op->inputs[0]->shape[1] : op->inputs[0]->shape[0]; + int klen = batch_first ? op->inputs[1]->shape[1] : op->inputs[1]->shape[0]; + int d_model = op->params.at("embed_dim").i; + int num_heads = op->params.at("num_heads").i; + int head_dim = d_model / num_heads; + bool bias = op->params.at("bias").b; + + // Linear transformations for Q, K, V + int flops_qkv = 3 * batch_size * qlen * d_model * d_model; + if (bias) + { + flops_qkv += 3 * batch_size * qlen * d_model; + } + + // Scaled dot-product attention + int flops_attention = batch_size * num_heads * qlen * klen * head_dim; + + // Linear transformation for output + int flops_output = batch_size * qlen * d_model * d_model; + if (bias) + { + flops_output += batch_size * qlen * d_model; + } + + flops += flops_qkv + flops_attention + flops_output; + + // Memory operations for Q, K, V + int memops_qkv = 3 * batch_size * qlen * d_model; + if (bias) + { + memops_qkv += 3 * d_model; + } + + // Memory operations for attention weights + int memops_attention = batch_size * num_heads * qlen * klen; + + // Memory operations for output + int memops_output = batch_size * qlen * d_model; + if (bias) { - num_elements *= op->inputs[0]->shape[i]; + memops_output += d_model; } - extra_flops += 7 * n * c * num_elements; - extra_memops += 2 * n * c * num_elements; + // Total memory operations + memops += memops_qkv + memops_attention + memops_output; } } } From 4adf254c729edc63ef0e7deb86b3bbb3cbf079c9 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Thu, 19 Sep 2024 16:02:54 +0800 Subject: [PATCH 19/28] functional finished --- tools/pnnx/src/ir.cpp | 335 +++++++----------------------------------- 1 file changed, 53 insertions(+), 282 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 6cbf320acf41..f05ca65a3c59 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1454,297 +1454,68 @@ void Graph::flops_memops_sum() if (op->type[0] == 'F') { std::string sub_type = op->type.substr(2); - if (sub_type == "adaptive_avg_pool1d") + if (sub_type == "linear") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int o = op->params.at("output_size").ai[0]; - flops += n * c * l * o; - memops += n * c * l + n * c * o; - } - else if (sub_type == "adaptive_avg_pool2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int oh = op->params.at("output_size").ai[0]; - int ow = op->params.at("output_size").ai[1]; - flops += n * c * h * w * oh * ow; - memops += n * c * h * w + n * c * oh * ow; - } - else if (sub_type == "adaptive_avg_pool3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int od = op->params.at("output_size").ai[0]; - int oh = op->params.at("output_size").ai[1]; - int ow = op->params.at("output_size").ai[2]; - flops += n * c * d * h * w * od * oh * ow; - memops += n * c * d * h * w + n * c * od * oh * ow; - } - else if (sub_type == "avg_pool1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int k = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; - int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int o = (l + 2 * p - k) / s + 1; - flops += n * c * l * k; - memops += n * c * l + n * c * o; - } - else if (sub_type == "avg_pool2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; - int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; - int sh = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int sw = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; - int ph = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int pw = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; - int oh = (h + 2 * ph - kh) / sh + 1; - int ow = (w + 2 * pw - kw) / sw + 1; - flops += n * c * h * w * kh * kw; - memops += n * c * h * w + n * c * oh * ow; - } - else if (sub_type == "avg_pool3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int kd = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; - int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; - int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[2] : 1; - int sd = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int sh = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; - int sw = op->has_param("stride") ? op->params.at("stride").ai[2] : 1; - int pd = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int ph = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; - int pw = op->has_param("padding") ? op->params.at("padding").ai[2] : 0; - int od = (d + 2 * pd - kd) / sd + 1; - int oh = (h + 2 * ph - kh) / sh + 1; - int ow = (w + 2 * pw - kw) / sw + 1; - flops += n * c * d * h * w * kd * kh * kw; - memops += n * c * d * h * w + n * c * od * oh * ow; - } - else if (sub_type == "adaptive_max_pool1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int o = op->params.at("output_size").ai[0]; - flops += n * c * l * o; - memops += n * c * l + n * c * o; - } - else if (sub_type == "adaptive_max_pool2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int oh = op->params.at("output_size").ai[0]; - int ow = op->params.at("output_size").ai[1]; - flops += n * c * h * w * oh * ow; - memops += n * c * h * w + n * c * oh * ow; - } - else if (sub_type == "adaptive_max_pool3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int od = op->params.at("output_size").ai[0]; - int oh = op->params.at("output_size").ai[1]; - int ow = op->params.at("output_size").ai[2]; - flops += n * c * d * h * w * od * oh * ow; - memops += n * c * d * h * w + n * c * od * oh * ow; - } - else if (sub_type == "max_pool1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int k = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; - int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int o = (l + 2 * p - k) / s + 1; - flops += n * c * l * k; - memops += n * c * l + n * c * o; - } - else if (sub_type == "max_pool2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; - int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; - int sh = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int sw = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; - int ph = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int pw = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; - int oh = (h + 2 * ph - kh) / sh + 1; - int ow = (w + 2 * pw - kw) / sw + 1; - flops += n * c * h * w * kh * kw; - memops += n * c * h * w + n * c * oh * ow; - } - else if (sub_type == "max_pool3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int kd = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[0] : 1; - int kh = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[1] : 1; - int kw = op->has_param("kernel_size") ? op->params.at("kernel_size").ai[2] : 1; - int sd = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int sh = op->has_param("stride") ? op->params.at("stride").ai[1] : 1; - int sw = op->has_param("stride") ? op->params.at("stride").ai[2] : 1; - int pd = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int ph = op->has_param("padding") ? op->params.at("padding").ai[1] : 0; - int pw = op->has_param("padding") ? op->params.at("padding").ai[2] : 0; - int od = (d + 2 * pd - kd) / sd + 1; - int oh = (h + 2 * ph - kh) / sh + 1; - int ow = (w + 2 * pw - kw) / sw + 1; - flops += n * c * d * h * w * kd * kh * kw; - memops += n * c * d * h * w + n * c * od * oh * ow; - } - else if (sub_type == "prelu" || sub_type == "leaky_relu") - { - - } - else if (sub_type == "conv1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int k = op->inputs[1]->shape[0]; - int o = op->outputs[0]->shape[2]; - flops += 2 * n * c * l * k * o; - memops += 2 * n * c * l * k + n * o; - } - else if (sub_type == "conv2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int kh = op->inputs[1]->shape[2]; - int kw = op->inputs[1]->shape[3]; - int o = op->outputs[0]->shape[2]; - int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int g = op->has_param("groups") ? op->params.at("groups").i : 1; - flops += 2 * n * c * h * w * kh * kw * o / g; - memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; - } - else if (sub_type == "conv3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int kd = op->inputs[1]->shape[2]; - int kh = op->inputs[1]->shape[3]; - int kw = op->inputs[1]->shape[4]; - int o = op->outputs[0]->shape[2]; - int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int g = op->has_param("groups") ? op->params.at("groups").i : 1; - flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; - memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; - } - else if (sub_type == "conv_transpose1d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - int k = op->inputs[1]->shape[0]; - int o = op->outputs[0]->shape[2]; - flops += 2 * n * c * l * k * o; - memops += 2 * n * c * l * k + n * o; - } - else if (sub_type == "conv_transpose2d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int h = op->inputs[0]->shape[2]; - int w = op->inputs[0]->shape[3]; - int kh = op->inputs[1]->shape[2]; - int kw = op->inputs[1]->shape[3]; - int o = op->outputs[0]->shape[2]; - int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int g = op->has_param("groups") ? op->params.at("groups").i : 1; - flops += 2 * n * c * h * w * kh * kw * o / g; - memops += 2 * n * c * h * w * kh * kw / g + n * o * h * w; - } - else if (sub_type == "conv_transpose3d") - { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - int h = op->inputs[0]->shape[3]; - int w = op->inputs[0]->shape[4]; - int kd = op->inputs[1]->shape[2]; - int kh = op->inputs[1]->shape[3]; - int kw = op->inputs[1]->shape[4]; - int o = op->outputs[0]->shape[2]; - int s = op->has_param("stride") ? op->params.at("stride").ai[0] : 1; - int p = op->has_param("padding") ? op->params.at("padding").ai[0] : 0; - int g = op->has_param("groups") ? op->params.at("groups").i : 1; - flops += 2 * n * c * d * h * w * kd * kh * kw * o / g; - memops += 2 * n * c * d * h * w * kd * kh * kw / g + n * o * d * h * w; - } - else if (sub_type == "embedding") - { - /*todo*/ + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + int out_features = op->attrs.at("data").shape[0]; + flops += input_size * out_features; + if(op->has_param("bias")) + { + flops += out_features; + } + memops += input_size + output_size; } - else if (sub_type == "linear") + else if (sub_type == "avgpool1d" + || sub_type == "avgpool2d" + || sub_type == "avgpool3d" + || sub_type == "adaptive_avgpool1d" + || sub_type == "adaptive_avgpool2d" + || sub_type == "adaptive_avgpool3d") { - int n = op->inputs[0]->shape[0]; - int i = op->inputs[0]->shape[1]; - int o = op->outputs[0]->shape[1]; - flops += 2 * n * i * o; - memops += 2 * n * i + n * o; + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size; + memops += input_size + output_size; } - else if (sub_type == "log_softmax") + else if (sub_type == "prelu" + || sub_type == "elu" + || sub_type == "leaky_relu" + || sub_type == "gelu" + || sub_type == "silu" + || sub_type == "softmax") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - extra_flops += 2 * n * c * l; - extra_memops += 2 * n * c * l; + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + extra_flops += input_size; + extra_memops += input_size + output_size; } - else if (sub_type == "logsigmoid") + else if (sub_type == "unsample" + || sub_type == "upsample_nearest" + || sub_type == "upsample_bilinear") { - int n = op->inputs[0]->shape[0]; - int c = op->inputs[0]->shape[1]; - int l = op->inputs[0]->shape[2]; - extra_flops += 2 * n * c * l; - extra_memops += 2 * n * c * l; + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + extra_flops += output_size; + extra_memops += input_size + output_size; } - else if (sub_type == "scaled_dot_product_attention") + else if (sub_type == "interpolate") { - int n = op->inputs[0]->shape[0]; - int l = op->inputs[0]->shape[1]; - int d = op->inputs[0]->shape[2]; - flops += 2 * n * l * l + n * l * d + n * l * l * d; - memops += 2 * n * l * d + 3 * n * l * l + n * l; + std::vector input_shape = op->inputs[0]->shape; + std::vector output_shape = op->outputs[0]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + std::vector scale_factor = op->params.at("scale_factor").ai; + extra_flops += input_size * std::accumulate(scale_factor.begin(), scale_factor.end(), 1, std::multiplies()); + extra_memops += input_size + output_size; } } From 296954dee81b2b2bc8c1b87906c121031d538191 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Fri, 20 Sep 2024 12:56:07 +0800 Subject: [PATCH 20/28] all finished --- tools/pnnx/src/ir.cpp | 45 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index f05ca65a3c59..e4974f6b7c87 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1450,7 +1450,6 @@ void Graph::flops_memops_sum() { for (auto op : ops) { - fprintf(stderr, "op->type: %s\n", op->type.c_str()); if (op->type[0] == 'F') { std::string sub_type = op->type.substr(2); @@ -1880,6 +1879,50 @@ void Graph::flops_memops_sum() memops += memops_qkv + memops_attention + memops_output; } } + + else if (op->type.substr(0, 5) == "torch") + { + std::string sub_type = op->type.substr(6); + if(sub_type == "matmul" + || sub_type == "mm" + || sub_type == "bmm") + { + std::vector input_shape_1 = op->inputs[0]->shape; + std::vector input_shape_2 = op->inputs[1]->shape; + int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies()); + int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size_1 * input_shape_2.back(); + memops += input_size_1 + input_size_2 + output_size; + } + else if (sub_type == "addmm" + || sub_type == "baddbmm") + { + std::vector input_shape = op->inputs[0]->shape; + std::vector mat_shape_1 = op->inputs[1]->shape; + std::vector mat_shape_2 = op->inputs[2]->shape; + int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); + int mat_size_1 = std::accumulate(mat_shape_1.begin(), mat_shape_1.end(), 1, std::multiplies()); + int mat_size_2 = std::accumulate(mat_shape_2.begin(), mat_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += input_size + mat_size_1 * mat_shape_2.back(); + memops += input_size + mat_size_1 + mat_size_2 + output_size; + } + else if (sub_type == "mul" + || sub_type == "add") + { + std::vector input_shape_1 = op->inputs[0]->shape; + std::vector input_shape_2 = op->inputs[1]->shape; + int input_size_1 = std::accumulate(input_shape_1.begin(), input_shape_1.end(), 1, std::multiplies()); + int input_size_2 = std::accumulate(input_shape_2.begin(), input_shape_2.end(), 1, std::multiplies()); + std::vector output_shape = op->outputs[0]->shape; + int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + flops += output_size; + memops += input_size_1 + input_size_2 + output_size; + } + } } } From ec94d1aef522ff464c892d9b54705f24cff857ef Mon Sep 17 00:00:00 2001 From: luxincn <2391719912@qq.com> Date: Mon, 7 Oct 2024 23:27:59 +0800 Subject: [PATCH 21/28] update flops_mem_count --- tools/pnnx/src/ir.cpp | 541 +++++++++++++++++++++++++++++++++++++++- tools/pnnx/src/ir.h | 1 + tools/pnnx/src/main.cpp | 6 +- 3 files changed, 544 insertions(+), 4 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 16d7c96f5164..3fa407470ba1 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2864,7 +2864,6 @@ const Operand* Graph::get_operand(const std::string& name) const pnnx::ModelInfo Graph::flops_mem_count() { - pnnx::ModelInfo m; for (const Operator* op : ops) { if (op->type == "nn.Conv2d") @@ -3010,6 +3009,546 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += (kernel_add + kernel_avg) * (out_h * out_w) * in_c; m.memory_access += batch_size * in_c * (in_h * in_w + out_h * out_w) } + else if (op->type == "nn.BatchNorm2d") + { + int in_n, in_c, in_h, in_w; + bool affine; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + affine = op->params.at("affine").b; + if (affine) + { + m.flops += 7 * in_n * in_c * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_h * in_w; + } + else + { + m.flops += 5 * in_n * in_c * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_h * in_w; + } + } + else if (op->type == "nn.AdaptiveAvgPool2d") + { + int in_n, in_c, in_h, in_w, out_h, out_w, k_h, k_w, kernel_add, kernel_avg; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + out_h = op->params.at("output_size").ai[0]; + out_w = op->params.at("output_size").ai[1]; + if (out_h == 0){ + k_h = in_h; + } + else + { + k_h = (in_h + out_h -1) / out_h; + } + + if (out_w == 0){ + k_w = in_w; + } + else + { + k_w = (in_w + out_w -1) / out_w; + } + kernel_add = k_h * k_w - 1; + kernel_avg = 1; + m.flops += (kernel_add + kernel_avg) * out_h * out_w * in_c; + m.memory_access += in_n * in_c * (in_h * in_w + out_h * out_w); + } + else if (op->type == "nn.AdaptiveMaxPool2d") + { + int num_o, in_n, in_c, in_h, in_w, out_h, out_w; + num_o = op->params.at("return_indices").b ? 2 : 1; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + out_h = op->params.at("output_size").ai[0]; + out_w = op->params.at("output_size").ai[1]; + m.memory_access += in_n * in_c * (in_h * in_w + out_h * out_w * num_o); + } + else if (op->type == "nn.CELU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.ELU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Embedding") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Fold") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.GELU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 12 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.GLU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += ( 5 * mem ) / 2; + m.memory_access += std::round(1.5 * mem); + } + else if (op->type == "nn.GroupNorm") + { + int num_g, in_n, in_c, in_size, mem = 1; + num_g = op->params.at("num_groups").i; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_size = op->inputs[0]->shape.size(); + for (int index = 2; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 9 * in_n * in_c * mem + 2 * in_n * num_g; + m.memory_access += 2 * in_n * in_c * mem + 2 * in_n * num_g; + } + else if (op->type == "nn.GRU") + { + int in_size, h_size, num_layers, batch, seq; + bool batch_first, bidirectional; + in_size = op->params.at("input_size").i; + h_size = op->params.at("hidden_size").i; + num_layers = op->params.at("num_layers").i; + batch_first = op->params.at("batch_first").b; + bidirectional = op->params.at("bidirectional").b; + batch = op->inputs[0]->shape[batch_first ? 0 : 1]; + seq = op->inputs[0]->shape[batch_first ? 1 : 0]; + if (bidirectional) + { + m.flops += 2 * num_layers * batch * seq * h_size * (3 * in_size + 7); + } + else + { + m.flops += num_layers * batch * seq * h_size * (3 * in_size + 7); + } + m.memory_access += num_layers * batch * seq * in_size; + } + else if (op->type == "nn.Hardsigmoid") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 2 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Hardswish") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 3 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Hardtanh") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Identity") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.InstanceNorm2d") + { + if ( op->inputs[0]->shape.size() == 4) + { + int in_b, in_c, in_h, in_w; + bool affine; + in_b = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + affine = op->params.at("affine").b; + if (affine) + { + m.flops += 7 * in_b * in_c * in_h * in_w; + m.memory_access += 2 * in_b * in_c * in_h * in_w; + } + else + { + m.flops += 5 * in_b * in_c * in_h * in_w; + m.memory_access += 2 * in_b * in_c * in_h * in_w; + } + } + } + else if (op->type == "nn.LeakyReLU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.LocalResponseNorm") + { + int in_size, mem = 1, size, in_n, in_c; + size = op->params.at("size").i; + in_size = op->inputs[0]->shape.size(); + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + for (int index = 2; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += (size + 4) * in_n * in_c * mem; + m.memory_access += (2 + size) * in_n * in_c * mem; + } + else if (op->type == "nn.LogSigmoid") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 10 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.LogSoftmax") + { + int in_size, mem = 1, dim, in_n, in_c, in_h; + dim = op->params.at("dim").i; + in_size = op->inputs[0]->shape.size(); + if (dim == 0) + { + in_n = op->inputs[0]->shape[0]; + for (int index = 1; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += ( 7 * in_n + 4 ) * mem; + m.memory_access += 2 * in_n * mem; + } + else if (dim == 1) + { + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + for (int index = 2; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += ( 7 * in_c + 4 ) * in_n * mem; + m.memory_access += 2 * in_n * in_c * mem; + } + else if (dim == 2) + { + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + for (int index = 3; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += ( 7 * in_h + 4 ) * in_n * in_c * mem; + m.memory_access += 2 * in_n * in_c * in_h * mem; + } + } + else if (op->type == "nn.LSTM") + { + int hidden_size, num_layers, batch, seq, in_d; + bool batch_first; + hidden_size = op->params.at("hidden_size").i; + num_layers = op->params.at("num_layers").i; + batch_first = op->params.at("batch_first").b; + batch = op->inputs[0]->shape[batch_first ? 0 : 1]; + seq = op->inputs[0]->shape[batch_first ? 1 : 0]; + in_d = op->inputs[0]->shape[2]; + m.flops += num_layers * batch * seq * hidden_size * (8 * (in_d + hidden_size) + 23); + m.memory_access += num_layers * batch * (seq * in_d + 12 * seq * hidden_size); + } + else if (op->type == "nn.Mish") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 5 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.PixelShuffle") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.PixelUnshuffle") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.ReflectionPad1d") + { + int pad_left, pad_right, in_size, in_w, mem = 1; + pad_left = op->params.at("padding").ai[0]; + pad_right = op->params.at("padding").ai[1]; + in_size = op->inputs[0]->shape.size(); + in_w = op->inputs[0]->shape[-1]; + for (int index = 0; index < in_size - 1; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += mem * in_w + mem * (in_w + pad_left + pad_right); + } + else if (op->type == "nn.ReflectionPad2d") + { + int pad_left, pad_right, pad_top, pad_bottom, in_size, in_w, in_h, mem = 1; + pad_left = op->params.at("padding").ai[0]; + pad_right = op->params.at("padding").ai[1]; + pad_top = op->params.at("padding").ai[2]; + pad_bottom = op->params.at("padding").ai[3]; + in_size = op->inputs[0]->shape.size(); + in_h = op->inputs[0]->shape[-1]; + in_w = op->inputs[0]->shape[-2]; + for (int index = 0; index < in_size - 2; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += mem * in_w * in_h + mem * (in_w + pad_left + pad_right) * (in_h + pad_top + pad_bottom); + } + else if (op->type == "nn.ReLU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.ReLU6") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.ReplicationPad2d") + { + int pad_left, pad_right, pad_top, pad_bottom, in_size, in_w, in_h, mem = 1; + pad_left = op->params.at("padding").ai[0]; + pad_right = op->params.at("padding").ai[1]; + pad_top = op->params.at("padding").ai[2]; + pad_bottom = op->params.at("padding").ai[3]; + in_size = op->inputs[0]->shape.size(); + in_h = op->inputs[0]->shape[-1]; + in_w = op->inputs[0]->shape[-2]; + for (int index = 0; index < in_size - 2; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += mem * in_w * in_h + mem * (in_w + pad_left + pad_right) * (in_h + pad_top + pad_bottom); + } + else if (op->type == "nn.RNN") + { + int in_size, h_size, num_layers, batch, seq; + bool batch_first, bidirectional; + in_size = op->params.at("input_size").i; + h_size = op->params.at("hidden_size").i; + num_layers = op->params.at("num_layers").i; + batch_first = op->params.at("batch_first").b; + bidirectional = op->params.at("bidirectional").b; + batch = op->inputs[0]->shape[batch_first ? 0 : 1]; + seq = op->inputs[0]->shape[batch_first ? 1 : 0]; + if (bidirectional) + { + m.flops += 2 * batch * seq * 2 * (in_size * h_size + (num_layers - 1) * (h_size * h_size)); + } + else + { + m.flops += batch * seq * 2 * (in_size * h_size + (num_layers - 1) * (h_size * h_size)); + } + m.memory_access += batch * seq * (in_size + num_layers * h_size + h_size); + + } + else if (op->type == "nn.SELU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Sigmoid") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 7 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.SiLU") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 8 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Softmax") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 7 * mem - 1; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Softmax2d") + { + int in_n, in_c, in_h, in_w; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + m.flops += in_n * in_c * (7 * in_h * in_w -1); + m.memory_access += 2 * in_n * in_c * in_h * in_w; + } + else if (op->type == "nn.Tanh") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 9 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.Unfold") + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.UpsamplingBilinear2d") + { + int in_n, in_c, in_h, in_w; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + if (op->params.find("size") != op->params.end()) + { + int size_h = op->params.at("size").ai[0]; + int size_w = op->params.at("size").ai[1]; + m.flops += 4 * in_c * size_h * size_w; + m.memory_access += 5 * in_c * size_h * size_w; + } + else + { + int scale_h = op->params.at("scale_factor").ai[0]; + int scale_w = op->params.at("scale_factor").ai[1]; + int out_h = in_h * scale_h; + int out_w = in_w * scale_w; + m.flops += 4 * in_c * out_h * out_w; + m.memory_access += 5 * in_c * out_h * out_w; + } + } + else if (op->type == "torch.mm") + { + int first_h, first_w, second_h, second_w; + first_h = op->inputs[0]->shape[0]; + first_w = op->inputs[0]->shape[1]; + second_h = op->inputs[1]->shape[0]; + second_w = op->inputs[1]->shape[1]; + fprintf(stderr, "first_h: %d\n", first_h);//debug + fprintf(stderr, "first_w: %d\n", first_w);//debug + fprintf(stderr, "second_h: %d\n", second_h);//debug + fprintf(stderr, "second_w: %d\n", second_w);//debug + m.flops += first_h * second_w * (2 * first_w - 1); + m.memory_access += first_h * first_w + second_h * second_w + first_h * second_w; + } else { } diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index aefe6d598026..46dd7323e041 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -356,6 +356,7 @@ class Graph Operand* get_operand(const std::string& name); const Operand* get_operand(const std::string& name) const; + ModelInfo m; std::vector ops; std::vector operands; diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 0a0e71ebdc15..fe1c05663019 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -364,9 +364,9 @@ int main(int argc, char** argv) pnnx_graph.python(pnnxpypath, pnnxbinpath); // count float - pnnx::ModelInfo md = pnnx_graph.flops_mem_count(); - fprintf(stderr, "float ops: %lld\n", md.flops); - fprintf(stderr, "memory ops: %lld\n", md.memory_access); + pnnx_graph.flops_mem_count(); + fprintf(stderr, "float ops: %lld\n", pnnx_graph.m.flops); + fprintf(stderr, "memory ops: %lld\n", pnnx_graph.m.memory_access); #if BUILD_PNNX2ONNX pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); From dfc83d0e3bc95e2e993241f644acc1dd182eaa75 Mon Sep 17 00:00:00 2001 From: luxincn Date: Mon, 7 Oct 2024 15:46:43 +0000 Subject: [PATCH 22/28] apply code-format changes --- tools/pnnx/src/ir.cpp | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 3fa407470ba1..531a624551e9 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -3038,20 +3038,22 @@ pnnx::ModelInfo Graph::flops_mem_count() in_w = op->inputs[0]->shape[3]; out_h = op->params.at("output_size").ai[0]; out_w = op->params.at("output_size").ai[1]; - if (out_h == 0){ + if (out_h == 0) + { k_h = in_h; } else { - k_h = (in_h + out_h -1) / out_h; + k_h = (in_h + out_h - 1) / out_h; } - if (out_w == 0){ + if (out_w == 0) + { k_w = in_w; } else { - k_w = (in_w + out_w -1) / out_w; + k_w = (in_w + out_w - 1) / out_w; } kernel_add = k_h * k_w - 1; kernel_avg = 1; @@ -3129,7 +3131,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += ( 5 * mem ) / 2; + m.flops += (5 * mem) / 2; m.memory_access += std::round(1.5 * mem); } else if (op->type == "nn.GroupNorm") @@ -3211,7 +3213,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else if (op->type == "nn.InstanceNorm2d") { - if ( op->inputs[0]->shape.size() == 4) + if (op->inputs[0]->shape.size() == 4) { int in_b, in_c, in_h, in_w; bool affine; @@ -3279,7 +3281,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += ( 7 * in_n + 4 ) * mem; + m.flops += (7 * in_n + 4) * mem; m.memory_access += 2 * in_n * mem; } else if (dim == 1) @@ -3290,7 +3292,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += ( 7 * in_c + 4 ) * in_n * mem; + m.flops += (7 * in_c + 4) * in_n * mem; m.memory_access += 2 * in_n * in_c * mem; } else if (dim == 2) @@ -3302,7 +3304,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += ( 7 * in_h + 4 ) * in_n * in_c * mem; + m.flops += (7 * in_h + 4) * in_n * in_c * mem; m.memory_access += 2 * in_n * in_c * in_h * mem; } } @@ -3435,7 +3437,6 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += batch * seq * 2 * (in_size * h_size + (num_layers - 1) * (h_size * h_size)); } m.memory_access += batch * seq * (in_size + num_layers * h_size + h_size); - } else if (op->type == "nn.SELU") { @@ -3487,7 +3488,7 @@ pnnx::ModelInfo Graph::flops_mem_count() in_c = op->inputs[0]->shape[1]; in_h = op->inputs[0]->shape[2]; in_w = op->inputs[0]->shape[3]; - m.flops += in_n * in_c * (7 * in_h * in_w -1); + m.flops += in_n * in_c * (7 * in_h * in_w - 1); m.memory_access += 2 * in_n * in_c * in_h * in_w; } else if (op->type == "nn.Tanh") @@ -3523,7 +3524,7 @@ pnnx::ModelInfo Graph::flops_mem_count() int size_h = op->params.at("size").ai[0]; int size_w = op->params.at("size").ai[1]; m.flops += 4 * in_c * size_h * size_w; - m.memory_access += 5 * in_c * size_h * size_w; + m.memory_access += 5 * in_c * size_h * size_w; } else { @@ -3532,7 +3533,7 @@ pnnx::ModelInfo Graph::flops_mem_count() int out_h = in_h * scale_h; int out_w = in_w * scale_w; m.flops += 4 * in_c * out_h * out_w; - m.memory_access += 5 * in_c * out_h * out_w; + m.memory_access += 5 * in_c * out_h * out_w; } } else if (op->type == "torch.mm") @@ -3542,12 +3543,12 @@ pnnx::ModelInfo Graph::flops_mem_count() first_w = op->inputs[0]->shape[1]; second_h = op->inputs[1]->shape[0]; second_w = op->inputs[1]->shape[1]; - fprintf(stderr, "first_h: %d\n", first_h);//debug - fprintf(stderr, "first_w: %d\n", first_w);//debug - fprintf(stderr, "second_h: %d\n", second_h);//debug - fprintf(stderr, "second_w: %d\n", second_w);//debug + fprintf(stderr, "first_h: %d\n", first_h); //debug + fprintf(stderr, "first_w: %d\n", first_w); //debug + fprintf(stderr, "second_h: %d\n", second_h); //debug + fprintf(stderr, "second_w: %d\n", second_w); //debug m.flops += first_h * second_w * (2 * first_w - 1); - m.memory_access += first_h * first_w + second_h * second_w + first_h * second_w; + m.memory_access += first_h * first_w + second_h * second_w + first_h * second_w; } else { From 4a1ee56be6a25baf08debbd001ade9c8b0fd5c5a Mon Sep 17 00:00:00 2001 From: luxincn <2391719912@qq.com> Date: Sat, 12 Oct 2024 11:50:38 +0800 Subject: [PATCH 23/28] update flops_mem_count --- tools/pnnx/src/ir.cpp | 843 +++++++++++++++++++++++++++++++++--------- 1 file changed, 667 insertions(+), 176 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 531a624551e9..f360572ec227 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2866,61 +2866,258 @@ pnnx::ModelInfo Graph::flops_mem_count() { for (const Operator* op : ops) { - if (op->type == "nn.Conv2d") + if (op->type == "nn.Conv1d" || op->type == "nn.ConvTranspose1d" && op->inputs[0]->shape.size() == 3) { - if (op->inputs[0]->type != 0) + if(op->inputs[0]->type != 0) { - int ci = op->inputs[0]->shape[1]; - int kw = op->params.at("kernel_size").ai[0]; - int kh = op->params.at("kernel_size").ai[1]; - int co = op->params.at("out_channels").i; - int w = op->outputs[0]->shape[3]; - int h = op->outputs[0]->shape[2]; - int bias = op->params.at("bias").b ? 1 : 0; - int wi = op->inputs[0]->shape[2]; - int hi = op->inputs[0]->shape[3]; - int g = op->params.at("groups").i; - if (bias == 1) + int in_n, in_c, in_l, out_c, out_l, k_s, g; + bool bias; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_l = op->inputs[0]->shape[2]; + k_s = op->params.at("kernel_size").i; + out_c = op->params.at("out_channels").i; + out_l = op->outputs[0]->shape[2]; + bias = op->params.at("bias").b;//bias + if(bias) { - m.flops += 2 * ci * kw * kh * co * w * h; + m.flops += in_n * 2 * in_c * k_s * out_c * out_l; + m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c + out_c); } else { - m.flops += (2 * ci * kw * kh - 1) * co * w * h; + m.flops += in_n * (2 * in_c * k_s -1) * out_c * out_l; + m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c); } - int input_m = wi * hi * ci; - int output_m = w * h * co; - int weights_m = kw * kh * ci * co; - m.memory_access += input_m + output_m + weights_m; } } - else if (op->type == "nn.Linear") + else if (op->type == "F.conv1d" || op->type == "F.conv_transpose1d" && op->inputs[0]->shape.size() == 3) { - int in = op->params.at("in_features").i; - int out = op->params.at("out_features").i; - int bias = op->params.at("bias").b ? 1 : 0; - if (bias == 1) + if(op->inputs[0]->type != 0) { - m.flops += 2 * in * out; + int in_n, in_c, in_l, out_c, out_l, k_s, g; + bool bias = true; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_l = op->inputs[0]->shape[2]; + k_s = op->inputs[1]->shape[2]; + out_c = op->outputs[0]->shape[1]; + out_l = op->outputs[0]->shape[2]; + if (op->params.find("bias") != op->params.end()) + { + std::string val = Parameter::encode_to_string(op->params.at("bias")); + if (val == "None") + { + bias = false; + } + } + if(bias) + { + m.flops += in_n * 2 * in_c * k_s * out_c * out_l; + m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c + out_c); + } + else + { + m.flops += in_n * (2 * in_c * k_s -1) * out_c * out_l; + m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c); + } + } + } + else if (op->type == "nn.Conv2d" || op->type == "nn.ConvTranspose2d" && op->inputs[0]->shape.size() == 4) + { + if(op->inputs[0]->type != 0) + { + int in_n, in_c, in_h, in_w, out_c, out_h, out_w, k_h, k_w, g; + bool bias; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + k_h = op->params.at("kernel_size").ai[0]; + k_w = op->params.at("kernel_size").ai[1]; + out_c = op->params.at("out_channels").i; + out_h = op->outputs[0]->shape[2]; + out_w = op->outputs[0]->shape[3]; + bias = op->params.at("bias").b;//bias + if(bias) + { + m.flops += in_n * 2 * in_c * k_h * k_w * out_c * out_w * out_h; + m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + in_c * k_h * k_w * out_c + out_c); + } + else + { + m.flops += in_n * (2 * in_c * k_h * k_w -1) * out_c * out_w * out_h; + m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + k_h * k_w * in_c * out_c); + } + } + } + else if (op->type == "F.conv2d" || op->type == "F.conv_transpose2d" && op->inputs[0]->shape.size() == 4) + { + if(op->inputs[0]->type != 0) + { + int in_n, in_c, in_h, in_w, out_c, out_h, out_w, k_h, k_w, g; + bool bias = true; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + k_h = op->inputs[1]->shape[2]; + k_w = op->inputs[1]->shape[3]; + out_c = op->outputs[0]->shape[1]; + out_h = op->outputs[0]->shape[2]; + out_w = op->outputs[0]->shape[3]; + if (op->params.find("bias") != op->params.end()) + { + std::string val = Parameter::encode_to_string(op->params.at("bias")); + if (val == "None") + { + bias = false; + } + } + if(bias) + { + m.flops += in_n * 2 * in_c * k_h * k_w * out_c * out_w * out_h; + m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + in_c * k_h * k_w * out_c + out_c); + } + else + { + m.flops += in_n * (2 * in_c * k_h * k_w -1) * out_c * out_w * out_h; + m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + k_h * k_w * in_c * out_c); + } + } + } + else if (op->type == "nn.Conv3d" || op->type == "nn.ConvTranspose3d" && op->inputs[0]->shape.size() == 5) + { + if(op->inputs[0]->type != 0) + { + int in_n, in_c, in_d, in_h, in_w, out_c, out_d, out_h, out_w, k_d, k_h, k_w, g; + bool bias; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_d = op->inputs[0]->shape[2]; + in_h = op->inputs[0]->shape[3]; + in_w = op->inputs[0]->shape[4]; + k_d = op->params.at("kernel_size").ai[0]; + k_h = op->params.at("kernel_size").ai[1]; + k_w = op->params.at("kernel_size").ai[2]; + out_c = op->outputs[0]->shape[1]; + out_d = op->outputs[0]->shape[2]; + out_h = op->outputs[0]->shape[3]; + out_w = op->outputs[0]->shape[4]; + bias = op->params.at("bias").b;//bias + if(bias) + { + m.flops += in_n * 2 * in_c * k_d * k_h * k_w * out_c * out_d * out_w * out_h; + m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c + out_c); + } + else + { + m.flops += in_n * (2 * in_c * k_d * k_h * k_w -1) * out_c * out_d * out_w * out_h; + m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c); + } + } + } + else if (op->type == "F.conv3d" || op->type == "F.conv_transpose3d" && op->inputs[0]->shape.size() == 5) + { + if(op->inputs[0]->type != 0) + { + int in_n, in_c, in_d, in_h, in_w, out_c, out_d, out_h, out_w, k_d, k_h, k_w, g; + bool bias = true; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_d = op->inputs[0]->shape[2]; + in_h = op->inputs[0]->shape[3]; + in_w = op->inputs[0]->shape[4]; + k_d = op->inputs[1]->shape[2]; + k_h = op->inputs[1]->shape[3]; + k_w = op->inputs[1]->shape[4]; + out_c = op->outputs[0]->shape[1]; + out_d = op->outputs[0]->shape[2]; + out_h = op->outputs[0]->shape[3]; + out_w = op->outputs[0]->shape[4]; + if (op->params.find("bias") != op->params.end()) + { + std::string val = Parameter::encode_to_string(op->params.at("bias")); + if (val == "None") + { + bias = false; + } + } + if(bias) + { + m.flops += in_n * 2 * in_c * k_d * k_h * k_w * out_c * out_d * out_w * out_h; + m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c + out_c); + } + else + { + m.flops += in_n * (2 * in_c * k_d * k_h * k_w -1) * out_c * out_d * out_w * out_h; + m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c); + } + } + } + else if (op->type == "nn.Linear" && op->inputs[0]->shape.size() >= 1) + { + int in_size, in, out, mem = 1; + bool bias; + in = op->params.at("in_features").i; + out = op->params.at("out_features").i; + bias = op->params.at("bias").b; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size - 1; index++) + { + mem *= op->inputs[0]->shape[index]; + } + if(bias) + { + m.flops += mem * 2 * in * out; + m.memory_access += mem * (in + out + in * out + 1); } else { - m.flops += (2 * in - 1) * out; + m.flops += mem * (2 * in - 1) * out; + m.memory_access += mem * (in + out + in * out); } - m.memory_access += in + out + in * out; } - else if (op->type == "nn.MultiheadAttention") + else if (op->type == "F.linear" && op->inputs[0]->shape.size() >= 1) { - int in_size = op->inputs.size(); - - if (std::find(op->nputnames.begin(), op->inputnames.end(), "attn_mask") != op->inputnames.end()) + int in_size, in, out, mem = 1; + bool bias = true; + in = op->inputs[1]->shape[1]; + out = op->inputs[1]->shape[0]; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size - 1; index++) { - in_size -= 1; + mem *= op->inputs[0]->shape[index]; } - - int q_l, k_s, v_s; + if (op->params.find("bias") != op->params.end()) + { + std::string val = Parameter::encode_to_string(op->params.at("bias")); + if (val == "None") + { + bias = false; + } + } + if(bias) + { + m.flops += mem * 2 * in * out; + m.memory_access += mem * (in + out + in * out + 1); + } + else + { + m.flops += mem * (2 * in - 1) * out; + m.memory_access += mem * (in + out + in * out); + } + } + else if (op->type == "nn.MultiheadAttention" && op->inputs[0]->shape.size() == 3) + { + int in_size, q_l, k_s, v_s, num_heads, embed_dim, Kdim, vdim; + long long linear1, attention, linerar2, weights, in, attention_m, out; bool batch_first = op->params.find("batch_first") != op->params.end() && op->params.at("batch_first").b; - + in_size = op->inputs.size(); + if (std::find(op->inputnames.begin(), op->inputnames.end(), "attn_mask") != op->inputnames.end()) + { + in_size -= 1; + } if (in_size == 3) { q_l = op->inputs[0]->shape[batch_first ? 1 : 0]; @@ -2939,77 +3136,148 @@ pnnx::ModelInfo Graph::flops_mem_count() k_s = q_l; v_s = q_l; } - - int num_heads = op->params.at("num_heads").i; - int embed_dim = op->params.at("embed_dim").i; - int Kdim = op->params.at("kdim").i; - int vdim = op->params.at("vdim").i; - - long long linear1 = q_l * embed_dim * embed_dim + k_s * embed_dim * Kdim + v_s * embed_dim * vdim; - long long attention = q_l * k_s * embed_dim + 2 * q_l * k_s * num_heads + q_l * v_s * embed_dim; - long long linerar2 = q_l * embed_dim * embed_dim; + num_heads = op->params.at("num_heads").i; + embed_dim = op->params.at("embed_dim").i; + Kdim = op->params.at("kdim").i; + vdim = op->params.at("vdim").i; + linear1 = q_l * embed_dim * embed_dim + k_s * embed_dim * Kdim + v_s * embed_dim * vdim; + attention = q_l * k_s * embed_dim + 2 * q_l * k_s * num_heads + q_l * v_s * embed_dim; + linerar2 = q_l* embed_dim * embed_dim; m.flops += linear1 + attention + linerar2; - - long long weights = embed_dim * embed_dim + embed_dim * Kdim + embed_dim * vdim + num_heads * vdim * embed_dim; - long long in = q_l * embed_dim + k_s * Kdim + v_s * vdim; - long long attention_m = q_l * embed_dim + k_s * Kdim + 2 * q_l * k_s + v_s * vdim; - long long out = q_l * embed_dim; + weights = embed_dim * embed_dim + embed_dim * Kdim + embed_dim * vdim + num_heads * vdim * embed_dim; + in = q_l * embed_dim + k_s * Kdim + v_s * vdim; + attention_m = q_l * embed_dim + k_s * Kdim + 2 * q_l * k_s + v_s * vdim; + out = q_l * embed_dim; m.memory_access += weights + in + attention_m + out; } - else if (op->type == "nn.MaxPool2d") + else if (op->type == "nn.MaxPool1d" || op->type == "F.max_pool1d" && op->inputs[0]->shape.size() >= 1) { - int num_o = op->params.at("return_indices").b ? 2 : 1; - int batch_size, in_c, in_h, in_w, out_h, out_w; - if (op->inputs[0]->shape.size() == 4) + int num_o, in_size, out_size, in_l, out_l, mem = 1; + num_o = op->params.at("return_indices").b ? 2 : 1; + in_size = op->inputs[0]->shape.size(); + in_l = op->inputs[0]->shape[in_size - 1]; + out_size = op->outputs[0]->shape.size(); + out_l = op->outputs[0]->shape[out_size - 1]; + for (int index = 0; index < in_size - 1; index++) { - batch_size = op->inputs[0]->shape[0]; - in_c = op->inputs[0]->shape[1]; - in_h = op->inputs[0]->shape[2]; - in_w = op->inputs[0]->shape[3]; - out_h = op->outputs[0]->shape[2]; - out_w = op->outputs[0]->shape[3]; + mem *= op->inputs[0]->shape[index]; } - else if (op->inputs[0]->shape.size() == 3) + m.memory_access += mem * (in_l + out_l * num_o); + } + else if (op->type == "nn.MaxPool2d" || op->type == "F.max_pool2d" && op->inputs[0]->shape.size() >= 2) + { + int num_o, in_size, out_size, in_h, in_w, out_h, out_w, mem = 1; + num_o = op->params.at("return_indices").b ? 2 : 1; + in_size = op->inputs[0]->shape.size(); + in_h = op->inputs[0]->shape[in_size - 2]; + in_w = op->inputs[0]->shape[in_size - 1]; + out_size = op->outputs[0]->shape.size(); + out_h = op->outputs[0]->shape[out_size - 2]; + out_w = op->outputs[0]->shape[out_size - 1]; + for (int index = 0; index < in_size - 2; index++) { - batch_size = 1; - in_c = op->inputs[0]->shape[0]; - in_h = op->inputs[0]->shape[1]; - in_w = op->inputs[0]->shape[2]; - out_h = op->outputs[0]->shape[1]; - out_w = op->outputs[0]->shape[2]; + mem *= op->inputs[0]->shape[index]; } - m.memory_access += batch_size * in_c * (in_h * in_w + out_h * out_w * num_o) + m.memory_access += mem * (in_h * in_w + out_h * out_w * num_o); } - else if (op->type == "nn.AvgPool2d") + else if (op->type == "nn.MaxPool3d" || op->type == "F.max_pool3d" && op->inputs[0]->shape.size() >= 3) { - int batch_size, in_c, in_h, in_w, out_h, out_w, k_h, k_w, kernel_add, kernel_avg; - if (op->inputs[0]->shape.size() == 4) + int num_o, in_size, out_size, in_d, in_h, in_w, out_d, out_h, out_w, mem = 1; + num_o = op->params.at("return_indices").b ? 2 : 1; + in_size = op->inputs[0]->shape.size(); + in_d = op->inputs[0]->shape[in_size - 3]; + in_h = op->inputs[0]->shape[in_size - 2]; + in_w = op->inputs[0]->shape[in_size - 1]; + out_size = op->outputs[0]->shape.size(); + out_d = op->outputs[0]->shape[out_size - 3]; + out_h = op->outputs[0]->shape[out_size - 2]; + out_w = op->outputs[0]->shape[out_size - 1]; + for (int index = 0; index < in_size - 3; index++) { - batch_size = op->inputs[0]->shape[0]; - in_c = op->inputs[0]->shape[1]; - in_h = op->inputs[0]->shape[2]; - in_w = op->inputs[0]->shape[3]; - out_h = op->outputs[0]->shape[2]; - out_w = op->outputs[0]->shape[3]; + mem *= op->inputs[0]->shape[index]; } - else if (op->inputs[0]->shape.size() == 3) + m.memory_access += mem * (in_d * in_h * in_w + out_d * out_h * out_w * num_o); + } + else if (op->type == "nn.AvgPool1d" || op->type == "F.avg_pool1d" && op->inputs[0]->shape.size() >= 1) + { + int in_size, out_size, in_l, out_l, k_l, kernel_add, kernel_avg, mem = 1; + in_size = op->inputs[0]->shape.size(); + in_l = op->inputs[0]->shape[in_size - 1]; + out_size = op->outputs[0]->shape.size(); + out_l = op->outputs[0]->shape[out_size - 1]; + k_l = op->params.at("kernel_size").i; + kernel_add = k_l - 1; + kernel_avg = 1; + for (int index = 0; index < in_size - 1; index++) { - batch_size = 1; - in_c = op->inputs[0]->shape[0]; - in_h = op->inputs[0]->shape[1]; - in_w = op->inputs[0]->shape[2]; - out_h = op->outputs[0]->shape[1]; - out_w = op->outputs[0]->shape[2]; + mem *= op->inputs[0]->shape[index]; } + m.flops += (kernel_add + kernel_avg) * out_l * mem; + m.memory_access += mem * (in_l+ out_l); + } + else if (op->type == "nn.AvgPool2d" || op->type == "F.avg_pool2d" && op->inputs[0]->shape.size() >= 2) + { + int in_size, out_size, in_h, in_w, out_h, out_w, k_h, k_w, kernel_add, kernel_avg, mem = 1; + in_size = op->inputs[0]->shape.size(); + in_h = op->inputs[0]->shape[in_size - 2]; + in_w = op->inputs[0]->shape[in_size - 1]; + out_size = op->outputs[0]->shape.size(); + out_h = op->outputs[0]->shape[out_size - 2]; + out_w = op->outputs[0]->shape[out_size - 1]; k_h = op->params.at("kernel_size").ai[0]; k_w = op->params.at("kernel_size").ai[1]; - kernel_add = k_h * k_w - 1; kernel_avg = 1; - m.flops += (kernel_add + kernel_avg) * (out_h * out_w) * in_c; - m.memory_access += batch_size * in_c * (in_h * in_w + out_h * out_w) + for (int index = 0; index < in_size - 2; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += (kernel_add + kernel_avg) * (out_h * out_w) * mem; + m.memory_access += mem * (in_h * in_w + out_h * out_w); + } + else if (op->type == "nn.AvgPool3d" || op->type == "F.avg_pool3d" && op->inputs[0]->shape.size() >= 3) + { + int in_size, out_size, in_d, in_h, in_w, out_d, out_h, out_w, k_d, k_h, k_w, kernel_add, kernel_avg, mem = 1; + in_size = op->inputs[0]->shape.size(); + in_d = op->inputs[0]->shape[in_size - 3]; + in_h = op->inputs[0]->shape[in_size - 2]; + in_w = op->inputs[0]->shape[in_size - 1]; + out_size = op->outputs[0]->shape.size(); + out_d = op->outputs[0]->shape[out_size - 3]; + out_h = op->outputs[0]->shape[out_size - 2]; + out_w = op->outputs[0]->shape[out_size - 1]; + k_d = op->params.at("kernel_size").ai[0]; + k_h = op->params.at("kernel_size").ai[1]; + k_w = op->params.at("kernel_size").ai[2]; + kernel_add = k_d * k_h * k_w - 1; + kernel_avg = 1; + for (int index = 0; index < in_size - 3; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += (kernel_add + kernel_avg) * (out_d * out_h * out_w) * mem; + m.memory_access += mem * (in_d * in_h * in_w + out_d * out_h * out_w); + } + else if (op->type == "nn.BatchNorm1d" && op->inputs[0]->shape.size() == 3) + { + int in_n, in_c, in_l; + bool affine; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_l = op->inputs[0]->shape[2]; + affine = op->params.at("affine").b; + if (affine) + { + m.flops += 7 * in_n * in_c * in_l; + m.memory_access += 2 * in_n * in_c * in_l; + } + else + { + m.flops += 5 * in_n * in_c * in_l; + m.memory_access += 2 * in_n * in_c * in_l; + } } - else if (op->type == "nn.BatchNorm2d") + else if (op->type == "nn.BatchNorm2d" && op->inputs[0]->shape.size() == 4) { int in_n, in_c, in_h, in_w; bool affine; @@ -3029,7 +3297,60 @@ pnnx::ModelInfo Graph::flops_mem_count() m.memory_access += 2 * in_n * in_c * in_h * in_w; } } - else if (op->type == "nn.AdaptiveAvgPool2d") + else if (op->type == "nn.BatchNorm3d" && op->inputs[0]->shape.size() == 5) + { + int in_n, in_c, in_d, in_h, in_w; + bool affine; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_d = op->inputs[0]->shape[2]; + in_h = op->inputs[0]->shape[3]; + in_w = op->inputs[0]->shape[4]; + affine = op->params.at("affine").b; + if (affine) + { + m.flops += 7 * in_n * in_c * in_d * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_d * in_h * in_w; + } + else + { + m.flops += 5 * in_n * in_c * in_d * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_d * in_h * in_w; + } + } + else if (op->type == "F.batch_norm" && op->inputs[0]->shape.size() >= 1) + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 5 * mem; + m.memory_access += 2 * mem; + } + else if (op->type == "nn.AdaptiveAvgPool1d" || op->type == "F.adaptive_avg_pool1d" && op->inputs[0]->shape.size() == 3) + { + int in_n, in_c, in_l, out_l, k_l, kernel_add, kernel_avg; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_l = op->inputs[0]->shape[2]; + out_l = op->params.at("output_size").i; + if (out_l == 0) + { + k_l = 1; + out_l = in_l; + } + else + { + k_l = (in_l + out_l -1) / out_l; + } + kernel_add = k_l - 1; + kernel_avg = 1; + m.flops += (kernel_add + kernel_avg) * out_l * in_c * in_n; + m.memory_access += in_n * in_c * (in_l + out_l); + } + else if (op->type == "nn.AdaptiveAvgPool2d" || op->type == "F.adaptive_avg_pool2d" && op->inputs[0]->shape.size() == 4) { int in_n, in_c, in_h, in_w, out_h, out_w, k_h, k_w, kernel_add, kernel_avg; in_n = op->inputs[0]->shape[0]; @@ -3040,27 +3361,85 @@ pnnx::ModelInfo Graph::flops_mem_count() out_w = op->params.at("output_size").ai[1]; if (out_h == 0) { - k_h = in_h; + k_h = 1; + out_h = in_h; } else { - k_h = (in_h + out_h - 1) / out_h; + k_h = (in_h + out_h -1) / out_h; } - if (out_w == 0) { - k_w = in_w; + k_w = 1; + out_w = in_w; } else { - k_w = (in_w + out_w - 1) / out_w; + k_w = (in_w + out_w -1) / out_w; } kernel_add = k_h * k_w - 1; kernel_avg = 1; m.flops += (kernel_add + kernel_avg) * out_h * out_w * in_c; m.memory_access += in_n * in_c * (in_h * in_w + out_h * out_w); } - else if (op->type == "nn.AdaptiveMaxPool2d") + else if (op->type == "nn.AdaptiveAvgPool3d" || op->type == "F.adaptive_avg_pool3d" && op->inputs[0]->shape.size() == 5) + { + int in_n, in_c, in_d, in_h, in_w,out_d, out_h, out_w, k_d, k_h, k_w, kernel_add, kernel_avg; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_d = op->inputs[0]->shape[2]; + in_h = op->inputs[0]->shape[3]; + in_w = op->inputs[0]->shape[4]; + out_d = op->params.at("output_size").ai[0]; + out_h = op->params.at("output_size").ai[1]; + out_w = op->params.at("output_size").ai[2]; + if (out_d == 0) + { + k_d = 1; + out_d = in_d; + } + else + { + k_d = (in_d + out_d -1) / out_d; + } + if (out_h == 0) + { + k_h = 1; + out_h = in_h; + } + else + { + k_h = (in_h + out_h -1) / out_h; + } + if (out_w == 0) + { + k_w = 1; + out_w = in_w; + } + else + { + k_w = (in_w + out_w -1) / out_w; + } + kernel_add = k_d * k_h * k_w - 1; + kernel_avg = 1; + m.flops += (kernel_add + kernel_avg) * out_d * out_h * out_w * in_c * in_n; + m.memory_access += in_n * in_c * (in_d * in_h * in_w + out_d * out_h * out_w); + } + else if (op->type == "nn.AdaptiveMaxPool1d" || op->type == "F.adaptive_max_pool1d" && op->inputs[0]->shape.size() == 3) + { + int num_o, in_n, in_c, in_l, out_l; + num_o = op->params.at("return_indices").b ? 2 : 1; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_l = op->inputs[0]->shape[2]; + out_l = op->params.at("output_size").i; + if (out_l == 0) + { + out_l = in_l; + } + m.memory_access += in_n * in_c * (in_l + out_l * num_o); + } + else if (op->type == "nn.AdaptiveMaxPool2d" || op->type == "F.adaptive_max_pool2d" && op->inputs[0]->shape.size() == 4) { int num_o, in_n, in_c, in_h, in_w, out_h, out_w; num_o = op->params.at("return_indices").b ? 2 : 1; @@ -3070,9 +3449,43 @@ pnnx::ModelInfo Graph::flops_mem_count() in_w = op->inputs[0]->shape[3]; out_h = op->params.at("output_size").ai[0]; out_w = op->params.at("output_size").ai[1]; + if (out_h == 0) + { + out_h = in_h; + } + if (out_w == 0) + { + out_w = in_w; + } m.memory_access += in_n * in_c * (in_h * in_w + out_h * out_w * num_o); } - else if (op->type == "nn.CELU") + else if (op->type == "nn.AdaptiveMaxPool3d" || op->type == "F.adaptive_max_pool3d" && op->inputs[0]->shape.size() == 5) + { + int num_o, in_n, in_c,in_d, in_h, in_w, out_d, out_h, out_w; + num_o = op->params.at("return_indices").b ? 2 : 1; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_d = op->inputs[0]->shape[2]; + in_h = op->inputs[0]->shape[3]; + in_w = op->inputs[0]->shape[4]; + out_d = op->params.at("output_size").ai[0]; + out_h = op->params.at("output_size").ai[1]; + out_w = op->params.at("output_size").ai[2]; + if (out_d == 0) + { + out_d = in_d; + } + if (out_h == 0) + { + out_h = in_h; + } + if (out_w == 0) + { + out_w = in_w; + } + m.memory_access += in_n * in_c * (in_d * in_h * in_w + out_d * out_h * out_w * num_o); + } + else if (op->type == "nn.CELU" || op->type == "F.celu" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3082,7 +3495,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.ELU") + else if (op->type == "nn.ELU" || op->type == "F.elu" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3092,7 +3505,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.Embedding") + else if (op->type == "nn.Embedding" || op->type == "F.embedding" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3102,7 +3515,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.Fold") + else if (op->type == "nn.Fold" || op->type == "F.fold" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3112,7 +3525,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.GELU") + else if (op->type == "nn.GELU" || op->type == "F.gelu" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3123,7 +3536,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 12 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.GLU") + else if (op->type == "nn.GLU" || op->type == "F.glu" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3131,10 +3544,10 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += (5 * mem) / 2; + m.flops += std::round(2.5 * mem); m.memory_access += std::round(1.5 * mem); } - else if (op->type == "nn.GroupNorm") + else if (op->type == "nn.GroupNorm" && op->inputs[0]->shape.size() >= 2) { int num_g, in_n, in_c, in_size, mem = 1; num_g = op->params.at("num_groups").i; @@ -3148,7 +3561,21 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 9 * in_n * in_c * mem + 2 * in_n * num_g; m.memory_access += 2 * in_n * in_c * mem + 2 * in_n * num_g; } - else if (op->type == "nn.GRU") + else if (op->type == "F.group_norm" && op->inputs[0]->shape.size() >= 2) + { + int num_g, in_n, in_c, in_size, mem = 1; + num_g = op->params.at("num_groups").i; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_size = op->inputs[0]->shape.size(); + for (int index = 2; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.flops += 9 * in_n * in_c * mem + 2 * in_n * num_g; + m.memory_access += 2 * in_n * in_c * mem + 2 * in_n * num_g; + } + else if (op->type == "nn.GRU" && op->inputs[0]->shape.size() == 3) { int in_size, h_size, num_layers, batch, seq; bool batch_first, bidirectional; @@ -3169,7 +3596,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += num_layers * batch * seq * in_size; } - else if (op->type == "nn.Hardsigmoid") + else if (op->type == "nn.Hardsigmoid" || op->type == "F.hardsigmoid" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3180,7 +3607,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 2 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.Hardswish") + else if (op->type == "nn.Hardswish" || op->type == "F.hardswish" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3191,7 +3618,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 3 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.Hardtanh") + else if (op->type == "nn.Hardtanh" || op->type == "F.hardtanh" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3201,7 +3628,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.Identity") + else if (op->type == "nn.Identity" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3211,30 +3638,67 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.InstanceNorm2d") + else if (op->type == "nn.InstanceNorm1d" && op->inputs[0]->shape.size() == 3) { - if (op->inputs[0]->shape.size() == 4) + int in_n, in_c, in_l; + bool affine; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_l = op->inputs[0]->shape[2]; + affine = op->params.at("affine").b; + if (affine) { - int in_b, in_c, in_h, in_w; - bool affine; - in_b = op->inputs[0]->shape[0]; - in_c = op->inputs[0]->shape[1]; - in_h = op->inputs[0]->shape[2]; - in_w = op->inputs[0]->shape[3]; - affine = op->params.at("affine").b; - if (affine) - { - m.flops += 7 * in_b * in_c * in_h * in_w; - m.memory_access += 2 * in_b * in_c * in_h * in_w; - } - else - { - m.flops += 5 * in_b * in_c * in_h * in_w; - m.memory_access += 2 * in_b * in_c * in_h * in_w; - } + m.flops += 7 * in_n * in_c * in_l; + m.memory_access += 2 * in_n * in_c * in_l; + } + else + { + m.flops += 5 * in_n * in_c * in_l; + m.memory_access += 2 * in_n * in_c * in_l; + } + } + else if (op->type == "nn.InstanceNorm2d" && op->inputs[0]->shape.size() == 4) + { + int in_n, in_c, in_h, in_w; + bool affine; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_h = op->inputs[0]->shape[2]; + in_w = op->inputs[0]->shape[3]; + affine = op->params.at("affine").b; + if (affine) + { + m.flops += 7 * in_n * in_c * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_h * in_w; + } + else + { + m.flops += 5 * in_n * in_c * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_h * in_w; + } + } + else if (op->type == "nn.InstanceNorm3d" && op->inputs[0]->shape.size() == 5) + { + int in_n, in_c, in_d, in_h, in_w; + bool affine; + in_n = op->inputs[0]->shape[0]; + in_c = op->inputs[0]->shape[1]; + in_d = op->inputs[0]->shape[2]; + in_h = op->inputs[0]->shape[3]; + in_w = op->inputs[0]->shape[4]; + affine = op->params.at("affine").b; + if (affine) + { + m.flops += 7 * in_n * in_c * in_d * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_d * in_h * in_w; + } + else + { + m.flops += 5 * in_n * in_c * in_d * in_h * in_w; + m.memory_access += 2 * in_n * in_c * in_d * in_h * in_w; } } - else if (op->type == "nn.LeakyReLU") + else if (op->type == "F.instance_norm" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3242,9 +3706,20 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } + m.flops += 5 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.LocalResponseNorm") + else if (op->type == "nn.LeakyReLU" || op->type == "F.leaky_relu" && op->inputs[0]->shape.size() >= 1) + { + int in_size, mem = 1; + in_size = op->inputs[0]->shape.size(); + for (int index = 0; index < in_size; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += 2 * mem; + } + else if (op->type == "nn.LocalResponseNorm" && op->inputs[0]->shape.size() >= 2) { int in_size, mem = 1, size, in_n, in_c; size = op->params.at("size").i; @@ -3258,7 +3733,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += (size + 4) * in_n * in_c * mem; m.memory_access += (2 + size) * in_n * in_c * mem; } - else if (op->type == "nn.LogSigmoid") + else if (op->type == "nn.LogSigmoid" || op->type == "F.logsigmoid" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3269,7 +3744,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 10 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.LogSoftmax") + else if (op->type == "nn.LogSoftmax" || op->type == "F.log_softmax" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1, dim, in_n, in_c, in_h; dim = op->params.at("dim").i; @@ -3281,7 +3756,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += (7 * in_n + 4) * mem; + m.flops += ( 7 * in_n + 4 ) * mem; m.memory_access += 2 * in_n * mem; } else if (dim == 1) @@ -3292,7 +3767,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += (7 * in_c + 4) * in_n * mem; + m.flops += ( 7 * in_c + 4 ) * in_n * mem; m.memory_access += 2 * in_n * in_c * mem; } else if (dim == 2) @@ -3304,11 +3779,11 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += (7 * in_h + 4) * in_n * in_c * mem; + m.flops += ( 7 * in_h + 4 ) * in_n * in_c * mem; m.memory_access += 2 * in_n * in_c * in_h * mem; } } - else if (op->type == "nn.LSTM") + else if (op->type == "nn.LSTM" && op->inputs[0]->shape.size() == 3) { int hidden_size, num_layers, batch, seq, in_d; bool batch_first; @@ -3321,7 +3796,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += num_layers * batch * seq * hidden_size * (8 * (in_d + hidden_size) + 23); m.memory_access += num_layers * batch * (seq * in_d + 12 * seq * hidden_size); } - else if (op->type == "nn.Mish") + else if (op->type == "nn.Mish" || op->type == "F.mish" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3332,7 +3807,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 5 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.PixelShuffle") + else if (op->type == "nn.PixelShuffle" || op->type == "F.pixel_shuffle" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3342,7 +3817,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.PixelUnshuffle") + else if (op->type == "nn.PixelUnshuffle" || op->type == "F.pixel_unshuffle" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3352,20 +3827,20 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.ReflectionPad1d") + else if (op->type == "nn.ReflectionPad1d" && op->inputs[0]->shape.size() >= 1) { int pad_left, pad_right, in_size, in_w, mem = 1; pad_left = op->params.at("padding").ai[0]; pad_right = op->params.at("padding").ai[1]; in_size = op->inputs[0]->shape.size(); - in_w = op->inputs[0]->shape[-1]; + in_w = op->inputs[0]->shape[in_size - 1]; for (int index = 0; index < in_size - 1; index++) { mem *= op->inputs[0]->shape[index]; } m.memory_access += mem * in_w + mem * (in_w + pad_left + pad_right); } - else if (op->type == "nn.ReflectionPad2d") + else if (op->type == "nn.ReflectionPad2d" && op->inputs[0]->shape.size() >= 2) { int pad_left, pad_right, pad_top, pad_bottom, in_size, in_w, in_h, mem = 1; pad_left = op->params.at("padding").ai[0]; @@ -3373,15 +3848,15 @@ pnnx::ModelInfo Graph::flops_mem_count() pad_top = op->params.at("padding").ai[2]; pad_bottom = op->params.at("padding").ai[3]; in_size = op->inputs[0]->shape.size(); - in_h = op->inputs[0]->shape[-1]; - in_w = op->inputs[0]->shape[-2]; + in_h = op->inputs[0]->shape[in_size - 2]; + in_w = op->inputs[0]->shape[in_size - 1]; for (int index = 0; index < in_size - 2; index++) { mem *= op->inputs[0]->shape[index]; } m.memory_access += mem * in_w * in_h + mem * (in_w + pad_left + pad_right) * (in_h + pad_top + pad_bottom); } - else if (op->type == "nn.ReLU") + else if (op->type == "nn.ReLU" || op->type == "F.relu" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3391,7 +3866,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.ReLU6") + else if (op->type == "nn.ReLU6" || op->type == "F.relu6" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3401,7 +3876,20 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.ReplicationPad2d") + else if (op->type == "nn.ReplicationPad1d" && op->inputs[0]->shape.size() == 3) + { + int pad_left, pad_right, in_size, in_l, mem = 1; + pad_left = op->params.at("padding").ai[0]; + pad_right = op->params.at("padding").ai[1]; + in_size = op->inputs[0]->shape.size(); + in_l = op->inputs[0]->shape[in_size - 1]; + for (int index = 0; index < in_size - 1; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += mem * in_l + mem * (in_l + pad_left + pad_right); + } + else if (op->type == "nn.ReplicationPad2d" && op->inputs[0]->shape.size() == 4) { int pad_left, pad_right, pad_top, pad_bottom, in_size, in_w, in_h, mem = 1; pad_left = op->params.at("padding").ai[0]; @@ -3409,15 +3897,34 @@ pnnx::ModelInfo Graph::flops_mem_count() pad_top = op->params.at("padding").ai[2]; pad_bottom = op->params.at("padding").ai[3]; in_size = op->inputs[0]->shape.size(); - in_h = op->inputs[0]->shape[-1]; - in_w = op->inputs[0]->shape[-2]; + in_h = op->inputs[0]->shape[in_size - 2]; + in_w = op->inputs[0]->shape[in_size - 1]; for (int index = 0; index < in_size - 2; index++) { mem *= op->inputs[0]->shape[index]; } m.memory_access += mem * in_w * in_h + mem * (in_w + pad_left + pad_right) * (in_h + pad_top + pad_bottom); } - else if (op->type == "nn.RNN") + else if (op->type == "nn.ReplicationPad3d" && op->inputs[0]->shape.size() == 5) + { + int pad_front, pad_back, pad_left, pad_right, pad_top, pad_bottom, in_size, in_d, in_h, in_w, mem = 1; + pad_left = op->params.at("padding").ai[0]; + pad_right = op->params.at("padding").ai[1]; + pad_top = op->params.at("padding").ai[2]; + pad_bottom = op->params.at("padding").ai[3]; + pad_front = op->params.at("padding").ai[4]; + pad_back = op->params.at("padding").ai[5]; + in_size = op->inputs[0]->shape.size(); + in_d = op->inputs[0]->shape[in_size - 3]; + in_h = op->inputs[0]->shape[in_size - 2]; + in_w = op->inputs[0]->shape[in_size - 1]; + for (int index = 0; index < in_size - 3; index++) + { + mem *= op->inputs[0]->shape[index]; + } + m.memory_access += mem * in_d * in_h * in_w + mem * (in_d + pad_front + pad_back) * (in_h + pad_top + pad_bottom) * (in_w + pad_left + pad_right); + } + else if (op->type == "nn.RNN" && op->inputs[0]->shape.size() == 3) { int in_size, h_size, num_layers, batch, seq; bool batch_first, bidirectional; @@ -3437,8 +3944,9 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += batch * seq * 2 * (in_size * h_size + (num_layers - 1) * (h_size * h_size)); } m.memory_access += batch * seq * (in_size + num_layers * h_size + h_size); + } - else if (op->type == "nn.SELU") + else if (op->type == "nn.SELU" || op->type == "F.selu" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3448,7 +3956,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.Sigmoid") + else if (op->type == "nn.Sigmoid" || op->type == "F.sigmoid" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3459,7 +3967,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 7 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.SiLU") + else if (op->type == "nn.SiLU" || op->type == "F.silu" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3470,7 +3978,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 8 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.Softmax") + else if (op->type == "nn.Softmax" || op->type == "F.softmax" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3481,17 +3989,17 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 7 * mem - 1; m.memory_access += 2 * mem; } - else if (op->type == "nn.Softmax2d") + else if (op->type == "nn.Softmax2d" && op->inputs[0]->shape.size() == 4) { int in_n, in_c, in_h, in_w; in_n = op->inputs[0]->shape[0]; in_c = op->inputs[0]->shape[1]; in_h = op->inputs[0]->shape[2]; in_w = op->inputs[0]->shape[3]; - m.flops += in_n * in_c * (7 * in_h * in_w - 1); + m.flops += in_n * in_c * (7 * in_h * in_w -1); m.memory_access += 2 * in_n * in_c * in_h * in_w; } - else if (op->type == "nn.Tanh") + else if (op->type == "nn.Tanh" || op->type == "F.tanh" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3502,7 +4010,7 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += 9 * mem; m.memory_access += 2 * mem; } - else if (op->type == "nn.Unfold") + else if (op->type == "nn.Unfold" || op->type == "F.unfold" && op->inputs[0]->shape.size() >= 1) { int in_size, mem = 1; in_size = op->inputs[0]->shape.size(); @@ -3512,7 +4020,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } m.memory_access += 2 * mem; } - else if (op->type == "nn.UpsamplingBilinear2d") + else if (op->type == "nn.UpsamplingBilinear2d" || op->type == "F.upsample_bilinear" && op->inputs[0]->shape.size() == 4) { int in_n, in_c, in_h, in_w; in_n = op->inputs[0]->shape[0]; @@ -3524,7 +4032,7 @@ pnnx::ModelInfo Graph::flops_mem_count() int size_h = op->params.at("size").ai[0]; int size_w = op->params.at("size").ai[1]; m.flops += 4 * in_c * size_h * size_w; - m.memory_access += 5 * in_c * size_h * size_w; + m.memory_access += 5 * in_c * size_h * size_w; } else { @@ -3533,26 +4041,9 @@ pnnx::ModelInfo Graph::flops_mem_count() int out_h = in_h * scale_h; int out_w = in_w * scale_w; m.flops += 4 * in_c * out_h * out_w; - m.memory_access += 5 * in_c * out_h * out_w; + m.memory_access += 5 * in_c * out_h * out_w; } } - else if (op->type == "torch.mm") - { - int first_h, first_w, second_h, second_w; - first_h = op->inputs[0]->shape[0]; - first_w = op->inputs[0]->shape[1]; - second_h = op->inputs[1]->shape[0]; - second_w = op->inputs[1]->shape[1]; - fprintf(stderr, "first_h: %d\n", first_h); //debug - fprintf(stderr, "first_w: %d\n", first_w); //debug - fprintf(stderr, "second_h: %d\n", second_h); //debug - fprintf(stderr, "second_w: %d\n", second_w); //debug - m.flops += first_h * second_w * (2 * first_w - 1); - m.memory_access += first_h * first_w + second_h * second_w + first_h * second_w; - } - else - { - } } return m; From d75af72db29362a172a4103e767ce484278caa19 Mon Sep 17 00:00:00 2001 From: luxincn Date: Sat, 12 Oct 2024 03:53:04 +0000 Subject: [PATCH 24/28] apply code-format changes --- tools/pnnx/src/ir.cpp | 79 +++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index f360572ec227..cfd15de042db 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2868,7 +2868,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { if (op->type == "nn.Conv1d" || op->type == "nn.ConvTranspose1d" && op->inputs[0]->shape.size() == 3) { - if(op->inputs[0]->type != 0) + if (op->inputs[0]->type != 0) { int in_n, in_c, in_l, out_c, out_l, k_s, g; bool bias; @@ -2878,22 +2878,22 @@ pnnx::ModelInfo Graph::flops_mem_count() k_s = op->params.at("kernel_size").i; out_c = op->params.at("out_channels").i; out_l = op->outputs[0]->shape[2]; - bias = op->params.at("bias").b;//bias - if(bias) + bias = op->params.at("bias").b; //bias + if (bias) { m.flops += in_n * 2 * in_c * k_s * out_c * out_l; m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c + out_c); } else { - m.flops += in_n * (2 * in_c * k_s -1) * out_c * out_l; + m.flops += in_n * (2 * in_c * k_s - 1) * out_c * out_l; m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c); } } } else if (op->type == "F.conv1d" || op->type == "F.conv_transpose1d" && op->inputs[0]->shape.size() == 3) { - if(op->inputs[0]->type != 0) + if (op->inputs[0]->type != 0) { int in_n, in_c, in_l, out_c, out_l, k_s, g; bool bias = true; @@ -2911,21 +2911,21 @@ pnnx::ModelInfo Graph::flops_mem_count() bias = false; } } - if(bias) + if (bias) { m.flops += in_n * 2 * in_c * k_s * out_c * out_l; m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c + out_c); } else { - m.flops += in_n * (2 * in_c * k_s -1) * out_c * out_l; + m.flops += in_n * (2 * in_c * k_s - 1) * out_c * out_l; m.memory_access += in_n * (in_l * in_c + out_l * out_c + in_c * k_s * out_c); } } } else if (op->type == "nn.Conv2d" || op->type == "nn.ConvTranspose2d" && op->inputs[0]->shape.size() == 4) { - if(op->inputs[0]->type != 0) + if (op->inputs[0]->type != 0) { int in_n, in_c, in_h, in_w, out_c, out_h, out_w, k_h, k_w, g; bool bias; @@ -2938,22 +2938,22 @@ pnnx::ModelInfo Graph::flops_mem_count() out_c = op->params.at("out_channels").i; out_h = op->outputs[0]->shape[2]; out_w = op->outputs[0]->shape[3]; - bias = op->params.at("bias").b;//bias - if(bias) + bias = op->params.at("bias").b; //bias + if (bias) { m.flops += in_n * 2 * in_c * k_h * k_w * out_c * out_w * out_h; m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + in_c * k_h * k_w * out_c + out_c); } else { - m.flops += in_n * (2 * in_c * k_h * k_w -1) * out_c * out_w * out_h; + m.flops += in_n * (2 * in_c * k_h * k_w - 1) * out_c * out_w * out_h; m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + k_h * k_w * in_c * out_c); } } } else if (op->type == "F.conv2d" || op->type == "F.conv_transpose2d" && op->inputs[0]->shape.size() == 4) { - if(op->inputs[0]->type != 0) + if (op->inputs[0]->type != 0) { int in_n, in_c, in_h, in_w, out_c, out_h, out_w, k_h, k_w, g; bool bias = true; @@ -2974,21 +2974,21 @@ pnnx::ModelInfo Graph::flops_mem_count() bias = false; } } - if(bias) + if (bias) { m.flops += in_n * 2 * in_c * k_h * k_w * out_c * out_w * out_h; m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + in_c * k_h * k_w * out_c + out_c); } else { - m.flops += in_n * (2 * in_c * k_h * k_w -1) * out_c * out_w * out_h; + m.flops += in_n * (2 * in_c * k_h * k_w - 1) * out_c * out_w * out_h; m.memory_access += in_n * (in_w * in_h * in_c + out_w * out_h * out_c + k_h * k_w * in_c * out_c); } } } else if (op->type == "nn.Conv3d" || op->type == "nn.ConvTranspose3d" && op->inputs[0]->shape.size() == 5) { - if(op->inputs[0]->type != 0) + if (op->inputs[0]->type != 0) { int in_n, in_c, in_d, in_h, in_w, out_c, out_d, out_h, out_w, k_d, k_h, k_w, g; bool bias; @@ -3004,22 +3004,22 @@ pnnx::ModelInfo Graph::flops_mem_count() out_d = op->outputs[0]->shape[2]; out_h = op->outputs[0]->shape[3]; out_w = op->outputs[0]->shape[4]; - bias = op->params.at("bias").b;//bias - if(bias) + bias = op->params.at("bias").b; //bias + if (bias) { m.flops += in_n * 2 * in_c * k_d * k_h * k_w * out_c * out_d * out_w * out_h; m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c + out_c); } else { - m.flops += in_n * (2 * in_c * k_d * k_h * k_w -1) * out_c * out_d * out_w * out_h; + m.flops += in_n * (2 * in_c * k_d * k_h * k_w - 1) * out_c * out_d * out_w * out_h; m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c); } } } else if (op->type == "F.conv3d" || op->type == "F.conv_transpose3d" && op->inputs[0]->shape.size() == 5) { - if(op->inputs[0]->type != 0) + if (op->inputs[0]->type != 0) { int in_n, in_c, in_d, in_h, in_w, out_c, out_d, out_h, out_w, k_d, k_h, k_w, g; bool bias = true; @@ -3043,14 +3043,14 @@ pnnx::ModelInfo Graph::flops_mem_count() bias = false; } } - if(bias) + if (bias) { m.flops += in_n * 2 * in_c * k_d * k_h * k_w * out_c * out_d * out_w * out_h; m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c + out_c); } else { - m.flops += in_n * (2 * in_c * k_d * k_h * k_w -1) * out_c * out_d * out_w * out_h; + m.flops += in_n * (2 * in_c * k_d * k_h * k_w - 1) * out_c * out_d * out_w * out_h; m.memory_access += in_n * (in_d * in_w * in_h * in_c + out_d * out_w * out_h * out_c + in_c * k_d * k_h * k_w * out_c); } } @@ -3067,7 +3067,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - if(bias) + if (bias) { m.flops += mem * 2 * in * out; m.memory_access += mem * (in + out + in * out + 1); @@ -3097,7 +3097,7 @@ pnnx::ModelInfo Graph::flops_mem_count() bias = false; } } - if(bias) + if (bias) { m.flops += mem * 2 * in * out; m.memory_access += mem * (in + out + in * out + 1); @@ -3142,7 +3142,7 @@ pnnx::ModelInfo Graph::flops_mem_count() vdim = op->params.at("vdim").i; linear1 = q_l * embed_dim * embed_dim + k_s * embed_dim * Kdim + v_s * embed_dim * vdim; attention = q_l * k_s * embed_dim + 2 * q_l * k_s * num_heads + q_l * v_s * embed_dim; - linerar2 = q_l* embed_dim * embed_dim; + linerar2 = q_l * embed_dim * embed_dim; m.flops += linear1 + attention + linerar2; weights = embed_dim * embed_dim + embed_dim * Kdim + embed_dim * vdim + num_heads * vdim * embed_dim; in = q_l * embed_dim + k_s * Kdim + v_s * vdim; @@ -3213,7 +3213,7 @@ pnnx::ModelInfo Graph::flops_mem_count() mem *= op->inputs[0]->shape[index]; } m.flops += (kernel_add + kernel_avg) * out_l * mem; - m.memory_access += mem * (in_l+ out_l); + m.memory_access += mem * (in_l + out_l); } else if (op->type == "nn.AvgPool2d" || op->type == "F.avg_pool2d" && op->inputs[0]->shape.size() >= 2) { @@ -3343,7 +3343,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - k_l = (in_l + out_l -1) / out_l; + k_l = (in_l + out_l - 1) / out_l; } kernel_add = k_l - 1; kernel_avg = 1; @@ -3366,7 +3366,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - k_h = (in_h + out_h -1) / out_h; + k_h = (in_h + out_h - 1) / out_h; } if (out_w == 0) { @@ -3375,7 +3375,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - k_w = (in_w + out_w -1) / out_w; + k_w = (in_w + out_w - 1) / out_w; } kernel_add = k_h * k_w - 1; kernel_avg = 1; @@ -3384,7 +3384,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else if (op->type == "nn.AdaptiveAvgPool3d" || op->type == "F.adaptive_avg_pool3d" && op->inputs[0]->shape.size() == 5) { - int in_n, in_c, in_d, in_h, in_w,out_d, out_h, out_w, k_d, k_h, k_w, kernel_add, kernel_avg; + int in_n, in_c, in_d, in_h, in_w, out_d, out_h, out_w, k_d, k_h, k_w, kernel_add, kernel_avg; in_n = op->inputs[0]->shape[0]; in_c = op->inputs[0]->shape[1]; in_d = op->inputs[0]->shape[2]; @@ -3400,7 +3400,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - k_d = (in_d + out_d -1) / out_d; + k_d = (in_d + out_d - 1) / out_d; } if (out_h == 0) { @@ -3409,7 +3409,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - k_h = (in_h + out_h -1) / out_h; + k_h = (in_h + out_h - 1) / out_h; } if (out_w == 0) { @@ -3418,7 +3418,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - k_w = (in_w + out_w -1) / out_w; + k_w = (in_w + out_w - 1) / out_w; } kernel_add = k_d * k_h * k_w - 1; kernel_avg = 1; @@ -3461,7 +3461,7 @@ pnnx::ModelInfo Graph::flops_mem_count() } else if (op->type == "nn.AdaptiveMaxPool3d" || op->type == "F.adaptive_max_pool3d" && op->inputs[0]->shape.size() == 5) { - int num_o, in_n, in_c,in_d, in_h, in_w, out_d, out_h, out_w; + int num_o, in_n, in_c, in_d, in_h, in_w, out_d, out_h, out_w; num_o = op->params.at("return_indices").b ? 2 : 1; in_n = op->inputs[0]->shape[0]; in_c = op->inputs[0]->shape[1]; @@ -3756,7 +3756,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += ( 7 * in_n + 4 ) * mem; + m.flops += (7 * in_n + 4) * mem; m.memory_access += 2 * in_n * mem; } else if (dim == 1) @@ -3767,7 +3767,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += ( 7 * in_c + 4 ) * in_n * mem; + m.flops += (7 * in_c + 4) * in_n * mem; m.memory_access += 2 * in_n * in_c * mem; } else if (dim == 2) @@ -3779,7 +3779,7 @@ pnnx::ModelInfo Graph::flops_mem_count() { mem *= op->inputs[0]->shape[index]; } - m.flops += ( 7 * in_h + 4 ) * in_n * in_c * mem; + m.flops += (7 * in_h + 4) * in_n * in_c * mem; m.memory_access += 2 * in_n * in_c * in_h * mem; } } @@ -3944,7 +3944,6 @@ pnnx::ModelInfo Graph::flops_mem_count() m.flops += batch * seq * 2 * (in_size * h_size + (num_layers - 1) * (h_size * h_size)); } m.memory_access += batch * seq * (in_size + num_layers * h_size + h_size); - } else if (op->type == "nn.SELU" || op->type == "F.selu" && op->inputs[0]->shape.size() >= 1) { @@ -3996,7 +3995,7 @@ pnnx::ModelInfo Graph::flops_mem_count() in_c = op->inputs[0]->shape[1]; in_h = op->inputs[0]->shape[2]; in_w = op->inputs[0]->shape[3]; - m.flops += in_n * in_c * (7 * in_h * in_w -1); + m.flops += in_n * in_c * (7 * in_h * in_w - 1); m.memory_access += 2 * in_n * in_c * in_h * in_w; } else if (op->type == "nn.Tanh" || op->type == "F.tanh" && op->inputs[0]->shape.size() >= 1) @@ -4032,7 +4031,7 @@ pnnx::ModelInfo Graph::flops_mem_count() int size_h = op->params.at("size").ai[0]; int size_w = op->params.at("size").ai[1]; m.flops += 4 * in_c * size_h * size_w; - m.memory_access += 5 * in_c * size_h * size_w; + m.memory_access += 5 * in_c * size_h * size_w; } else { @@ -4041,7 +4040,7 @@ pnnx::ModelInfo Graph::flops_mem_count() int out_h = in_h * scale_h; int out_w = in_w * scale_w; m.flops += 4 * in_c * out_h * out_w; - m.memory_access += 5 * in_c * out_h * out_w; + m.memory_access += 5 * in_c * out_h * out_w; } } } From 259ba3364be2d29abf119c20918bcc54b2a1a061 Mon Sep 17 00:00:00 2001 From: luxincn <2391719912@qq.com> Date: Sat, 12 Oct 2024 11:57:55 +0800 Subject: [PATCH 25/28] update flops_mem_count --- tools/pnnx/src/main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index fe1c05663019..f8b48f9de6e8 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -365,8 +365,8 @@ int main(int argc, char** argv) // count float pnnx_graph.flops_mem_count(); - fprintf(stderr, "float ops: %lld\n", pnnx_graph.m.flops); - fprintf(stderr, "memory ops: %lld\n", pnnx_graph.m.memory_access); + fprintf(stderr, "float ops: %.2f M\n", pnnx_graph.m.flops / 1e6); + fprintf(stderr, "memory ops: %.2f M\n", pnnx_graph.m.memory_access / 1e6); #if BUILD_PNNX2ONNX pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16); From 5ff6210585d6d8d48a7050245c979501b2829b49 Mon Sep 17 00:00:00 2001 From: SZUwishion <2559916473@qq.com> Date: Tue, 15 Oct 2024 23:09:30 +0800 Subject: [PATCH 26/28] code format fix --- tools/pnnx/src/ir.cpp | 132 ++++++++++++++++++++-------------------- tools/pnnx/src/main.cpp | 2 +- 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index e4974f6b7c87..c81944c12052 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1461,18 +1461,18 @@ void Graph::flops_memops_sum() int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); int out_features = op->attrs.at("data").shape[0]; flops += input_size * out_features; - if(op->has_param("bias")) + if (op->has_param("bias")) { flops += out_features; } memops += input_size + output_size; } else if (sub_type == "avgpool1d" - || sub_type == "avgpool2d" - || sub_type == "avgpool3d" - || sub_type == "adaptive_avgpool1d" - || sub_type == "adaptive_avgpool2d" - || sub_type == "adaptive_avgpool3d") + || sub_type == "avgpool2d" + || sub_type == "avgpool3d" + || sub_type == "adaptive_avgpool1d" + || sub_type == "adaptive_avgpool2d" + || sub_type == "adaptive_avgpool3d") { std::vector input_shape = op->inputs[0]->shape; std::vector output_shape = op->outputs[0]->shape; @@ -1482,11 +1482,11 @@ void Graph::flops_memops_sum() memops += input_size + output_size; } else if (sub_type == "prelu" - || sub_type == "elu" - || sub_type == "leaky_relu" - || sub_type == "gelu" - || sub_type == "silu" - || sub_type == "softmax") + || sub_type == "elu" + || sub_type == "leaky_relu" + || sub_type == "gelu" + || sub_type == "silu" + || sub_type == "softmax") { std::vector input_shape = op->inputs[0]->shape; std::vector output_shape = op->outputs[0]->shape; @@ -1496,8 +1496,8 @@ void Graph::flops_memops_sum() extra_memops += input_size + output_size; } else if (sub_type == "unsample" - || sub_type == "upsample_nearest" - || sub_type == "upsample_bilinear") + || sub_type == "upsample_nearest" + || sub_type == "upsample_bilinear") { std::vector input_shape = op->inputs[0]->shape; std::vector output_shape = op->outputs[0]->shape; @@ -1534,7 +1534,7 @@ void Graph::flops_memops_sum() int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - if((op->has_param("affine") && op->params.at("affine").b) + if ((op->has_param("affine") && op->params.at("affine").b) || (op->has_param("elementwise_affine") && op->params.at("elementwise_affine").b)) { extra_flops += 2 * num_elements; @@ -1547,11 +1547,11 @@ void Graph::flops_memops_sum() } } else if (sub_type == "Conv1d" - || sub_type == "Conv2d" - || sub_type == "Conv3d" - || sub_type == "ConvTranspose1d" - || sub_type == "ConvTranspose2d" - || sub_type == "ConvTranspose3d") + || sub_type == "Conv2d" + || sub_type == "Conv3d" + || sub_type == "ConvTranspose1d" + || sub_type == "ConvTranspose2d" + || sub_type == "ConvTranspose3d") { int c = op->params.at("in_channels").i; std::vector k = op->params.at("kernel_size").ai; @@ -1561,17 +1561,17 @@ void Graph::flops_memops_sum() int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); int kernel_size = std::accumulate(k.begin() + 2, k.end(), 1, std::multiplies()); - flops += output_size * c * kernel_size / g; + flops += output_size * c * kernel_size / g; memops += input_size + output_size + std::accumulate(k.begin(), k.end(), 1, std::multiplies()) * c / g; - if(op->has_param("bias")) + if (op->has_param("bias")) { flops += output_size; memops += output_size; } } else if (sub_type == "AvgPool1d" - || sub_type == "AvgPool2d" - || sub_type == "AvgPool3d") + || sub_type == "AvgPool2d" + || sub_type == "AvgPool3d") { std::vector input_shape = op->inputs[0]->shape; std::vector output_shape = op->outputs[0]->shape; @@ -1581,31 +1581,31 @@ void Graph::flops_memops_sum() memops += input_size + output_size; } else if (sub_type == "AdaptiveAvgPool1d" - || sub_type == "AdaptiveAvgPool2d" - || sub_type == "AdaptiveAvgPool3d") + || sub_type == "AdaptiveAvgPool2d" + || sub_type == "AdaptiveAvgPool3d") { std::vector input_shape = op->inputs[0]->shape; std::vector output_shape = op->outputs[0]->shape; int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); std::vector kernel_size; - for(size_t i = 2; i < input_shape.size(); i++) + for (size_t i = 2; i < input_shape.size(); i++) { kernel_size.emplace_back(output_shape[i] / input_shape[i]); } flops += (std::accumulate(kernel_size.begin(), kernel_size.end(), 1, std::multiplies()) + 1) * output_size; memops += input_size + output_size; } - else if(sub_type == "PReLU" - || sub_type == "ELU" - || sub_type == "LeakyReLU" - || sub_type == "GELU") + else if (sub_type == "PReLU" + || sub_type == "ELU" + || sub_type == "LeakyReLU" + || sub_type == "GELU") { std::vector shape = op->outputs[0]->shape; int n = shape[0]; int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); extra_flops += num_elements; - if(sub_type == "PReLU") + if (sub_type == "PReLU") { extra_memops += 2 * num_elements + n * op->params["num_parameters"].i; } @@ -1614,7 +1614,7 @@ void Graph::flops_memops_sum() extra_memops += 2 * num_elements; } } - else if(sub_type == "Tanh") + else if (sub_type == "Tanh") { std::vector shape = op->outputs[0]->shape; int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); @@ -1634,75 +1634,75 @@ void Graph::flops_memops_sum() memops += input_size + output_size + output_size * (bias ? 1 : 0); } else if (sub_type == "Upsample" - || sub_type == "UnsampleBilinear2d" - || sub_type == "UnsampleNearest2d") + || sub_type == "UnsampleBilinear2d" + || sub_type == "UnsampleNearest2d") { std::vector input_shape = op->inputs[0]->shape; int input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies()); std::vector output_shape = op->outputs[0]->shape; int output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); std::string mode; - if(sub_type == "Unsample") + if (sub_type == "Unsample") { mode = op->has_param("mode") ? op->params.at("mode").s : "nearest"; } - else if(sub_type == "UnsampleBilinear2d") + else if (sub_type == "UnsampleBilinear2d") { mode = "bilinear"; } - else if(sub_type == "UnsampleNearest2d") + else if (sub_type == "UnsampleNearest2d") { mode = "nearest"; } - if(mode == "nearest") + if (mode == "nearest") { extra_flops += input_size; extra_memops += input_size + output_size; } - else if(mode == "linear") + else if (mode == "linear") { extra_flops += 5 * output_size; extra_memops += 2 * input_size + output_size; } - else if(mode == "bilinear") + else if (mode == "bilinear") { extra_flops += 11 * output_size; extra_memops += 4 * input_size + output_size; } - else if(mode == "bicubic") + else if (mode == "bicubic") { extra_flops += (224 + 35) * output_size; extra_memops += 16 * input_size + output_size; } - else if(mode == "trilinear") + else if (mode == "trilinear") { extra_flops += (13 * 2 + 5) * input_size; extra_memops += 8 * input_size + output_size; } } - else if(sub_type == "RNN") + else if (sub_type == "RNN") { bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; bool bias = op->has_param("bias") && op->params.at("bias").b; int input_size = op->params.at("input_size").i; int hidden_size = op->params.at("hidden_size").i; int flops1 = hidden_size * (input_size + hidden_size) + hidden_size; - if(bias) + if (bias) { flops1 += 2 * hidden_size; } - if(bi) + if (bi) { flops1 *= 2; } int num_layers = op->params.at("num_layers").i; int flops2 = 0; - if(bi) + if (bi) { flops2 = 3 * hidden_size * hidden_size + hidden_size; - if(bias) + if (bias) { flops2 += 2 * hidden_size; } @@ -1711,7 +1711,7 @@ void Graph::flops_memops_sum() else { flops2 = 2 * hidden_size * hidden_size + hidden_size; - if(bias) + if (bias) { flops2 += 2 * hidden_size; } @@ -1723,23 +1723,23 @@ void Graph::flops_memops_sum() flops += (flops1 + flops2) * num_steps * batch_size; memops += num_steps * batch_size * input_size; memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); - if(bias) + if (bias) { memops += 2 * hidden_size * num_layers * (bi ? 2 : 1); } } - else if(sub_type == "LSTM") + else if (sub_type == "LSTM") { bool bi = op->has_param("bidirectional") && op->params.at("bidirectional").b; bool bias = op->has_param("bias") && op->params.at("bias").b; int input_size = op->params.at("input_size").i; int hidden_size = op->params.at("hidden_size").i; int flops1 = 4 * hidden_size * (input_size + hidden_size) + 4 * hidden_size; - if(bias) + if (bias) { flops1 += 8 * hidden_size; } - if(bi) + if (bi) { flops1 *= 2; } @@ -1747,10 +1747,10 @@ void Graph::flops_memops_sum() int num_layers = op->params.at("num_layers").i; int flops2 = 0; - if(bi) + if (bi) { flops2 = 12 * hidden_size * hidden_size + 4 * hidden_size; - if(bias) + if (bias) { flops2 += 8 * hidden_size; } @@ -1760,7 +1760,7 @@ void Graph::flops_memops_sum() else { flops2 = 4 * hidden_size * hidden_size + 4 * hidden_size; - if(bias) + if (bias) { flops2 += 8 * hidden_size; } @@ -1773,7 +1773,7 @@ void Graph::flops_memops_sum() flops += (flops1 + flops2) * num_steps * batch_size; memops += num_steps * batch_size * input_size; memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); - if(bias) + if (bias) { memops += 8 * hidden_size * num_layers * (bi ? 2 : 1); } @@ -1785,22 +1785,22 @@ void Graph::flops_memops_sum() int input_size = op->params.at("input_size").i; int hidden_size = op->params.at("hidden_size").i; int flops1 = 3 * hidden_size * (input_size + hidden_size) + 3 * hidden_size; - if(bias) + if (bias) { flops1 += 6 * hidden_size; } flops1 += 4 * hidden_size; - if(bi) + if (bi) { flops1 *= 2; } int num_layers = op->params.at("num_layers").i; int flops2 = 0; - if(bi) + if (bi) { flops2 = 9 * hidden_size * hidden_size + 3 * hidden_size; - if(bias) + if (bias) { flops2 += 6 * hidden_size; } @@ -1810,7 +1810,7 @@ void Graph::flops_memops_sum() else { flops2 = 6 * hidden_size * hidden_size + 3 * hidden_size; - if(bias) + if (bias) { flops2 += 6 * hidden_size; } @@ -1823,12 +1823,12 @@ void Graph::flops_memops_sum() flops += (flops1 + flops2) * num_steps * batch_size; memops += num_steps * batch_size * input_size; memops += 2 * num_steps * batch_size * hidden_size * num_layers * (bi ? 2 : 1); - if(bias) + if (bias) { memops += 6 * hidden_size * num_layers * (bi ? 2 : 1); } } - else if(sub_type == "MultiheadAttention") + else if (sub_type == "MultiheadAttention") { bool batch_first = op->has_param("batch_first") && op->params.at("batch_first").b; int batch_size = batch_first ? op->inputs[0]->shape[0] : op->inputs[0]->shape[1]; @@ -1883,7 +1883,7 @@ void Graph::flops_memops_sum() else if (op->type.substr(0, 5) == "torch") { std::string sub_type = op->type.substr(6); - if(sub_type == "matmul" + if (sub_type == "matmul" || sub_type == "mm" || sub_type == "bmm") { @@ -1897,7 +1897,7 @@ void Graph::flops_memops_sum() memops += input_size_1 + input_size_2 + output_size; } else if (sub_type == "addmm" - || sub_type == "baddbmm") + || sub_type == "baddbmm") { std::vector input_shape = op->inputs[0]->shape; std::vector mat_shape_1 = op->inputs[1]->shape; @@ -1911,7 +1911,7 @@ void Graph::flops_memops_sum() memops += input_size + mat_size_1 + mat_size_2 + output_size; } else if (sub_type == "mul" - || sub_type == "add") + || sub_type == "add") { std::vector input_shape_1 = op->inputs[0]->shape; std::vector input_shape_2 = op->inputs[1]->shape; diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 949680faab82..a50ca679fbc6 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -362,7 +362,7 @@ int main(int argc, char** argv) pnnx_graph.save(pnnxparampath, pnnxbinpath); pnnx_graph.python(pnnxpypath, pnnxbinpath); - + pnnx_graph.flops_memops_sum(); fprintf(stderr, "float ops = %.3fM\n", double(pnnx_graph.flops) / 1e6); fprintf(stderr, "mem ops = %.3fM\n", double(pnnx_graph.memops) / 1e6); From 0e3791cb329f1414dc5c23e16bb1e481a8db5980 Mon Sep 17 00:00:00 2001 From: luxincn <2391719912@qq.com> Date: Thu, 17 Oct 2024 17:06:17 +0800 Subject: [PATCH 27/28] ci tests fix --- tools/pnnx/src/ir.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index cfd15de042db..e51cb543a3e9 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -4022,11 +4022,24 @@ pnnx::ModelInfo Graph::flops_mem_count() else if (op->type == "nn.UpsamplingBilinear2d" || op->type == "F.upsample_bilinear" && op->inputs[0]->shape.size() == 4) { int in_n, in_c, in_h, in_w; + bool has_size = false; in_n = op->inputs[0]->shape[0]; in_c = op->inputs[0]->shape[1]; in_h = op->inputs[0]->shape[2]; in_w = op->inputs[0]->shape[3]; if (op->params.find("size") != op->params.end()) + { + std::string val = Parameter::encode_to_string(op->params.at("size")); + if (val == "None") + { + has_size = false; + } + else + { + has_size = true; + } + } + if (has_size) { int size_h = op->params.at("size").ai[0]; int size_w = op->params.at("size").ai[1]; @@ -4035,8 +4048,8 @@ pnnx::ModelInfo Graph::flops_mem_count() } else { - int scale_h = op->params.at("scale_factor").ai[0]; - int scale_w = op->params.at("scale_factor").ai[1]; + int scale_h = op->params.at("scale_factor").af[0]; + int scale_w = op->params.at("scale_factor").af[1]; int out_h = in_h * scale_h; int out_w = in_w * scale_w; m.flops += 4 * in_c * out_h * out_w; From 1f5b8416bdfa2c0da337912cb458dd4d22859053 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 20 Dec 2024 03:14:01 +0000 Subject: [PATCH 28/28] apply code-format changes --- tools/pnnx/src/ir.cpp | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 0d518f53875a..13727a30246f 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1522,20 +1522,20 @@ void Graph::flops_memops_sum() { std::string sub_type = op->type.substr(3); if (sub_type == "BatchNorm1d" - || sub_type == "BatchNorm2d" - || sub_type == "BatchNorm3d" - || sub_type == "GroupNorm" - || sub_type == "LayerNorm" - || sub_type == "InstanceNorm1d" - || sub_type == "InstanceNorm2d" - || sub_type == "InstanceNorm3d") + || sub_type == "BatchNorm2d" + || sub_type == "BatchNorm3d" + || sub_type == "GroupNorm" + || sub_type == "LayerNorm" + || sub_type == "InstanceNorm1d" + || sub_type == "InstanceNorm2d" + || sub_type == "InstanceNorm3d") { std::vector shape = op->inputs[0]->shape; int n = op->inputs[0]->shape[0]; int c = op->inputs[0]->shape[1]; int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); if ((op->has_param("affine") && op->params.at("affine").b) - || (op->has_param("elementwise_affine") && op->params.at("elementwise_affine").b)) + || (op->has_param("elementwise_affine") && op->params.at("elementwise_affine").b)) { extra_flops += 2 * num_elements; extra_memops += 2 * (num_elements + n * c); @@ -1884,8 +1884,8 @@ void Graph::flops_memops_sum() { std::string sub_type = op->type.substr(6); if (sub_type == "matmul" - || sub_type == "mm" - || sub_type == "bmm") + || sub_type == "mm" + || sub_type == "bmm") { std::vector input_shape_1 = op->inputs[0]->shape; std::vector input_shape_2 = op->inputs[1]->shape; @@ -2018,10 +2018,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) for (size_t i = 0; i < param.ai.size(); i++) { if ((op->type == "nn.AdaptiveAvgPool2d" - || op->type == "nn.AdaptiveAvgPool3d" - || op->type == "nn.AdaptiveMaxPool2d" - || op->type == "nn.AdaptiveMaxPool3d") - && it.first == "output_size" && param.ai[i] == 0) + || op->type == "nn.AdaptiveAvgPool3d" + || op->type == "nn.AdaptiveMaxPool2d" + || op->type == "nn.AdaptiveMaxPool3d") + && it.first == "output_size" && param.ai[i] == 0) { fprintf(pyfp, "None"); } @@ -2780,10 +2780,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) for (size_t i = 0; i < param.ai.size(); i++) { if ((op->type == "F.adaptive_avg_pool2d" - || op->type == "F.adaptive_avg_pool3d" - || op->type == "F.adaptive_max_pool2d" - || op->type == "F.adaptive_max_pool3d") - && it.first == "output_size" && param.ai[i] == 0) + || op->type == "F.adaptive_avg_pool3d" + || op->type == "F.adaptive_max_pool2d" + || op->type == "F.adaptive_max_pool3d") + && it.first == "output_size" && param.ai[i] == 0) { fprintf(pyfp, "None"); }