diff --git a/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc index 33471068ecbcb4..6030cea716494b 100644 --- a/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/xpu/fc_xpu_fuse_pass.cc @@ -25,32 +25,6 @@ #include "paddle/pir/include/pass/pass_registry.h" #include "paddle/pir/include/pattern_rewrite/pattern_match.h" -/* -fuse malmul + add to fc_xpu -For example: -graph: - - x w - \ / - | - mul - | - | - bias --- add - | - | - output ------------------------------------------------------- -After the pass is applied: - x w - \ / - | - bias--- fc_xpu - | - | - Output -*/ - namespace { int ConvertActivationType(const std::string &act_type) { @@ -76,6 +50,8 @@ int ConvertActivationType(const std::string &act_type) { return static_cast(xpu::Activation_t::SWISH); } else if (act_type == "relu6") { return static_cast(xpu::Activation_t::RELU6); + } else if (act_type == "swish_glu") { + return static_cast(xpu::Activation_t::SWISH_GLU); } else { PADDLE_THROW(common::errors::Unimplemented( "Not support convert activation_type(%s).", act_type)); @@ -83,13 +59,38 @@ int ConvertActivationType(const std::string &act_type) { return -1; } -class FCXpuFusePattern : public paddle::drr::DrrPatternBase { +/* +fuse malmul + add to fc_xpu +For example: +graph: + + x w + \ / + | + mul + | + | + bias --- add + | + | + output +------------------------------------------------------ +After the pass is applied: + x w + \ / + | + bias--- fc_xpu + | + | + Output +*/ +class FcXpuFuseAddPattern : public paddle::drr::DrrPatternBase { private: bool transpose_w_; public: - explicit FCXpuFusePattern(bool transpose_w) : transpose_w_(transpose_w) {} - std::string name() const override { return "FCXpuFusePattern"; } + explicit FcXpuFuseAddPattern(bool transpose_w) : transpose_w_(transpose_w) {} + std::string name() const override { return "FcXpuFuseAddPattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -185,14 +186,277 @@ class FCXpuFusePattern : public paddle::drr::DrrPatternBase { } }; +/* +fuse malmul + add + act to fc_xpu +For example: +graph: + + x w + \ / + | + mul + | + | + bias --- add + | + | + act + | + | + output +------------------------------------------------------ +After the pass is applied: + x w + \ / + | + bias--- fc_xpu + | + | + Output +*/ +class FcXpuFuseAddActPattern : public paddle::drr::DrrPatternBase { + private: + bool transpose_w_; + + public: + explicit FcXpuFuseAddActPattern(bool transpose_w) + : transpose_w_(transpose_w) {} + std::string name() const override { return "FcXpuFuseAddActPattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &mul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + mul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("mul_out")}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + add({&pat.Tensor("mul_out"), &pat.Tensor("bias")}, + {&pat.Tensor("add_out")}); + const auto &swiglu = pat.Op(paddle::dialect::SwigluOp::name()); + swiglu({&pat.Tensor("add_out"), &pat.InputNoneTensor()}, + {&pat.Tensor("act_out")}); + + // Constraints + pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_shape = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + if (transpose_w_ != match_ctx.Attr("transpose_y")) { + return false; + } + return (w_shape.size() == 2 && x_shape.size() >= 2 && + bias_shape.size() == 1); + }); + + // Result pattern + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &in_num_col_dims_attr = + res.ComputeAttr([&](const paddle::drr::MatchContext &match_ctx) -> int { + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + return x_shape.size() - 1; + }); + + if (!transpose_w_) { + // prepare weight, transpose it if necessary + const auto &perm_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + if (w_shape.size() == 2) { + return {1, 0}; + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Not support convert w_shape.size()(%d).", w_shape.size())); + } + }); + const auto &transpose_op = + res.Op(paddle::dialect::TransposeOp::name(), {{"perm", perm_attr}}); + res.Tensor("w_trans") = transpose_op(res.Tensor("w")); + VLOG(3) << "transpose weight for fc_xpu op"; + } + + const auto &out_dtype_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> phi::DataType { + auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x")); + // 目前仅支持以下几种非量化的情况 + if (x_dtype.isa()) { + return phi::DataType::FLOAT32; + } else if (x_dtype.isa()) { + return phi::DataType::FLOAT16; + } else if (x_dtype.isa()) { + return phi::DataType::BFLOAT16; + } else { + return phi::DataType::UNDEFINED; + } + }); + // only support float32 bias now + const auto &cast_op = res.Op(paddle::dialect::CastOp::name(), + {{"dtype", res.DataTypeAttr("float32")}}); + res.Tensor("bias_fp32") = cast_op(res.Tensor("bias")); + + const auto &fc_xpu = res.Op( + paddle::dialect::FcXpuOp::name(), + {{ + {"in_num_col_dims", in_num_col_dims_attr}, + {"transpose_x", pat.Attr("transpose_x")}, + {"alpha", res.Float32Attr(1.0f)}, + {"beta", res.Float32Attr(0.f)}, + {"act_type", res.Int32Attr(ConvertActivationType("swish_glu"))}, + {"act_alpha", res.Float32Attr(0.0f)}, + {"out_dtype", out_dtype_attr}, + }}); + fc_xpu( + { + &res.Tensor("x"), + &res.InputNoneTensor(), + transpose_w_ ? &res.Tensor("w") : &res.Tensor("w_trans"), + &res.InputNoneTensor(), + &res.Tensor("bias_fp32"), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + }, + {&res.Tensor("act_out"), &res.Tensor("out_max")}); + } +}; + +/* +fuse malmul + act to fc_xpu +For example: +graph: + + x w + \ / + | + mul + | + | + act + | + | + output +------------------------------------------------------ +After the pass is applied: + x w + \ / + | + bias--- fc_xpu + | + | + Output +*/ + +class FcXpuFuseActPattern : public paddle::drr::DrrPatternBase { + private: + bool transpose_w_; + + public: + explicit FcXpuFuseActPattern(bool transpose_w) : transpose_w_(transpose_w) {} + std::string name() const override { return "FcXpuFuseActPattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &mul = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + mul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("mul_out")}); + + const auto &swiglu = pat.Op(paddle::dialect::SwigluOp::name()); + swiglu({&pat.Tensor("mul_out"), &pat.InputNoneTensor()}, + {&pat.Tensor("act_out")}); + + // Constraints + pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + if (transpose_w_ != match_ctx.Attr("transpose_y")) { + return false; + } + return (w_shape.size() == 2 && x_shape.size() >= 2); + }); + + // Result pattern + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &in_num_col_dims_attr = + res.ComputeAttr([&](const paddle::drr::MatchContext &match_ctx) -> int { + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + return x_shape.size() - 1; + }); + + if (!transpose_w_) { + // prepare weight, transpose it if necessary + const auto &perm_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + auto w_shape = pir::GetShapeFromValue(match_ctx.Tensor("w")); + if (w_shape.size() == 2) { + return {1, 0}; + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Not support convert w_shape.size()(%d).", w_shape.size())); + } + }); + const auto &transpose_op = + res.Op(paddle::dialect::TransposeOp::name(), {{"perm", perm_attr}}); + res.Tensor("w_trans") = transpose_op(res.Tensor("w")); + VLOG(3) << "transpose weight for fc_xpu op"; + } + + const auto &out_dtype_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> phi::DataType { + auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x")); + // 目前仅支持以下几种非量化的情况 + if (x_dtype.isa()) { + return phi::DataType::FLOAT32; + } else if (x_dtype.isa()) { + return phi::DataType::FLOAT16; + } else if (x_dtype.isa()) { + return phi::DataType::BFLOAT16; + } else { + return phi::DataType::UNDEFINED; + } + }); + // only support float32 bias now + const auto &cast_op = res.Op(paddle::dialect::CastOp::name(), + {{"dtype", res.DataTypeAttr("float32")}}); + + const auto &fc_xpu = res.Op( + paddle::dialect::FcXpuOp::name(), + {{ + {"in_num_col_dims", in_num_col_dims_attr}, + {"transpose_x", pat.Attr("transpose_x")}, + {"alpha", res.Float32Attr(1.0f)}, + {"beta", res.Float32Attr(0.f)}, + {"act_type", res.Int32Attr(ConvertActivationType("swish_glu"))}, + {"act_alpha", res.Float32Attr(0.0f)}, + {"out_dtype", out_dtype_attr}, + }}); + fc_xpu( + { + &res.Tensor("x"), + &res.InputNoneTensor(), + transpose_w_ ? &res.Tensor("w") : &res.Tensor("w_trans"), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + }, + {&res.Tensor("act_out"), &res.Tensor("out_max")}); + } +}; + class FCXpuFusePass : public pir::PatternRewritePass { public: FCXpuFusePass() : pir::PatternRewritePass("fc_xpu_fuse_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); - ps.Add(paddle::drr::Create(context, false)); - ps.Add(paddle::drr::Create(context, true)); + // ps.Add(paddle::drr::Create(context, false)); + // ps.Add(paddle::drr::Create(context, true)); + ps.Add(paddle::drr::Create(context, false)); + ps.Add(paddle::drr::Create(context, true)); + ps.Add(paddle::drr::Create(context, false)); + ps.Add(paddle::drr::Create(context, true)); return ps; } }; diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 4ce91b1d918643..f062ad0f1624cc 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -842,6 +842,9 @@ void FcXPUInferMeta(const MetaTensor& x, out_shape[i] = static_cast(x.dims()[i]); } out_shape[in_num_col_dims] = static_cast(w.dims()[0]); + if (act_type == 23 /*phi::backends::xpu::Activation_t::SWISH_GLU*/) { + out_shape[in_num_col_dims] = out_shape[in_num_col_dims] / 2; + } out->set_dims(DDim(out_shape.data(), static_cast(out_shape.size()))); out->set_dtype(out_dtype); out->set_layout(x.layout()); diff --git a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc index b2047d6ec99c7e..ea2aefe6608f4d 100644 --- a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc @@ -156,34 +156,111 @@ void FcXPUKernelImpl(const Context& ctx, w_len); PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te"); } - int r = - xblas::fc_fusion( - ctx.x_context(), - x_data_fp16, - w_data_fp16, - out_data, - m, - n, - k, - transpose_x, - true, - x_max_data ? x_max_data : xte_x_maxptr, - w_max_data ? w_max_data : xte_w_maxptr, - out_max_data, - transpose_x ? m : k, - k, - n, - alpha, - beta, - bias_data, - act, - xte_scale_x, - xte_scale_w); + baidu::xpu::xblas::FcFusionTensor tensor_a1{ + x_data_fp16, + x_max_data ? x_max_data : xte_x_maxptr, + transpose_x ? k : m, + transpose_x ? m : k, + transpose_x ? m : k, + transpose_x}; + baidu::xpu::xblas::FcFusionTensor tensor_b1{ + w_data_fp16, w_max_data ? w_max_data : xte_w_maxptr, n, k, k, true}; + baidu::xpu::xblas::FcFusionTensor tensor_c1{ + out_data, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionTensor tensor_d1{ + out_data, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionDesc desc{alpha, + beta}; + baidu::xpu::xblas::FcFusionEpilogue epilogue1{ + act, bias_data, xte_scale_x, xte_scale_w, 0, 0, out_max_data}; + + if (act_type == xpu::Activation_t::SWISH_GLU) { + tensor_d1 = baidu::xpu::xblas::FcFusionTensor{ + out_data, nullptr, m, n / 2, n / 2, false}; + } else { + tensor_d1 = baidu::xpu::xblas::FcFusionTensor{ + out_data, nullptr, m, n, n, false}; + } + + int r = baidu::xpu::xblas::fc_fusion(ctx.x_context(), + tensor_a1, + tensor_b1, + tensor_c1, + tensor_d1, + desc, + epilogue1); + + // int r = + // xblas::fc_fusion( + // ctx.x_context(), + // x_data_fp16, + // w_data_fp16, + // out_data, + // m, + // n, + // k, + // transpose_x, + // true, + // x_max_data ? x_max_data : xte_x_maxptr, + // w_max_data ? w_max_data : xte_w_maxptr, + // out_max_data, + // transpose_x ? m : k, + // k, + // n, + // alpha, + // beta, + // bias_data, + // act, + // xte_scale_x, + // xte_scale_w); PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); } } if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") == nullptr) { + // baidu::xpu::xblas::FcFusionTensor tensor_a1{ + // x_data, + // x_max_data, + // transpose_x ? k : m, + // transpose_x ? m : k, + // transpose_x ? m : k, + // transpose_x}; + // baidu::xpu::xblas::FcFusionTensor tensor_b1{ + // w_data, w_max_data, n, k, k, true}; + // baidu::xpu::xblas::FcFusionTensor + // tensor_c1{out_data, out_max_data, m, n, n, false}; + // baidu::xpu::xblas::FcFusionTensor tensor_d1{ + // out_data, out_max_data, m, n, n, false}; + // baidu::xpu::xblas::FcFusionDesc desc{alpha, + // beta}; + + // baidu::xpu::xblas::FcFusionEpilogue epilogue1{ + // act, bias_data, scale_max_data, nullptr, 0, 0, out_max_data}; + + // int r = baidu::xpu::xblas::fc_fusion(ctx.x_context(), + // tensor_a1, + // tensor_b1, + // tensor_c1, + // tensor_d1, + // desc, + // epilogue1); int r = xpu:: fc_fusion( // TX/TW/TY/TGEMM ctx.x_context(), // ctx diff --git a/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py b/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py index ca1f6e6df4f920..54b93cbe1938b6 100644 --- a/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py +++ b/test/ir/pir/fused_pass/xpu/test_fc_xpu_fuse_pass.py @@ -23,7 +23,7 @@ paddle.enable_static() -class TestFCXpuFusePattern(PassTest): +class TestFcXpuFuseAddPattern(PassTest): r""" x w \ /