Skip to content

Commit 56a5854

Browse files
committed
CSHARP-5628: Add new boolean expression simplifications to PartialEvaluator
1 parent 55c13cd commit 56a5854

File tree

3 files changed

+314
-34
lines changed

3 files changed

+314
-34
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/PartialEvaluator.cs

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -75,60 +75,137 @@ public override Expression Visit(Expression expression)
7575

7676
protected override Expression VisitBinary(BinaryExpression node)
7777
{
78-
if (node.NodeType == ExpressionType.AndAlso)
78+
var leftExpression = node.Left;
79+
var rightExpression = node.Right;
80+
81+
if (leftExpression.Type == typeof(bool) && rightExpression.Type == typeof(bool))
7982
{
80-
var leftExpression = Visit(node.Left);
81-
if (leftExpression is ConstantExpression constantLeftExpression )
83+
if (node.NodeType == ExpressionType.AndAlso)
8284
{
83-
var value = (bool)constantLeftExpression.Value;
84-
return value ? Visit(node.Right) : Expression.Constant(false);
85+
leftExpression = Visit(leftExpression);
86+
if (IsConstant<bool>(leftExpression, out var leftValue))
87+
{
88+
// true && Q => Q
89+
// false && Q => false
90+
return leftValue ? Visit(rightExpression) : Expression.Constant(false);
91+
}
92+
93+
rightExpression = Visit(rightExpression);
94+
if (IsConstant<bool>(rightExpression, out var rightValue))
95+
{
96+
// P && true => P
97+
// P && false => false
98+
return rightValue ? leftExpression : Expression.Constant(false);
99+
}
100+
101+
return node.Update(leftExpression, conversion: null, rightExpression);
85102
}
86103

87-
var rightExpression = Visit(node.Right);
88-
if (rightExpression is ConstantExpression constantRightExpression)
104+
if (node.NodeType == ExpressionType.OrElse)
89105
{
90-
var value = (bool)constantRightExpression.Value;
91-
return value ? leftExpression : Expression.Constant(false);
106+
leftExpression = Visit(leftExpression);
107+
if (IsConstant<bool>(leftExpression, out var leftValue))
108+
{
109+
// true || Q => true
110+
// false || Q => Q
111+
return leftValue ? Expression.Constant(true) : Visit(rightExpression);
112+
}
113+
114+
rightExpression = Visit(rightExpression);
115+
if (IsConstant<bool>(rightExpression, out var rightValue))
116+
{
117+
// P || true => true
118+
// P || false => P
119+
return rightValue ? Expression.Constant(true) : leftExpression;
120+
}
121+
122+
return node.Update(leftExpression, conversion: null, rightExpression);
92123
}
124+
}
125+
126+
return base.VisitBinary(node);
127+
}
93128

94-
return node.Update(leftExpression, conversion: null, rightExpression);
129+
protected override Expression VisitConditional(ConditionalExpression node)
130+
{
131+
var test = Visit(node.Test);
132+
133+
if (IsConstant<bool>(test, out var testValue))
134+
{
135+
// true ? A : B => A
136+
// false ? A : B => B
137+
return testValue ? Visit(node.IfTrue) : Visit(node.IfFalse);
95138
}
96139

97-
if (node.NodeType == ExpressionType.OrElse)
140+
var ifTrue = Visit(node.IfTrue);
141+
var ifFalse = Visit(node.IfFalse);
142+
143+
if (BothAreConstant<bool>(ifTrue, ifFalse, out var ifTrueValue, out var ifFalseValue))
98144
{
99-
var leftExpression = Visit(node.Left);
100-
if (leftExpression is ConstantExpression constantLeftExpression)
145+
return (ifTrueValue, ifFalseValue) switch
101146
{
102-
var value = (bool)constantLeftExpression.Value;
103-
return value ? Expression.Constant(true) : Visit(node.Right);
104-
}
147+
(false, false) => Expression.Constant(false), // T ? false : false => false
148+
(false, true) => Expression.Not(test), // T ? false : true => !T
149+
(true, false) => test, // T ? true : false => T
150+
(true, true) => Expression.Constant(true) // T ? true : true => true
151+
};
152+
}
153+
else if (IsConstant<bool>(ifTrue, out ifTrueValue))
154+
{
155+
// T ? true : Q => T || Q
156+
// T ? false : Q => !T && Q
157+
return ifTrueValue
158+
? Visit(Expression.Or(test, ifFalse))
159+
: Visit(Expression.And(Expression.Not(test), ifFalse));
160+
}
161+
else if (IsConstant<bool>(ifFalse, out ifFalseValue))
162+
{
163+
// T ? P : true => !T || P
164+
// T ? P : false => T && P
165+
return ifFalseValue
166+
? Visit(Expression.Or(Expression.Not(test), ifTrue))
167+
: Visit(Expression.And(test, ifTrue));
168+
}
105169

106-
var rightExpression = Visit(node.Right);
107-
if (rightExpression is ConstantExpression constantRightExpression)
170+
return node.Update(test, ifTrue, ifFalse);
171+
}
172+
173+
protected override Expression VisitUnary(UnaryExpression node)
174+
{
175+
var operand = Visit(node.Operand);
176+
177+
if (node.Type == typeof(bool) &&
178+
node.NodeType == ExpressionType.Not)
179+
{
180+
if (operand is UnaryExpression innerUnaryExpressionOperand &&
181+
innerUnaryExpressionOperand.NodeType == ExpressionType.Not)
108182
{
109-
var value = (bool)constantRightExpression.Value;
110-
return value ? Expression.Constant(true) : leftExpression;
183+
// !!P => P
184+
return innerUnaryExpressionOperand.Operand;
111185
}
112-
113-
return node.Update(leftExpression, conversion: null, rightExpression);
114186
}
115187

116-
return base.VisitBinary(node);
188+
return node.Update(operand);
117189
}
118190

119-
protected override Expression VisitConditional(ConditionalExpression node)
191+
// private methods
192+
private bool BothAreConstant<T>(Expression expression1, Expression expression2, out T constantValue1, out T constantValue2)
120193
{
121-
var test = Visit(node.Test);
122-
if (test is ConstantExpression constantTestExpression)
194+
if (expression1 is ConstantExpression constantExpression1 &&
195+
expression2 is ConstantExpression constantExpression2 &&
196+
constantExpression1.Type == typeof(T) &&
197+
constantExpression2.Type == typeof(T))
123198
{
124-
var value = (bool)constantTestExpression.Value;
125-
return value ? Visit(node.IfTrue) : Visit(node.IfFalse);
199+
constantValue1 = (T)constantExpression1.Value;
200+
constantValue2 = (T)constantExpression2.Value;
201+
return true;
126202
}
127203

128-
return node.Update(test, Visit(node.IfTrue), Visit(node.IfFalse));
204+
constantValue1 = default;
205+
constantValue2 = default;
206+
return false;
129207
}
130208

131-
// private methods
132209
private Expression Evaluate(Expression expression)
133210
{
134211
if (expression.NodeType == ExpressionType.Constant)
@@ -139,6 +216,19 @@ private Expression Evaluate(Expression expression)
139216
Delegate fn = lambda.Compile();
140217
return Expression.Constant(fn.DynamicInvoke(null), expression.Type);
141218
}
219+
220+
private bool IsConstant<T>(Expression expression, out T constantValue)
221+
{
222+
if (expression is ConstantExpression constantExpression1 &&
223+
constantExpression1.Type == typeof(T))
224+
{
225+
constantValue = (T)constantExpression1.Value;
226+
return true;
227+
}
228+
229+
constantValue = default;
230+
return false;
231+
}
142232
}
143233

144234
private class Nominator : ExpressionVisitor

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4337Tests.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ public class CSharp4337Tests : LinqIntegrationTest<CSharp4337Tests.ClassFixture>
3232
{
3333
private static (Expression<Func<C, R<bool>>> Projection, string ExpectedStage, bool[] ExpectedResults)[] __predicate_should_use_correct_representation_test_cases = new (Expression<Func<C, R<bool>>> Projection, string ExpectedStage, bool[] ExpectedResults)[]
3434
{
35-
(d => new R<bool> { N = d.Id, V = d.I1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['$I1', 1] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }),
36-
(d => new R<bool> { N = d.Id, V = d.S1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['$S1', 'E1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }),
37-
(d => new R<bool> { N = d.Id, V = E.E1 == d.I1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : [1, '$I1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }),
38-
(d => new R<bool> { N = d.Id, V = E.E1 == d.S1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['E1', '$S1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false })
35+
(d => new R<bool> { N = d.Id, V = d.I1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['$I1', 1] }, _id : 0 } }", new[] { true, false }),
36+
(d => new R<bool> { N = d.Id, V = d.S1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['$S1', 'E1'] }, _id : 0 } }", new[] { true, false }),
37+
(d => new R<bool> { N = d.Id, V = E.E1 == d.I1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : [1, '$I1'] }, _id : 0 } }", new[] { true, false }),
38+
(d => new R<bool> { N = d.Id, V = E.E1 == d.S1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['E1', '$S1'] }, _id : 0 } }", new[] { true, false })
3939
};
4040

4141
public CSharp4337Tests(ClassFixture fixture)

0 commit comments

Comments
 (0)