Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expr div #70376

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open

Expr div #70376

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/cinn/adt/dim_expr_match_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Negative<T0>> final
: public UnaryDimExprMatchTrait<::symbol::Negative, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Reciprocal<T0>> final
: public UnaryDimExprMatchTrait<::symbol::Reciprocal, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Add<T0>> final
: public ListDimExprMatchTrait<::symbol::Add, T0> {};
Expand All @@ -77,6 +73,10 @@ template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Mul<T0>> final
: public ListDimExprMatchTrait<::symbol::Mul, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Div<T0>> final
: public ListDimExprMatchTrait<::symbol::Div, T0> {};

template <typename T0>
struct MatchTrait<DimExpr, ::symbol::Broadcast<T0>> final
: public ListDimExprMatchTrait<::symbol::Broadcast, T0> {};
Expand Down
12 changes: 6 additions & 6 deletions paddle/cinn/common/broadcast_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ bool SearchBroadcastImpl(const symbol::Negative<symbol::DimExpr>& unary,
return SearchBroadcastImplForUnary(unary, DoEach);
}

template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Reciprocal<symbol::DimExpr>& unary,
const DoEachT& DoEach) {
return SearchBroadcastImplForUnary(unary, DoEach);
}

template <typename T, typename DoEachT>
bool SearchBroadcastImplForVariadic(const T& variadic, const DoEachT& DoEach) {
const auto& operands = *(variadic.operands);
Expand All @@ -76,6 +70,12 @@ bool SearchBroadcastImpl(const symbol::Mul<symbol::DimExpr>& variadic,
return SearchBroadcastImplForVariadic(variadic, DoEach);
}

template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Div<symbol::DimExpr>& variadic,
const DoEachT& DoEach) {
return SearchBroadcastImplForVariadic(variadic, DoEach);
}

template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Max<symbol::DimExpr>& variadic,
const DoEachT& DoEach) {
Expand Down
29 changes: 13 additions & 16 deletions paddle/cinn/common/dim_expr_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ struct DimExprToIrExprVisitor {
return ir::Sub::Make(ir::Expr(std::int64_t(0)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Reciprocal<DimExpr>& dim_expr) {
const auto& [operand] = *dim_expr;
return ir::Div::Make(ir::Expr(std::int64_t(1)), ConvertToIrExpr(operand));
}

ir::Expr operator()(const Add<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
Expand All @@ -69,17 +64,19 @@ struct DimExprToIrExprVisitor {
}
ir::Expr product = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
// Convert Reciprocal<DimExpr>(S0) to (1 / S0) will result in precision
// error. For example, (S0 * S1 / S2) != (S0 * S1 * (1 / S2)). So we
// should use Div instead of Reciprocal here.
if (operands->at(i).isa<Reciprocal<DimExpr>>()) {
product = ir::Div::Make(
product,
ConvertToIrExpr(
operands->at(i).dyn_cast<Reciprocal<DimExpr>>()->data));
} else {
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
}
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
}
return product;
}

ir::Expr operator()(const Div<DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
if (operands->empty()) {
return ir::Expr(std::int64_t(1));
}
ir::Expr product = ConvertToIrExpr(operands->at(0));
for (std::size_t i = 1; i < operands->size(); ++i) {
product = ir::Div::Make(product, ConvertToIrExpr(operands->at(i)));
}
return product;
}
Expand Down
45 changes: 23 additions & 22 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ std::string GetSerializedTag<Negative<DimExpr>>() {
return "Negative";
}

template <>
std::string GetSerializedTag<Reciprocal<DimExpr>>() {
return "Reciprocal";
}

template <>
std::string GetSerializedTag<Add<DimExpr>>() {
return "Add";
Expand All @@ -45,6 +40,11 @@ std::string GetSerializedTag<Mul<DimExpr>>() {
return "Mul";
}

template <>
std::string GetSerializedTag<Div<DimExpr>>() {
return "Div";
}

template <>
std::string GetSerializedTag<Max<DimExpr>>() {
return "Max";
Expand Down Expand Up @@ -85,11 +85,6 @@ ::pir::Attribute ConvertDimExprToAttributeImpl(
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::IrContext* ctx, const Reciprocal<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
}

template <typename T>
::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::IrContext* ctx,
const T& dim_expr) {
Expand All @@ -112,6 +107,11 @@ ::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Div<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Max<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
Expand Down Expand Up @@ -175,12 +175,12 @@ std::optional<ArrayAttributeConverterT> GetArrayAttributeConverter(
static std::unordered_map<std::string, ArrayAttributeConverterT> map{
{GetSerializedTag<Negative<DimExpr>>(),
&ConvertArrayAttributeToUnaryDimExpr<Negative<DimExpr>>},
{GetSerializedTag<Reciprocal<DimExpr>>(),
&ConvertArrayAttributeToUnaryDimExpr<Reciprocal<DimExpr>>},
{GetSerializedTag<Add<DimExpr>>(),
&ConvertArrayAttributeToVariadicDimExpr<Add<DimExpr>>},
{GetSerializedTag<Mul<DimExpr>>(),
&ConvertArrayAttributeToVariadicDimExpr<Mul<DimExpr>>},
{GetSerializedTag<Div<DimExpr>>(),
&ConvertArrayAttributeToVariadicDimExpr<Div<DimExpr>>},
{GetSerializedTag<Max<DimExpr>>(),
&ConvertArrayAttributeToVariadicDimExpr<Max<DimExpr>>},
{GetSerializedTag<Min<DimExpr>>(),
Expand Down Expand Up @@ -276,9 +276,6 @@ class SubstituteDimExprHelper final {
std::optional<DimExpr> SubstituteImpl(const Negative<DimExpr>& dim_expr) {
return SubstituteUnary(dim_expr);
}
std::optional<DimExpr> SubstituteImpl(const Reciprocal<DimExpr>& dim_expr) {
return SubstituteUnary(dim_expr);
}

template <typename T>
std::optional<DimExpr> SubstituteUnary(const T& dim_expr) {
Expand All @@ -298,6 +295,10 @@ class SubstituteDimExprHelper final {
return SubstituteVariadic(dim_expr);
}

std::optional<DimExpr> SubstituteImpl(const Div<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}

std::optional<DimExpr> SubstituteImpl(const Max<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}
Expand Down Expand Up @@ -412,12 +413,12 @@ bool IsAtomicImpl(const std::string&) { return true; }

bool IsAtomicImpl(const symbol::Negative<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Reciprocal<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Add<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Mul<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Div<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Max<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Min<symbol::DimExpr>&) { return false; }
Expand Down Expand Up @@ -484,11 +485,6 @@ void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr,
CollectSymbolNamesImplForUnary(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForUnary(dim_expr, ret);
}

template <typename T>
void CollectSymbolNamesImplForVariadic(const T& dim_expr,
std::set<std::string>* ret) {
Expand All @@ -508,6 +504,11 @@ void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr,
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Div<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Max<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ struct ShapeSignatureGenerator {
[&](const symbol::Negative<symbol::DimExpr>& negative) {
GetSymbolsForOneDimExpr(negative->data, symbols);
},
[&](const symbol::Reciprocal<symbol::DimExpr>& reciprocal) {
GetSymbolsForOneDimExpr(reciprocal->data, symbols);
},
[&](const symbol::Add<symbol::DimExpr>& add) {
for (const auto& dim_expr : *add.operands) {
GetSymbolsForOneDimExpr(dim_expr, symbols);
Expand All @@ -150,6 +147,11 @@ struct ShapeSignatureGenerator {
GetSymbolsForOneDimExpr(dim_expr, symbols);
}
},
[&](const symbol::Div<symbol::DimExpr>& div) {
for (const auto& dim_expr : *div.operands) {
GetSymbolsForOneDimExpr(dim_expr, symbols);
}
},
[&](const symbol::Max<symbol::DimExpr>& max) {
for (const auto& dim_expr : *max.operands) {
GetSymbolsForOneDimExpr(dim_expr, symbols);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,6 @@ struct StaticDimToDynamicConverter {
return AppliedOnceUnaryImpl(dim_expr, symbol);
}

bool AppliedOnceImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
const std::string& symbol) {
return AppliedOnceUnaryImpl(dim_expr, symbol);
}

template <typename T>
bool AppliedOnceListImpl(const T& dim_expr, const std::string& symbol) {
const auto& [operands] = dim_expr;
Expand All @@ -222,6 +217,11 @@ struct StaticDimToDynamicConverter {
return AppliedOnceListImpl(dim_expr, symbol);
}

bool AppliedOnceImpl(const symbol::Div<symbol::DimExpr>& dim_expr,
const std::string& symbol) {
return AppliedOnceListImpl(dim_expr, symbol);
}

bool AppliedOnceImpl(const symbol::Min<symbol::DimExpr>& dim_expr,
const std::string& symbol) {
return AppliedOnceListImpl(dim_expr, symbol);
Expand Down Expand Up @@ -297,21 +297,21 @@ struct StaticDimToDynamicConverter {
}

std::optional<symbol::DimExpr> ConvertDimExprImpl(
const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
const symbol::Add<symbol::DimExpr>& dim_expr,
int64_t c,
const std::string& symbol) {
return ConvertUnaryDimExprImpl(dim_expr, c, symbol);
return ConvertListDimExprImpl(dim_expr, c, symbol);
}

std::optional<symbol::DimExpr> ConvertDimExprImpl(
const symbol::Add<symbol::DimExpr>& dim_expr,
const symbol::Mul<symbol::DimExpr>& dim_expr,
int64_t c,
const std::string& symbol) {
return ConvertListDimExprImpl(dim_expr, c, symbol);
}

std::optional<symbol::DimExpr> ConvertDimExprImpl(
const symbol::Mul<symbol::DimExpr>& dim_expr,
const symbol::Div<symbol::DimExpr>& dim_expr,
int64_t c,
const std::string& symbol) {
return ConvertListDimExprImpl(dim_expr, c, symbol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ bool IsComplicatedDimExpr(const symbol::DimExpr& dim_expr) {
[](std::int64_t dim_expr) { return false; },
[](const std::string& dim_expr) { return false; },
[](const symbol::Negative<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Reciprocal<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Add<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Mul<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Div<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Max<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Min<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Broadcast<symbol::DimExpr>& dim_expr) { return true; }};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,6 @@ struct PirToPyCodeConverterHelper {
ss << ")";
return ss.str();
},
[](const symbol::Reciprocal<symbol::DimExpr>& reciprocal) {
std::ostringstream ss;
const auto& [operand] = *reciprocal;
ss << "self.s_reciprocal(";
ss << PirToPyCodeConverterHelper::ConvertDimExpr(operand);
ss << ")";
return ss.str();
},
[](const symbol::Add<symbol::DimExpr>& add) {
std::ostringstream ss;
ss << "self.s_add(";
Expand Down Expand Up @@ -847,6 +839,20 @@ struct PirToPyCodeConverterHelper {
ss << ")";
return ss.str();
},
[](const symbol::Div<symbol::DimExpr>& div) {
std::ostringstream ss;
ss << "self.s_div(";
const auto& operands = div.operands;
int i = 0;
for (const auto& operand : *operands) {
if (i++ > 0) {
ss << ", ";
}
ss << PirToPyCodeConverterHelper::ConvertDimExpr(operand);
}
ss << ")";
return ss.str();
},
[](const symbol::Max<symbol::DimExpr>& max) {
std::ostringstream ss;
ss << "self.s_max(";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,6 @@ struct CachedDimExprToValueConverter {
"ConvertToValueImpl(symbol::Add<symbol::DimExpr>)"));
}

pir::Value ConvertToValueImpl(
const symbol::Reciprocal<symbol::DimExpr>& dim_expr) {
PADDLE_THROW(::common::errors::Fatal(
"Dead code. This logical should handled by "
"ConvertToValueImpl(symbol::Mul<symbol::DimExpr>)"));
}

pir::Value ConvertToValueImpl(const symbol::Add<symbol::DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
PADDLE_ENFORCE_GT(operands->size(),
Expand Down Expand Up @@ -201,19 +194,26 @@ struct CachedDimExprToValueConverter {
operands->size()));
pir::Value prod = ConvertToValue(operands->at(0));
for (int i = 1; i < operands->size(); ++i) {
if (operands->at(i).isa<symbol::Reciprocal<symbol::DimExpr>>()) {
const auto& operand =
operands->at(i)
.dyn_cast<symbol::Reciprocal<symbol::DimExpr>>()
->data;
pir::Value operand_value = ConvertToValue(operand);
prod = rewriter->Build<paddle::dialect::DivideOp>(prod, operand_value)
.out();
} else {
pir::Value operand_value = ConvertToValue(operands->at(i));
prod = rewriter->Build<paddle::dialect::MultiplyOp>(prod, operand_value)
.out();
}
pir::Value operand_value = ConvertToValue(operands->at(i));
prod = rewriter->Build<paddle::dialect::MultiplyOp>(prod, operand_value)
.out();
}
return prod;
}

pir::Value ConvertToValueImpl(const symbol::Div<symbol::DimExpr>& dim_expr) {
const auto& [operands] = dim_expr;
PADDLE_ENFORCE_GT(operands->size(),
0,
::common::errors::InvalidArgument(
"The size of operands is incorrect."
"Expected size is larger than 0, but receive %d.",
operands->size()));
pir::Value prod = ConvertToValue(operands->at(0));
for (int i = 1; i < operands->size(); ++i) {
pir::Value operand_value = ConvertToValue(operands->at(i));
prod =
rewriter->Build<paddle::dialect::DivideOp>(prod, operand_value).out();
}
return prod;
}
Expand Down
Loading