Skip to content

Commit a9b24ec

Browse files
committed
wip7
1 parent 3696436 commit a9b24ec

3 files changed

Lines changed: 184 additions & 60 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 35 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ private module Input3 implements InputSig3 {
286286
(exists(resolveTupleFieldExpr(_, _)) implies any())
287287
}
288288

289+
predicate inferType = M3::inferType/2;
290+
289291
class BoolType extends DataType {
290292
BoolType() { this.getTypeItem() instanceof Builtins::Bool }
291293
}
@@ -366,41 +368,54 @@ private module Input3 implements InputSig3 {
366368
override AstNode getRightOperand() { result = this.getInitializer() }
367369
}
368370

369-
class CallTarget extends FunctionCallMatchingInput::Declaration {
371+
class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment;
372+
373+
class TypePosition = FunctionPosition;
374+
375+
class Callable extends FunctionCallMatchingInput::Declaration {
370376
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) {
371377
result =
372378
tp.(TypeParamTypeParameter)
373379
.getTypeParam()
374380
.getAdditionalTypeBound(this.getFunction(), _)
375381
.getTypeRepr()
376382
}
377-
378-
Type getReturnType(TypePath path) {
379-
exists(FunctionPosition pos |
380-
pos.isReturn() and
381-
result = super.getDeclaredType(pos, path)
382-
)
383-
}
384-
385-
Type getParameterType(int index, TypePath path) {
386-
none() // todo
387-
}
388383
}
389384

390385
class Call extends Expr instanceof FunctionCallMatchingInput::Access {
391386
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
392387
result = super.getTypeArgument(apos, path)
393388
}
394389

390+
AstNode getNodeAt(TypePosition pos) { result = super.getNodeAt(pos) }
391+
395392
/** Gets the target of this call. */
396-
CallTarget getTargetCertain() {
393+
Callable getTargetCertain() {
397394
exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p |
398395
result.isFunction(i, f) and
399396
p = CallExprImpl::getFunctionPath(this) and
400397
f = resolvePath(p) and
401398
f.isDirectlyFor(i)
402399
)
403400
}
401+
402+
Callable getTarget(string derefChainBorrow) { result = super.getTarget(derefChainBorrow) }
403+
}
404+
405+
bindingset[derefChainBorrow]
406+
Type inferCallTypeIn(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) {
407+
result = call.(FunctionCallMatchingInput::Access).getInferredType(derefChainBorrow, pos, path)
408+
}
409+
410+
Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path) {
411+
result = inferFunctionCallTypeNonSelf(n, pos, path)
412+
or
413+
exists(FunctionCallMatchingInput::Access a |
414+
result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and
415+
if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver()
416+
then not path.isEmpty()
417+
else any()
418+
)
404419
}
405420

406421
predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
@@ -553,8 +568,6 @@ private module Input3 implements InputSig3 {
553568
Type inferTypeInput(AstNode n, TypePath path) {
554569
result = inferAssignmentOperationType(n, path)
555570
or
556-
result = inferFunctionCallType(n, path)
557-
or
558571
result = inferConstructionType(n, path)
559572
or
560573
result = inferOperationType(n, path)
@@ -2836,22 +2849,20 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
28362849
}
28372850
}
28382851

2839-
private module FunctionCallMatching = MatchingWithEnvironment<FunctionCallMatchingInput>;
2840-
28412852
pragma[nomagic]
28422853
private Type inferFunctionCallType0(
28432854
FunctionCallMatchingInput::Access call, FunctionPosition pos, AstNode n, DerefChain derefChain,
28442855
BorrowKind borrow, TypePath path
28452856
) {
28462857
exists(TypePath path0 |
2847-
n = call.getNodeAt(pos) and
28482858
exists(string derefChainBorrow |
28492859
FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow)
28502860
|
2851-
result = FunctionCallMatching::inferAccessType(call, derefChainBorrow, pos, path0)
2852-
or
2861+
n = call.getNodeAt(pos) and
28532862
call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and
28542863
result = TUnknownType()
2864+
or
2865+
result = inferCallTypeOut(call, pos, n, derefChainBorrow, path0)
28552866
)
28562867
|
28572868
if
@@ -2919,31 +2930,6 @@ private Type inferFunctionCallTypeSelf(
29192930
)
29202931
}
29212932

2922-
private Type inferFunctionCallTypePreCheck(
2923-
AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path
2924-
) {
2925-
exists(FunctionPosition pos |
2926-
result = inferFunctionCallTypeNonSelf(n, pos, path) and
2927-
if pos.isPosition()
2928-
then kind = ContextTyping::PositionalKind()
2929-
else kind = ContextTyping::ReturnKind()
2930-
)
2931-
or
2932-
exists(FunctionCallMatchingInput::Access a |
2933-
result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and
2934-
if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver()
2935-
then kind = ContextTyping::SelfKind()
2936-
else kind = ContextTyping::PositionalKind()
2937-
)
2938-
}
2939-
2940-
/**
2941-
* Gets the type of `n` at `path`, where `n` is either a function call or an
2942-
* argument/receiver of a function call.
2943-
*/
2944-
private predicate inferFunctionCallType =
2945-
ContextTyping::CheckContextTyping<inferFunctionCallTypePreCheck/3>::check/2;
2946-
29472933
abstract private class Constructor extends Addressable {
29482934
final TypeParameter getTypeParameter(TypeParameterPosition ppos) {
29492935
typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos)
@@ -3815,11 +3801,10 @@ private module Debug {
38153801
t = self.getTypeAt(path)
38163802
}
38173803

3818-
predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) {
3819-
n = getRelevantLocatable() and
3820-
t = inferFunctionCallType(n, path)
3821-
}
3822-
3804+
// predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) {
3805+
// n = getRelevantLocatable() and
3806+
// t = inferFunctionCallType(n, path)
3807+
// }
38233808
predicate debugInferConstructionType(AstNode n, TypePath path, Type t) {
38243809
n = getRelevantLocatable() and
38253810
t = inferConstructionType(n, path)

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10252,6 +10252,7 @@ inferType
1025210252
| main.rs:1412:17:1412:20 | self | TRef.TSlice | main.rs:1410:14:1410:23 | T |
1025310253
| main.rs:1412:17:1412:27 | self.get(...) | | {EXTERNAL LOCATION} | Option |
1025410254
| main.rs:1412:17:1412:27 | self.get(...) | T | {EXTERNAL LOCATION} | & |
10255+
| main.rs:1412:17:1412:27 | self.get(...) | T.TRef | main.rs:1410:14:1410:23 | T |
1025510256
| main.rs:1412:17:1412:36 | ... .unwrap() | | {EXTERNAL LOCATION} | & |
1025610257
| main.rs:1412:17:1412:36 | ... .unwrap() | TRef | main.rs:1410:14:1410:23 | T |
1025710258
| main.rs:1412:26:1412:26 | 0 | | {EXTERNAL LOCATION} | i32 |
@@ -11600,6 +11601,8 @@ inferType
1160011601
| main.rs:2221:18:2221:21 | true | | {EXTERNAL LOCATION} | bool |
1160111602
| main.rs:2223:9:2223:15 | S(...) | | main.rs:2107:5:2107:19 | S |
1160211603
| main.rs:2223:9:2223:15 | S(...) | T | {EXTERNAL LOCATION} | i64 |
11604+
| main.rs:2223:9:2223:15 | S(...) | T | main.rs:2107:5:2107:19 | S |
11605+
| main.rs:2223:9:2223:15 | S(...) | T.T | {EXTERNAL LOCATION} | i64 |
1160311606
| main.rs:2223:9:2223:31 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S |
1160411607
| main.rs:2223:9:2223:31 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 |
1160511608
| main.rs:2223:9:2223:31 | ... .my_add(...) | T | main.rs:2107:5:2107:19 | S |
@@ -11618,6 +11621,8 @@ inferType
1161811621
| main.rs:2224:24:2224:27 | 3i64 | | {EXTERNAL LOCATION} | i64 |
1161911622
| main.rs:2225:9:2225:15 | S(...) | | main.rs:2107:5:2107:19 | S |
1162011623
| main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | i64 |
11624+
| main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | & |
11625+
| main.rs:2225:9:2225:15 | S(...) | T.TRef | {EXTERNAL LOCATION} | i64 |
1162111626
| main.rs:2225:9:2225:29 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S |
1162211627
| main.rs:2225:9:2225:29 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 |
1162311628
| main.rs:2225:11:2225:14 | 1i64 | | {EXTERNAL LOCATION} | i64 |

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 144 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,6 +2124,13 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21242124
*/
21252125
default predicate cachedStageRevRef() { none() }
21262126

2127+
/**
2128+
* Point this predicate to the `inferType` predicate in the output of this module.
2129+
*
2130+
* Needed to be able to refer to `inferType` in default signature implementations.
2131+
*/
2132+
Type inferType(AstNode n, TypePath path);
2133+
21272134
/** A boolean type. */
21282135
class BoolType extends Type;
21292136

@@ -2221,29 +2228,64 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22212228
AstNode getRightOperand();
22222229
}
22232230

2224-
class CallTarget {
2231+
/**
2232+
* A position where a callable can have a declared type.
2233+
*/
2234+
class TypePosition {
2235+
/** Holds if this position represents the return type of a callable. */
2236+
predicate isReturn();
2237+
2238+
/** Gets a textual representation of this position. */
2239+
string toString();
2240+
}
2241+
2242+
/** A context needed to resolve calls. */
2243+
bindingset[this]
2244+
class CallResolutionContext {
2245+
/** Gets a textual representation of this context. */
2246+
bindingset[this]
2247+
string toString();
2248+
}
2249+
2250+
/** A callable. */
2251+
class Callable {
22252252
TypeParameter getTypeParameter(TypeParameterPosition ppos);
22262253

22272254
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp);
22282255

2229-
Type getReturnType(TypePath path);
2256+
/* Gets the declared type of this callable at `path` for position `pos`. */
2257+
Type getDeclaredType(TypePosition pos, TypePath path);
22302258

2231-
Type getParameterType(int index, TypePath path);
2232-
2233-
/** Gets a textual representation of this element. */
2259+
/** Gets a textual representation of this callable. */
22342260
string toString();
22352261

2236-
/** Gets the location of this element. */
2262+
/** Gets the location of this callable. */
22372263
Location getLocation();
22382264
}
22392265

22402266
class Call extends Expr {
22412267
Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
22422268

2269+
AstNode getNodeAt(TypePosition pos);
2270+
22432271
/** Gets the target of this call. */
2244-
CallTarget getTargetCertain();
2272+
Callable getTargetCertain();
2273+
2274+
/** Gets the target of this call. */
2275+
Callable getTarget(CallResolutionContext ctx);
22452276
}
22462277

2278+
/** Gets the inferred type `call` at `path` for position `pos` in context `ctx`. */
2279+
bindingset[ctx]
2280+
default Type inferCallTypeIn(
2281+
Call call, CallResolutionContext ctx, TypePosition pos, TypePath path
2282+
) {
2283+
result = inferType(call.getNodeAt(pos), path) and
2284+
exists(ctx)
2285+
}
2286+
2287+
Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path);
2288+
22472289
/**
22482290
* Holds if the types of `n1` at `path1` and `n2` at `path2` are certainly equal.
22492291
*/
@@ -2330,16 +2372,19 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
23302372

23312373
pragma[nomagic]
23322374
private Type getCertainCallExprType(Call call, TypePath path) {
2333-
forex(CallTarget target | target = call.getTargetCertain() |
2334-
result = target.getReturnType(path)
2375+
exists(TypePosition ret |
2376+
ret.isReturn() and
2377+
forex(Callable target | target = call.getTargetCertain() |
2378+
result = target.getDeclaredType(ret, path)
2379+
)
23352380
)
23362381
}
23372382

23382383
pragma[nomagic]
23392384
private Type inferCertainCallExprType(Call call, TypePath path) {
23402385
exists(Type ty, TypePath prefix | ty = getCertainCallExprType(call, prefix) |
23412386
exists(
2342-
CallTarget target, TypePath suffix, TypeParameterPosition tppos,
2387+
Callable target, TypePath suffix, TypeParameterPosition tppos,
23432388
TypeArgumentPosition tapos
23442389
|
23452390
ty = target.getTypeParameter(tppos) and
@@ -2497,15 +2542,104 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
24972542
(
24982543
result = inferTypeEquality(n, path)
24992544
or
2545+
result = CheckContextTyping<inferCallTypeOut/3>::check(n, path)
2546+
or
25002547
result = inferTypeInput(n, path)
25012548
)
25022549
}
25032550

2551+
private module TypePositionMatchingInput {
2552+
class DeclarationPosition = TypePosition;
2553+
2554+
class AccessPosition = DeclarationPosition;
2555+
2556+
predicate accessDeclarationPositionMatch(AccessPosition apos, DeclarationPosition dpos) {
2557+
apos = dpos
2558+
}
2559+
}
2560+
2561+
/**
2562+
* A matching configuration for resolving types of calls.
2563+
*/
2564+
private module CallMatchingInput implements MatchingWithEnvironmentInputSig {
2565+
import TypePositionMatchingInput
2566+
2567+
class Declaration = Callable;
2568+
2569+
bindingset[decl]
2570+
TypeMention getATypeParameterConstraint(TypeParameter tp, Declaration decl) {
2571+
result = Input2::getATypeParameterConstraint(tp) and
2572+
exists(decl)
2573+
or
2574+
result = decl.getAdditionalTypeParameterConstraint(tp)
2575+
}
2576+
2577+
class AccessEnvironment = CallResolutionContext;
2578+
2579+
final private class CallFinal = Call;
2580+
2581+
class Access extends CallFinal {
2582+
bindingset[e]
2583+
Type getInferredType(AccessEnvironment e, AccessPosition apos, TypePath path) {
2584+
result = inferCallTypeIn(this, e, apos, path)
2585+
}
2586+
}
2587+
}
2588+
2589+
private module CallMatching = MatchingWithEnvironment<CallMatchingInput>;
2590+
2591+
pragma[nomagic]
2592+
Type inferCallTypeOut(
2593+
Call call, TypePosition pos, AstNode n, CallResolutionContext ctx, TypePath path
2594+
) {
2595+
n = call.getNodeAt(pos) and
2596+
result = CallMatching::inferAccessType(call, ctx, pos, path)
2597+
}
2598+
2599+
pragma[nomagic]
2600+
private predicate hasUnknownTypeAt(AstNode n, TypePath path) {
2601+
inferType(n, path) instanceof UnknownType
2602+
}
2603+
2604+
pragma[nomagic]
2605+
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
2606+
2607+
private signature Type inferCallTypeSig(AstNode n, TypePosition pos, TypePath path);
2608+
2609+
/**
2610+
* Given a predicate `inferCallType` for inferring the type of a call at a given
2611+
* position, this module exposes the predicate `check`, which wraps the input
2612+
* predicate and checks that types are only propagated into arguments when they
2613+
* are context-typed.
2614+
*/
2615+
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
2616+
pragma[nomagic]
2617+
private Type inferCallNonReturnType(AstNode n, TypePath prefix, TypePath path) {
2618+
exists(TypePosition pos |
2619+
result = inferCallType(n, pos, path) and
2620+
hasUnknownType(n) and
2621+
not pos.isReturn() and
2622+
prefix = path.getAPrefix()
2623+
)
2624+
}
2625+
2626+
pragma[nomagic]
2627+
Type check(AstNode n, TypePath path) {
2628+
result = inferCallType(n, any(TypePosition pos | pos.isReturn()), path)
2629+
or
2630+
exists(TypePath prefix |
2631+
result = inferCallNonReturnType(n, prefix, path) and
2632+
hasUnknownTypeAt(n, prefix)
2633+
)
2634+
}
2635+
}
2636+
25042637
/**
25052638
* Gets the inferred root type of `n`, if any.
25062639
*/
25072640
Type inferType(AstNode n) { result = inferType(n, TypePath::nil()) }
25082641

2642+
// todo: consistency checks
25092643
/** The cached stage of type inference. */
25102644
cached
25112645
module CachedStage {

0 commit comments

Comments
 (0)