Skip to content

Commit 437246f

Browse files
authored
Merge pull request #19593 from paldepind/rust/operator-overloading
Rust: Type inference for operator overloading
2 parents 55be5fb + 6500ebf commit 437246f

File tree

7 files changed

+2065
-798
lines changed

7 files changed

+2065
-798
lines changed

rust/ql/lib/codeql/rust/elements/internal/BinaryExprImpl.qll

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ module Impl {
2828

2929
override string getOperatorName() { result = Generated::BinaryExpr.super.getOperatorName() }
3030

31-
override Expr getAnOperand() { result = [this.getLhs(), this.getRhs()] }
31+
override Expr getOperand(int n) {
32+
n = 0 and result = this.getLhs()
33+
or
34+
n = 1 and result = this.getRhs()
35+
}
3236
}
3337
}

rust/ql/lib/codeql/rust/elements/internal/OperationImpl.qll

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,78 @@
77
private import rust
88
private import codeql.rust.elements.internal.ExprImpl::Impl as ExprImpl
99

10+
/**
11+
* Holds if the operator `op` is overloaded to a trait with the canonical path
12+
* `path` and the method name `method`.
13+
*/
14+
private predicate isOverloaded(string op, string path, string method) {
15+
// Negation
16+
op = "-" and path = "core::ops::arith::Neg" and method = "neg"
17+
or
18+
// Not
19+
op = "!" and path = "core::ops::bit::Not" and method = "not"
20+
or
21+
// Dereference
22+
op = "*" and path = "core::ops::Deref" and method = "deref"
23+
or
24+
// Comparison operators
25+
op = "==" and path = "core::cmp::PartialEq" and method = "eq"
26+
or
27+
op = "!=" and path = "core::cmp::PartialEq" and method = "ne"
28+
or
29+
op = "<" and path = "core::cmp::PartialOrd" and method = "lt"
30+
or
31+
op = "<=" and path = "core::cmp::PartialOrd" and method = "le"
32+
or
33+
op = ">" and path = "core::cmp::PartialOrd" and method = "gt"
34+
or
35+
op = ">=" and path = "core::cmp::PartialOrd" and method = "ge"
36+
or
37+
// Arithmetic operators
38+
op = "+" and path = "core::ops::arith::Add" and method = "add"
39+
or
40+
op = "-" and path = "core::ops::arith::Sub" and method = "sub"
41+
or
42+
op = "*" and path = "core::ops::arith::Mul" and method = "mul"
43+
or
44+
op = "/" and path = "core::ops::arith::Div" and method = "div"
45+
or
46+
op = "%" and path = "core::ops::arith::Rem" and method = "rem"
47+
or
48+
// Arithmetic assignment expressions
49+
op = "+=" and path = "core::ops::arith::AddAssign" and method = "add_assign"
50+
or
51+
op = "-=" and path = "core::ops::arith::SubAssign" and method = "sub_assign"
52+
or
53+
op = "*=" and path = "core::ops::arith::MulAssign" and method = "mul_assign"
54+
or
55+
op = "/=" and path = "core::ops::arith::DivAssign" and method = "div_assign"
56+
or
57+
op = "%=" and path = "core::ops::arith::RemAssign" and method = "rem_assign"
58+
or
59+
// Bitwise operators
60+
op = "&" and path = "core::ops::bit::BitAnd" and method = "bitand"
61+
or
62+
op = "|" and path = "core::ops::bit::BitOr" and method = "bitor"
63+
or
64+
op = "^" and path = "core::ops::bit::BitXor" and method = "bitxor"
65+
or
66+
op = "<<" and path = "core::ops::bit::Shl" and method = "shl"
67+
or
68+
op = ">>" and path = "core::ops::bit::Shr" and method = "shr"
69+
or
70+
// Bitwise assignment operators
71+
op = "&=" and path = "core::ops::bit::BitAndAssign" and method = "bitand_assign"
72+
or
73+
op = "|=" and path = "core::ops::bit::BitOrAssign" and method = "bitor_assign"
74+
or
75+
op = "^=" and path = "core::ops::bit::BitXorAssign" and method = "bitxor_assign"
76+
or
77+
op = "<<=" and path = "core::ops::bit::ShlAssign" and method = "shl_assign"
78+
or
79+
op = ">>=" and path = "core::ops::bit::ShrAssign" and method = "shr_assign"
80+
}
81+
1082
/**
1183
* INTERNAL: This module contains the customizable definition of `Operation` and should not
1284
* be referenced directly.
@@ -16,14 +88,28 @@ module Impl {
1688
* An operation, for example `&&`, `+=`, `!` or `*`.
1789
*/
1890
abstract class Operation extends ExprImpl::Expr {
91+
/** Gets the operator name of this operation, if it exists. */
92+
abstract string getOperatorName();
93+
94+
/** Gets the `n`th operand of this operation, if any. */
95+
abstract Expr getOperand(int n);
96+
1997
/**
20-
* Gets the operator name of this operation, if it exists.
98+
* Gets the number of operands of this operation.
99+
*
100+
* This is either 1 for prefix operations, or 2 for binary operations.
21101
*/
22-
abstract string getOperatorName();
102+
final int getNumberOfOperands() { result = strictcount(this.getAnOperand()) }
103+
104+
/** Gets an operand of this operation. */
105+
Expr getAnOperand() { result = this.getOperand(_) }
23106

24107
/**
25-
* Gets an operand of this operation.
108+
* Holds if this operation is overloaded to the method `methodName` of the
109+
* trait `trait`.
26110
*/
27-
abstract Expr getAnOperand();
111+
predicate isOverloaded(Trait trait, string methodName) {
112+
isOverloaded(this.getOperatorName(), trait.getCanonicalPath(), methodName)
113+
}
28114
}
29115
}

rust/ql/lib/codeql/rust/elements/internal/PrefixExprImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ module Impl {
2626

2727
override string getOperatorName() { result = Generated::PrefixExpr.super.getOperatorName() }
2828

29-
override Expr getAnOperand() { result = this.getExpr() }
29+
override Expr getOperand(int n) { n = 0 and result = this.getExpr() }
3030
}
3131
}

rust/ql/lib/codeql/rust/elements/internal/RefExprImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ module Impl {
2929

3030
override string getOperatorName() { result = "&" }
3131

32-
override Expr getAnOperand() { result = this.getExpr() }
32+
override Expr getOperand(int n) { n = 0 and result = this.getExpr() }
3333

3434
private string getSpecPart(int index) {
3535
index = 0 and this.isRaw() and result = "raw"

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -643,20 +643,30 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
643643

644644
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl
645645

646-
class Access extends CallExprBase {
646+
abstract class Access extends Expr {
647+
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
648+
649+
abstract AstNode getNodeAt(AccessPosition apos);
650+
651+
abstract Type getInferredType(AccessPosition apos, TypePath path);
652+
653+
abstract Declaration getTarget();
654+
}
655+
656+
private class CallExprBaseAccess extends Access instanceof CallExprBase {
647657
private TypeMention getMethodTypeArg(int i) {
648658
result = this.(MethodCallExpr).getGenericArgList().getTypeArg(i)
649659
}
650660

651-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
661+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
652662
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
653663
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
654664
or
655665
arg = this.getMethodTypeArg(apos.asMethodTypeArgumentPosition())
656666
)
657667
}
658668

659-
AstNode getNodeAt(AccessPosition apos) {
669+
override AstNode getNodeAt(AccessPosition apos) {
660670
exists(int p, boolean isMethodCall |
661671
argPos(this, result, p, isMethodCall) and
662672
apos = TPositionalAccessPosition(p, isMethodCall)
@@ -669,17 +679,42 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
669679
apos = TReturnAccessPosition()
670680
}
671681

672-
Type getInferredType(AccessPosition apos, TypePath path) {
682+
override Type getInferredType(AccessPosition apos, TypePath path) {
673683
result = inferType(this.getNodeAt(apos), path)
674684
}
675685

676-
Declaration getTarget() {
686+
override Declaration getTarget() {
677687
result = CallExprImpl::getResolvedFunction(this)
678688
or
679689
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
680690
}
681691
}
682692

693+
private class OperationAccess extends Access instanceof Operation {
694+
OperationAccess() { super.isOverloaded(_, _) }
695+
696+
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
697+
// The syntax for operators does not allow type arguments.
698+
none()
699+
}
700+
701+
override AstNode getNodeAt(AccessPosition apos) {
702+
result = super.getOperand(0) and apos = TSelfAccessPosition()
703+
or
704+
result = super.getOperand(1) and apos = TPositionalAccessPosition(0, true)
705+
or
706+
result = this and apos = TReturnAccessPosition()
707+
}
708+
709+
override Type getInferredType(AccessPosition apos, TypePath path) {
710+
result = inferType(this.getNodeAt(apos), path)
711+
}
712+
713+
override Declaration getTarget() {
714+
result = inferMethodCallTarget(this) // mutual recursion; resolving method calls requires resolving types and vice versa
715+
}
716+
}
717+
683718
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
684719
apos.isSelf() and
685720
dpos.isSelf()
@@ -1059,6 +1094,26 @@ private module MethodCall {
10591094
pragma[nomagic]
10601095
override Type getTypeAt(TypePath path) { result = inferType(receiver, path) }
10611096
}
1097+
1098+
private class OperationMethodCall extends MethodCallImpl instanceof Operation {
1099+
TraitItemNode trait;
1100+
string methodName;
1101+
1102+
OperationMethodCall() { super.isOverloaded(trait, methodName) }
1103+
1104+
override string getMethodName() { result = methodName }
1105+
1106+
override int getArity() { result = this.(Operation).getNumberOfOperands() - 1 }
1107+
1108+
override Trait getTrait() { result = trait }
1109+
1110+
pragma[nomagic]
1111+
override Type getTypeAt(TypePath path) {
1112+
result = inferType(this.(BinaryExpr).getLhs(), path)
1113+
or
1114+
result = inferType(this.(PrefixExpr).getExpr(), path)
1115+
}
1116+
}
10621117
}
10631118

10641119
import MethodCall

0 commit comments

Comments
 (0)