Skip to content

Commit 4e0e3e7

Browse files
committed
wip1
1 parent 74fa6d1 commit 4e0e3e7

File tree

14 files changed

+433
-219
lines changed

14 files changed

+433
-219
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll

Lines changed: 107 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -77,27 +77,23 @@ pragma[nomagic]
7777
private predicate implHasSibling(ImplItemNode impl, Trait trait) { implSiblings(trait, impl, _) }
7878

7979
/**
80-
* Holds if type parameter `tp` of `trait` occurs in the function `f` with the name
81-
* `functionName` at position `pos` and path `path`.
82-
*
83-
* Note that `pos` can also be the special `return` position, which is sometimes
84-
* needed to disambiguate associated function calls like `Default::default()`
85-
* (in this case, `tp` is the special `Self` type parameter).
80+
* Holds if `f` is a function declared inside `trait`, and the type of `f` at
81+
* `pos` and `path` is `traitTp`, which is a type parameter of `trait`.
8682
*/
87-
bindingset[trait]
88-
pragma[inline_late]
83+
pragma[nomagic]
8984
predicate traitTypeParameterOccurrence(
9085
TraitItemNode trait, Function f, string functionName, FunctionPosition pos, TypePath path,
91-
TypeParameter tp
86+
TypeParameter traitTp
9287
) {
93-
f = trait.getASuccessor(functionName) and
94-
tp = getAssocFunctionTypeAt(f, trait, pos, path) and
95-
tp = trait.(TraitTypeAbstraction).getATypeParameter()
88+
f = trait.getAssocItem(functionName) and
89+
traitTp = getAssocFunctionTypeAt(f, trait, pos, path) and
90+
traitTp = trait.(TraitTypeAbstraction).getATypeParameter()
9691
}
9792

9893
pragma[nomagic]
99-
private predicate functionResolutionDependsOnArgument0(
100-
ImplItemNode impl, Function f, FunctionPosition pos, TypePath path, Type type
94+
private predicate functionResolutionDependsOnArgumentCand(
95+
ImplItemNode impl, Function f, string functionName, TypeParameter traitTp, FunctionPosition pos,
96+
TypePath path
10197
) {
10298
/*
10399
* As seen in the example below, when an implementation has a sibling for a
@@ -124,36 +120,113 @@ private predicate functionResolutionDependsOnArgument0(
124120
* method. In that case we will still resolve several methods.
125121
*/
126122

127-
exists(TraitItemNode trait, string functionName |
123+
exists(TraitItemNode trait |
128124
implHasSibling(impl, trait) and
129-
traitTypeParameterOccurrence(trait, _, functionName, pos, path, _) and
130-
type = getAssocFunctionTypeAt(f, impl, pos, path) and
131-
f = impl.getASuccessor(functionName)
125+
traitTypeParameterOccurrence(trait, _, functionName, pos, path, traitTp) and
126+
f = impl.getASuccessor(functionName) and
127+
not pos.isSelf()
128+
)
129+
}
130+
131+
private predicate functionResolutionDependsOnPositionalArgument(
132+
ImplItemNode impl, Function f, string functionName, TypeParameter traitTp
133+
) {
134+
exists(FunctionPosition pos |
135+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, _) and
136+
pos.isPosition()
132137
)
133138
}
134139

140+
pragma[nomagic]
141+
private Type getAssocFunctionNonTypeParameterTypeAt(
142+
ImplItemNode impl, Function f, FunctionPosition pos, TypePath path
143+
) {
144+
result = getAssocFunctionTypeAt(f, impl, pos, path) and
145+
not result instanceof TypeParameter
146+
}
147+
135148
/**
136-
* Holds if resolving the function `f` in `impl` requires inspecting the type
137-
* of applied _arguments_ at position `pos` (including the return type) in
138-
* order to determine whether it is the correct resolution.
149+
* Holds if `f` inside `impl` has a sibling implementation inside `sibling`, where
150+
* those two implementations agree on the instantiation of `traitTp`, which occurs
151+
* in a positional position inside `f`.
139152
*/
140153
pragma[nomagic]
141-
predicate functionResolutionDependsOnArgument(
142-
ImplItemNode impl, Function f, FunctionPosition pos, TypePath path, Type type
154+
private predicate hasEquivalentPositionalSibling(
155+
ImplItemNode impl, ImplItemNode sibling, Function f, TypeParameter traitTp
143156
) {
144-
functionResolutionDependsOnArgument0(impl, f, pos, path, type) and
145-
(
157+
exists(string functionName, FunctionPosition pos, TypePath path |
158+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, path) and
146159
pos.isPosition()
147-
or
148-
// Only disambiguate based on return type when all other positions are trivially
149-
// satisfied for all arguments.
150-
pos.isReturn() and
151-
forall(FunctionPosition pos0, TypePath path0, Type type0 |
152-
pos0.isPosition() and
153-
functionResolutionDependsOnArgument0(impl, f, pos0, path0, type0)
160+
|
161+
exists(Function f1 |
162+
implSiblings(_, impl, sibling) and
163+
f1 = sibling.getASuccessor(functionName)
154164
|
155-
path0.isEmpty() and
156-
type0.(TypeParamTypeParameter).getTypeParam() = any(TypeParam tp | not tp.hasTypeBound())
165+
forall(TypePath path0, Type t |
166+
t = getAssocFunctionNonTypeParameterTypeAt(impl, f, pos, path0) and
167+
(path = path0.getAPrefix() or path = path0)
168+
|
169+
t = getAssocFunctionNonTypeParameterTypeAt(sibling, f1, pos, path0)
170+
) and
171+
forall(TypePath path0, Type t |
172+
t = getAssocFunctionNonTypeParameterTypeAt(sibling, f1, pos, path0) and
173+
(path = path0.getAPrefix() or path = path0)
174+
|
175+
t = getAssocFunctionNonTypeParameterTypeAt(impl, f, pos, path0)
176+
)
157177
)
158178
)
159179
}
180+
181+
/**
182+
* Holds if resolving the function `f` in `impl` requires inspecting the type
183+
* of applied _arguments_ or possibly knowing the return type.
184+
*
185+
* `traitTp` is a type parameter of the trait being implemented by `impl`, and
186+
* we need to check that the type of `f` corresponding to `traitTp` is satisfied
187+
* at any one of the positions `pos` in which that type occurs in `f`.
188+
*
189+
* Type parameters that only occur in return positions are only included when
190+
* all other type parameters that occur in a positional position are insufficient
191+
* to disambiguate.
192+
*
193+
* Example:
194+
*
195+
* ```rust
196+
* trait Trait1<T1> {
197+
* fn f(self, x: T1) -> T1;
198+
* }
199+
*
200+
* impl Trait1<i32> for i32 {
201+
* fn f(self, x: i32) -> i32 { 0 } // f1
202+
* }
203+
*
204+
* impl Trait1<i64> for i32 {
205+
* fn f(self, x: i64) -> i64 { 0 } // f2
206+
* }
207+
* ```
208+
*
209+
* The type for `T1` above occurs in both a positional position and a return position
210+
* in `f`, so both may be used to disambiguate between `f1` and `f2`. That is, `f(0i32)`
211+
* is sufficient to resolve to `f1`, and so is `let y: i64 = f(Default::default())`.
212+
*/
213+
pragma[nomagic]
214+
predicate functionResolutionDependsOnArgument(
215+
ImplItemNode impl, Function f, TypeParameter traitTp, FunctionPosition pos
216+
) {
217+
exists(string functionName, TypePath path |
218+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, path)
219+
|
220+
if functionResolutionDependsOnPositionalArgument(impl, f, functionName, traitTp)
221+
then any()
222+
else
223+
exists(ImplItemNode sibling |
224+
implSiblings(_, impl, sibling) and
225+
forall(TypeParameter otherTraitTp |
226+
functionResolutionDependsOnPositionalArgument(impl, f, functionName, otherTraitTp)
227+
|
228+
hasEquivalentPositionalSibling(impl, sibling, f, otherTraitTp)
229+
)
230+
)
231+
)
232+
}

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,14 @@ module ArgIsInstantiationOf<
294294
*/
295295
signature module ArgsAreInstantiationsOfInputSig {
296296
/**
297-
* Holds if types need to be matched against the type `t` at position `pos` of
298-
* `f` inside `i`.
297+
* Holds if `f` implements a trait function with type parameter `traitTp`, where
298+
* we need to check that the type of `f` for `traitTp` is satisfied.
299+
*
300+
* `pos` is one of the positions in `f` in which the relevant type occours.
301+
*
302+
* For example, i
299303
*/
300-
predicate toCheck(ImplOrTraitItemNode i, Function f, FunctionPosition pos, AssocFunctionType t);
304+
predicate toCheck(ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos);
301305

302306
/** A call whose argument types are to be checked. */
303307
class Call {
@@ -318,23 +322,28 @@ signature module ArgsAreInstantiationsOfInputSig {
318322
*/
319323
module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
320324
pragma[nomagic]
321-
private predicate toCheckRanked(ImplOrTraitItemNode i, Function f, FunctionPosition pos, int rnk) {
322-
Input::toCheck(i, f, pos, _) and
323-
pos =
324-
rank[rnk + 1](FunctionPosition pos0, int j |
325-
Input::toCheck(i, f, pos0, _) and
326-
(
327-
j = pos0.asPosition()
328-
or
329-
pos0.isSelf() and j = -1
330-
or
331-
pos0.isReturn() and j = -2
332-
)
325+
private predicate toCheckRanked(
326+
ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos, int rnk
327+
) {
328+
Input::toCheck(i, f, traitTp, pos) and
329+
traitTp =
330+
rank[rnk + 1](TypeParameter traitTp0, int j |
331+
Input::toCheck(i, f, traitTp0, _) and
332+
j = getTypeParameterId(traitTp0)
333333
|
334-
pos0 order by j
334+
traitTp0 order by j
335335
)
336336
}
337337

338+
pragma[nomagic]
339+
private predicate toCheck(
340+
ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos,
341+
AssocFunctionType t
342+
) {
343+
Input::toCheck(i, f, traitTp, pos) and
344+
t.appliesTo(f, i, pos)
345+
}
346+
338347
private newtype TCallAndPos =
339348
MkCallAndPos(Input::Call call, FunctionPosition pos) { exists(call.getArgType(pos, _)) }
340349

@@ -356,36 +365,34 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
356365
string toString() { result = call.toString() + " [arg " + pos + "]" }
357366
}
358367

368+
pragma[nomagic]
369+
private predicate potentialInstantiationOf0(
370+
CallAndPos cp, Input::Call call, TypeParameter traitTp, FunctionPosition pos, Function f,
371+
TypeAbstraction abs, AssocFunctionType constraint
372+
) {
373+
cp = MkCallAndPos(call, pragma[only_bind_into](pos)) and
374+
call.hasTargetCand(abs, f) and
375+
toCheck(abs, f, traitTp, pragma[only_bind_into](pos), constraint)
376+
}
377+
359378
private module ArgIsInstantiationOfToIndexInput implements
360379
IsInstantiationOfInputSig<CallAndPos, AssocFunctionType>
361380
{
362-
pragma[nomagic]
363-
private predicate potentialInstantiationOf0(
364-
CallAndPos cp, Input::Call call, FunctionPosition pos, int rnk, Function f,
365-
TypeAbstraction abs, AssocFunctionType constraint
366-
) {
367-
cp = MkCallAndPos(call, pragma[only_bind_into](pos)) and
368-
call.hasTargetCand(abs, f) and
369-
toCheckRanked(abs, f, pragma[only_bind_into](pos), rnk) and
370-
Input::toCheck(abs, f, pragma[only_bind_into](pos), constraint)
371-
}
372-
373381
pragma[nomagic]
374382
predicate potentialInstantiationOf(
375383
CallAndPos cp, TypeAbstraction abs, AssocFunctionType constraint
376384
) {
377-
exists(Input::Call call, int rnk, Function f |
378-
potentialInstantiationOf0(cp, call, _, rnk, f, abs, constraint)
385+
exists(Input::Call call, TypeParameter traitTp, FunctionPosition pos, int rnk, Function f |
386+
potentialInstantiationOf0(cp, call, traitTp, pragma[only_bind_into](pos), f, abs, constraint) and
387+
toCheckRanked(abs, f, traitTp, pragma[only_bind_into](pos), rnk)
379388
|
380389
rnk = 0
381390
or
382391
argsAreInstantiationsOfToIndex(call, abs, f, rnk - 1)
383392
)
384393
}
385394

386-
predicate relevantConstraint(AssocFunctionType constraint) {
387-
Input::toCheck(_, _, _, constraint)
388-
}
395+
predicate relevantConstraint(AssocFunctionType constraint) { toCheck(_, _, _, _, constraint) }
389396
}
390397

391398
private module ArgIsInstantiationOfToIndex =
@@ -398,39 +405,63 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
398405
exists(FunctionPosition pos |
399406
ArgIsInstantiationOfToIndex::argIsInstantiationOf(MkCallAndPos(call, pos), i, _) and
400407
call.hasTargetCand(i, f) and
401-
toCheckRanked(i, f, pos, rnk)
408+
toCheckRanked(i, f, _, pos, rnk)
409+
|
410+
rnk = 0
411+
or
412+
argsAreInstantiationsOfToIndex(call, i, f, rnk - 1)
402413
)
403414
}
404415

405416
/**
406417
* Holds if all arguments of `call` have types that are instantiations of the
407418
* types of the corresponding parameters of `f` inside `i`.
419+
*
420+
* TODO: Check type parameter constraints as well.
408421
*/
409422
pragma[nomagic]
410423
predicate argsAreInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {
411424
exists(int rnk |
412425
argsAreInstantiationsOfToIndex(call, i, f, rnk) and
413-
rnk = max(int r | toCheckRanked(i, f, _, r))
426+
rnk = max(int r | toCheckRanked(i, f, _, _, r))
414427
)
415428
}
416429

430+
private module ArgsAreNotInstantiationOfInput implements
431+
IsInstantiationOfInputSig<CallAndPos, AssocFunctionType>
432+
{
433+
pragma[nomagic]
434+
predicate potentialInstantiationOf(
435+
CallAndPos cp, TypeAbstraction abs, AssocFunctionType constraint
436+
) {
437+
potentialInstantiationOf0(cp, _, _, _, _, abs, constraint)
438+
}
439+
440+
predicate relevantConstraint(AssocFunctionType constraint) { toCheck(_, _, _, _, constraint) }
441+
}
442+
443+
private module ArgsAreNotInstantiationOf =
444+
ArgIsInstantiationOf<CallAndPos, ArgsAreNotInstantiationOfInput>;
445+
417446
pragma[nomagic]
418447
private predicate argsAreNotInstantiationsOf0(
419448
Input::Call call, FunctionPosition pos, ImplOrTraitItemNode i
420449
) {
421-
ArgIsInstantiationOfToIndex::argIsNotInstantiationOf(MkCallAndPos(call, pos), i, _, _)
450+
ArgsAreNotInstantiationOf::argIsNotInstantiationOf(MkCallAndPos(call, pos), i, _, _)
422451
}
423452

424453
/**
425454
* Holds if _some_ argument of `call` has a type that is not an instantiation of the
426455
* type of the corresponding parameter of `f` inside `i`.
456+
*
457+
* TODO: Check type parameter constraints as well.
427458
*/
428459
pragma[nomagic]
429460
predicate argsAreNotInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {
430461
exists(FunctionPosition pos |
431462
argsAreNotInstantiationsOf0(call, pos, i) and
432463
call.hasTargetCand(i, f) and
433-
Input::toCheck(i, f, pos, _)
464+
Input::toCheck(i, f, _, pos)
434465
)
435466
}
436467
}

0 commit comments

Comments
 (0)