Skip to content

Commit f2fea03

Browse files
committed
wip4
1 parent 30be9c4 commit f2fea03

2 files changed

Lines changed: 123 additions & 45 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ private module Input3 implements InputSig3 {
286286
(exists(resolveTupleFieldExpr(_, _)) implies any())
287287
}
288288

289+
class BoolType extends DataType {
290+
BoolType() { this.getTypeItem() instanceof Builtins::Bool }
291+
}
292+
289293
class AstNode = Rust::AstNode;
290294

291295
TypeMention getTypeAnnotation(AstNode n) {
@@ -304,32 +308,60 @@ private module Input3 implements InputSig3 {
304308
result = n.(ShorthandSelfParameterMention)
305309
}
306310

311+
class Expr = Rust::Expr;
312+
313+
class ConditionalExpr extends AstNode, IfExpr {
314+
Expr getCondition() { result = super.getCondition() }
315+
316+
Expr getThen() { result = super.getThen() }
317+
318+
Expr getElse() { result = super.getElse() }
319+
}
320+
321+
class BinaryExpr extends AstNode, Rust::BinaryExpr {
322+
Expr getLeftOperand() { result = super.getLhs() }
323+
324+
Expr getRightOperand() { result = super.getRhs() }
325+
}
326+
327+
class LogicalAndExpr extends BinaryExpr, Rust::LogicalAndExpr { }
328+
329+
class LogicalOrExpr extends BinaryExpr, Rust::LogicalOrExpr { }
330+
331+
abstract class Assignment extends BinaryExpr { }
332+
333+
class AssignExpr extends Assignment, Rust::AssignmentExpr { }
334+
335+
class ParenExpr extends AstNode, Rust::ParenExpr {
336+
AstNode getExpr() { result = super.getExpr() }
337+
}
338+
307339
class Variable extends Rust::Variable {
308340
AstNode getDefiningNode() {
309341
result = this.getPat().getName() or
310342
result = this.getParameter().(SelfParam)
311343
}
312344

313-
AstNode getAnAccess() { result = super.getAnAccess() }
345+
Expr getAnAccess() { result = super.getAnAccess() }
314346
}
315347

316-
abstract class Assignment extends AstNode {
348+
abstract class LetDeclaration extends AstNode {
317349
abstract predicate isCoercionSite();
318350

319351
abstract AstNode getLeftOperand();
320352

321353
abstract AstNode getRightOperand();
322354
}
323355

324-
private class LetExprAssignment extends Assignment, LetExpr {
356+
private class LetExprLetDeclaration extends LetDeclaration, LetExpr {
325357
override predicate isCoercionSite() { not this.getPat() instanceof IdentPat }
326358

327359
override AstNode getLeftOperand() { result = this.getPat() }
328360

329361
override AstNode getRightOperand() { result = this.getScrutinee() }
330362
}
331363

332-
private class LetStmtAssignment extends Assignment, LetStmt {
364+
private class LetStmtLetDeclaration extends LetDeclaration, LetStmt {
333365
override predicate isCoercionSite() {
334366
this.hasTypeRepr() or
335367
not identLetStmt(this, _, _)
@@ -340,18 +372,6 @@ private module Input3 implements InputSig3 {
340372
override AstNode getRightOperand() { result = this.getInitializer() }
341373
}
342374

343-
private class AssignmentExprAssignment extends Assignment, AssignmentExpr {
344-
override predicate isCoercionSite() { any() }
345-
346-
override AstNode getLeftOperand() { result = this.getLhs() }
347-
348-
override AstNode getRightOperand() { result = this.getRhs() }
349-
}
350-
351-
class ParenExpr extends AstNode, Rust::ParenExpr {
352-
AstNode getExpr() { result = super.getExpr() }
353-
}
354-
355375
predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
356376
n1 =
357377
any(IdentPat ip |
@@ -824,8 +844,6 @@ private module CertainTypeInferenceInput {
824844
result = inferRefExprType(n) and
825845
path.isEmpty()
826846
or
827-
result = inferLogicalOperationType(n, path)
828-
or
829847
result = inferCertainStructExprType(n, path)
830848
or
831849
result = inferCertainStructPatType(n, path)
@@ -857,14 +875,6 @@ private module CertainTypeInferenceInput {
857875
}
858876
}
859877

860-
private Type inferLogicalOperationType(AstNode n, TypePath path) {
861-
exists(Builtins::Bool t, BinaryLogicalOperation be |
862-
n = [be, be.getLhs(), be.getRhs()] and
863-
path.isEmpty() and
864-
result = TDataType(t)
865-
)
866-
}
867-
868878
private Type inferAssignmentOperationType(AstNode n, TypePath path) {
869879
n instanceof AssignmentOperation and
870880
path.isEmpty() and

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,6 +2114,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21142114

21152115
/**
21162116
* Provides the input to `Make3`.
2117+
*
2118+
* TODO: Eventually align the AST signature with that of the shared CFG library.
21172119
*/
21182120
signature module InputSig3 {
21192121
/**
@@ -2122,6 +2124,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21222124
*/
21232125
default predicate cachedStageRevRef() { none() }
21242126

2127+
/** A boolean type. */
2128+
class BoolType extends Type;
2129+
21252130
/** An AST node. */
21262131
class AstNode {
21272132
/** Gets a textual representation of this AST node. */
@@ -2134,13 +2139,63 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21342139
/** Gets the type annotation that applies to `n`, if any. */
21352140
TypeMention getTypeAnnotation(AstNode n);
21362141

2142+
/** An expression. */
2143+
class Expr extends AstNode;
2144+
2145+
/** A ternary conditional expression. */
2146+
class ConditionalExpr extends Expr {
2147+
/** Gets the condition of this expression. */
2148+
Expr getCondition();
2149+
2150+
/** Gets the true branch of this expression. */
2151+
Expr getThen();
2152+
2153+
/** Gets the false branch of this expression. */
2154+
Expr getElse();
2155+
}
2156+
2157+
/** A binary expression. */
2158+
class BinaryExpr extends Expr {
2159+
/** Gets the left operand of this binary expression. */
2160+
Expr getLeftOperand();
2161+
2162+
/** Gets the right operand of this binary expression. */
2163+
Expr getRightOperand();
2164+
}
2165+
2166+
/** A short-circuiting logical AND expression. */
2167+
class LogicalAndExpr extends BinaryExpr;
2168+
2169+
/** A short-circuiting logical OR expression. */
2170+
class LogicalOrExpr extends BinaryExpr;
2171+
2172+
/**
2173+
* An assignment expression, either compound or simple.
2174+
*
2175+
* Examples:
2176+
*
2177+
* ```
2178+
* x = y
2179+
* sum += element
2180+
* ```
2181+
*/
2182+
class Assignment extends BinaryExpr;
2183+
2184+
/** A simple assignment expression, for example `x = y`. */
2185+
class AssignExpr extends Assignment;
2186+
2187+
/** A parenthesized expression. */
2188+
class ParenExpr extends AstNode {
2189+
AstNode getExpr();
2190+
}
2191+
21372192
/** A variable, for example a local variable or a field. */
21382193
class Variable {
21392194
/** Gets the AST node that defines this variable. */
21402195
AstNode getDefiningNode();
21412196

21422197
/** Gets an access to this variable. */
2143-
AstNode getAnAccess();
2198+
Expr getAnAccess();
21442199

21452200
/** Gets a textual representation of this element. */
21462201
string toString();
@@ -2150,28 +2205,22 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21502205
}
21512206

21522207
/**
2153-
* An assignment where type information can flow from one operand to the
2154-
* other.
2208+
* A `let` declaration, for example a local variable declaration.
21552209
*/
2156-
class Assignment extends AstNode {
2210+
class LetDeclaration extends AstNode {
21572211
/**
2158-
* Holds if this assignment is a coercion site, meaning that the type of the right
2212+
* Holds if this declaration is a coercion site, meaning that the type of the right
21592213
* operand may have to be coerced to the type of the left operand.
21602214
*/
21612215
predicate isCoercionSite();
21622216

2163-
/** Gets the left operand of this binary expression. */
2217+
/** Gets the left operand of this declaration. */
21642218
AstNode getLeftOperand();
21652219

2166-
/** Gets the right operand of this binary expression. */
2220+
/** Gets the right operand of this declaration. */
21672221
AstNode getRightOperand();
21682222
}
21692223

2170-
/** A parenthesized expression. */
2171-
class ParenExpr extends AstNode {
2172-
AstNode getExpr();
2173-
}
2174-
21752224
/**
21762225
* Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal.
21772226
*/
@@ -2222,10 +2271,10 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22222271
(
22232272
exists(Variable v | n1 = v.getAnAccess() and n2 = v.getDefiningNode())
22242273
or
2225-
exists(Assignment a |
2226-
not a.isCoercionSite() and
2227-
n1 = a.getLeftOperand() and
2228-
n2 = a.getRightOperand()
2274+
exists(LetDeclaration let |
2275+
not let.isCoercionSite() and
2276+
n1 = let.getLeftOperand() and
2277+
n2 = let.getRightOperand()
22292278
)
22302279
or
22312280
n1 = n2.(ParenExpr).getExpr()
@@ -2246,6 +2295,16 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22462295
)
22472296
}
22482297

2298+
private Type inferLogicalOperationType(AstNode n, TypePath path) {
2299+
(
2300+
exists(LogicalAndExpr lae | n = [lae, lae.getLeftOperand(), lae.getRightOperand()]) or
2301+
exists(LogicalOrExpr loe | n = [loe, loe.getLeftOperand(), loe.getRightOperand()]) //or
2302+
// exists(LogicalNotExpr lne | n = [lne, lne.getOperand()])
2303+
) and
2304+
result instanceof BoolType and
2305+
path.isEmpty()
2306+
}
2307+
22492308
/** Gets the inferred certain type of `n` at `path`. */
22502309
cached
22512310
Type inferCertainType(AstNode n, TypePath path) {
@@ -2256,6 +2315,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22562315
or
22572316
result = inferCertainTypeInput(n, path)
22582317
or
2318+
result = inferLogicalOperationType(n, path)
2319+
or
22592320
infersCertainTypeAt(n, path, result.getATypeParameter())
22602321
}
22612322

@@ -2309,9 +2370,16 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
23092370
or
23102371
path1.isEmpty() and
23112372
path2.isEmpty() and
2312-
exists(Assignment a |
2313-
a.getLeftOperand() = n1 and
2314-
a.getRightOperand() = n2
2373+
(
2374+
exists(Assignment a |
2375+
a.getLeftOperand() = n1 and
2376+
a.getRightOperand() = n2
2377+
)
2378+
or
2379+
exists(LetDeclaration let |
2380+
let.getLeftOperand() = n1 and
2381+
let.getRightOperand() = n2
2382+
)
23152383
)
23162384
or
23172385
typeEqualityInput(n1, path1, n2, path2)

0 commit comments

Comments
 (0)