Skip to content

Commit 67e5694

Browse files
committed
wip3
1 parent ce5232b commit 67e5694

File tree

7 files changed

+276
-96
lines changed

7 files changed

+276
-96
lines changed

rust/ql/lib/codeql/rust/internal/typeinference/FunctionOverloading.qll

Lines changed: 107 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,24 @@ private predicate implSiblings(TraitItemNode trait, Impl impl1, Impl impl2) {
7676
pragma[nomagic]
7777
private predicate implHasSibling(ImplItemNode impl, Trait trait) { implSiblings(trait, impl, _) }
7878

79+
/**
80+
* Holds if `f` is a function declared inside `trait`, and the type of `f` at
81+
* `pos` and `path` is `traitTp`, which is a type parameter of `trait`.
82+
*/
7983
pragma[nomagic]
8084
predicate traitTypeParameterOccurrence(
8185
TraitItemNode trait, Function f, string functionName, FunctionPosition pos, TypePath path,
82-
TypeParameter tp
86+
TypeParameter traitTp
8387
) {
8488
f = trait.getAssocItem(functionName) and
85-
tp = getAssocFunctionTypeAt(f, trait, pos, path) and
86-
tp = trait.(TraitTypeAbstraction).getATypeParameter()
89+
traitTp = getAssocFunctionTypeAt(f, trait, pos, path) and
90+
traitTp = trait.(TraitTypeAbstraction).getATypeParameter()
8791
}
8892

8993
pragma[nomagic]
90-
private predicate functionResolutionDependsOnArgument0(
94+
private predicate functionResolutionDependsOnArgumentCand(
9195
ImplItemNode impl, Function f, string functionName, TypeParameter traitTp, FunctionPosition pos,
92-
TypePath path, Type type
96+
TypePath path
9397
) {
9498
/*
9599
* As seen in the example below, when an implementation has a sibling for a
@@ -119,57 +123,114 @@ private predicate functionResolutionDependsOnArgument0(
119123
exists(TraitItemNode trait |
120124
implHasSibling(impl, trait) and
121125
traitTypeParameterOccurrence(trait, _, functionName, pos, path, traitTp) and
122-
type = getAssocFunctionTypeAt(f, impl, pos, path) and
123-
f = impl.getASuccessor(functionName)
126+
f = impl.getASuccessor(functionName) and
127+
not pos.isSelf()
128+
)
129+
}
130+
131+
private predicate functionResolutionDependsOnPositionalArgument(
132+
ImplItemNode impl, Function f, string functionName, TypeParameter traitTp
133+
) {
134+
exists(FunctionPosition pos |
135+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, _) and
136+
pos.isPosition()
124137
)
125138
}
126139

140+
pragma[nomagic]
141+
private Type getAssocFunctionNonTypeParameterTypeAt(
142+
ImplItemNode impl, Function f, FunctionPosition pos, TypePath path
143+
) {
144+
result = getAssocFunctionTypeAt(f, impl, pos, path) and
145+
not result instanceof TypeParameter
146+
}
147+
127148
/**
128-
* Holds if resolving the function `f` in `impl` requires inspecting the type
129-
* of applied _arguments_ at position `pos` (possibly including the return type)
130-
* in order to determine whether it is the correct resolution.
149+
* Holds if `f` inside `impl` has a sibling implementation inside `sibling`, where
150+
* those two implementations agree on the instantiation of `traitTp`, which occurs
151+
* in a positional position inside `f`.
131152
*/
132153
pragma[nomagic]
133-
predicate functionResolutionDependsOnArgument(
134-
ImplItemNode impl, Function f, TypeParameter traitTp, FunctionPosition pos, TypePath path,
135-
Type type
154+
private predicate hasEquivalentPositionalSibling(
155+
ImplItemNode impl, ImplItemNode sibling, Function f, TypeParameter traitTp
136156
) {
137-
exists(string functionName |
138-
functionResolutionDependsOnArgument0(impl, f, functionName, traitTp, pos, path, type) and
139-
not pos.isSelf()
157+
exists(string functionName, FunctionPosition pos, TypePath path |
158+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, path) and
159+
pos.isPosition()
140160
|
141-
exists(FunctionPosition pos0 |
142-
functionResolutionDependsOnArgument0(impl, f, functionName, traitTp, pos0, _, _)
161+
exists(Function f1 |
162+
implSiblings(_, impl, sibling) and
163+
f1 = sibling.getASuccessor(functionName)
143164
|
144-
pos0.isPosition()
145-
or
146-
pos0.isReturn() and
147-
not exists(FunctionPosition pos1 |
148-
functionResolutionDependsOnArgument0(impl, f, functionName, _, pos1, _, _) and
149-
pos1.isPosition()
165+
forall(TypePath path0, Type t |
166+
t = getAssocFunctionNonTypeParameterTypeAt(impl, f, pos, path0) and
167+
(path = path0.getAPrefix() or path = path0)
168+
|
169+
t = getAssocFunctionNonTypeParameterTypeAt(sibling, f1, pos, path0)
170+
) and
171+
forall(TypePath path0, Type t |
172+
t = getAssocFunctionNonTypeParameterTypeAt(sibling, f1, pos, path0) and
173+
(path = path0.getAPrefix() or path = path0)
174+
|
175+
t = getAssocFunctionNonTypeParameterTypeAt(impl, f, pos, path0)
150176
)
151177
)
152-
// |
153-
// pos.isPosition()
154-
// or
155-
// // Only disambiguate based on return type when all other positions are trivially
156-
// // satisfied for all arguments.
157-
// pos.isReturn() and
158-
// forall(FunctionPosition pos0, TypePath path0, Type type0 |
159-
// pos0.isPosition() and
160-
// functionResolutionDependsOnArgument0(impl, f, _, _, pos0, path0, type0)
161-
// |
162-
// type0.(TypeParamTypeParameter).getTypeParam() = any(TypeParam tp | not tp.hasTypeBound())
163-
// or
164-
// forall(ImplItemNode impl1, Function f1, Type type1 |
165-
// implSiblings(_, impl, impl1) and
166-
// f1 = impl1.getASuccessor(functionName) and
167-
// type1 = getAssocFunctionTypeAt(f1, impl1, pos0, path0)
168-
// |
169-
// type1.(TypeParamTypeParameter).getTypeParam() = any(TypeParam tp | not tp.hasTypeBound())
170-
// or
171-
// type0 = type1
172-
// )
173-
// )
178+
)
179+
}
180+
181+
/**
182+
* Holds if resolving the function `f` in `impl` requires inspecting the type
183+
* of applied _arguments_ (possibly including knowing the return type).
184+
*
185+
* `traitTp` is a type parameter of the trait being implemented by `impl`, and
186+
* we need to check that the type of `f` corresponding to `traitTp` is satisfied
187+
* at any one of the positions `pos` in which that type occurs in `f`.
188+
*
189+
* Type parameters that only occur in return positions are only included when
190+
* all other type parameters that occur in a positional position are insufficient
191+
* to disambiguate.
192+
*
193+
* Example:
194+
*
195+
* ```rust
196+
* trait Trait1<T1> {
197+
* fn f(self, x: T1) -> T1;
198+
* }
199+
*
200+
* impl Trait1<i32> for i32 {
201+
* fn f(self, x: i32) -> i32 { 0 } // f1
202+
* }
203+
*
204+
* impl Trait1<i64> for i32 {
205+
* fn f(self, x: i64) -> i64 { 0 } // f2
206+
* }
207+
* ```
208+
*
209+
* The type for `T1` above occurs in both a positional position and a return position
210+
* in `f`, so both may be used to disambiguate between `f1` and `f2`. That is, `f(0i32)`
211+
* is sufficient to resolve to `f1`, and so is `let y: i64 = f(Default::default())`.
212+
*/
213+
pragma[nomagic]
214+
predicate functionResolutionDependsOnArgument(
215+
ImplItemNode impl, Function f, TypeParameter traitTp, FunctionPosition pos
216+
) {
217+
exists(string functionName, TypePath path |
218+
functionResolutionDependsOnArgumentCand(impl, f, functionName, traitTp, pos, path)
219+
|
220+
// Only include type parameters occuring in return positions when there are
221+
// no other type parameters occuring in positional positions. Note that if
222+
// a type parameter occurs in both a positional and a return position, then
223+
// both argument type and return type can be used to satisfy the constraint.
224+
if functionResolutionDependsOnPositionalArgument(impl, f, functionName, traitTp)
225+
then any()
226+
else
227+
exists(ImplItemNode sibling |
228+
implSiblings(_, impl, sibling) and
229+
forall(TypeParameter otherTraitTp |
230+
functionResolutionDependsOnPositionalArgument(impl, f, functionName, otherTraitTp)
231+
|
232+
hasEquivalentPositionalSibling(impl, sibling, f, otherTraitTp)
233+
)
234+
)
174235
)
175236
}

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
416416
/**
417417
* Holds if all arguments of `call` have types that are instantiations of the
418418
* types of the corresponding parameters of `f` inside `i`.
419+
*
420+
* TODO: Check type parameter constraints as well.
419421
*/
420422
pragma[nomagic]
421423
predicate argsAreInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {
@@ -451,6 +453,8 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
451453
/**
452454
* Holds if _some_ argument of `call` has a type that is not an instantiation of the
453455
* type of the corresponding parameter of `f` inside `i`.
456+
*
457+
* TODO: Check type parameter constraints as well.
454458
*/
455459
pragma[nomagic]
456460
predicate argsAreNotInstantiationsOf(Input::Call call, ImplOrTraitItemNode i, Function f) {

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,7 +2143,7 @@ private module MethodResolution {
21432143
pragma[nomagic]
21442144
Method resolveCallTarget(ImplOrTraitItemNode i) {
21452145
result = this.resolveCallTargetCand(i) and
2146-
not FunctionOverloading::functionResolutionDependsOnArgument(i, result, _, _, _, _)
2146+
not FunctionOverloading::functionResolutionDependsOnArgument(i, result, _, _)
21472147
or
21482148
MethodArgsAreInstantiationsOf::argsAreInstantiationsOf(this, i, result)
21492149
}
@@ -2378,10 +2378,7 @@ private module MethodResolution {
23782378
*/
23792379
private module MethodArgsAreInstantiationsOfInput implements ArgsAreInstantiationsOfInputSig {
23802380
predicate toCheck(ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos) {
2381-
exists(TypePath path, Type t0 |
2382-
FunctionOverloading::functionResolutionDependsOnArgument(i, f, traitTp, pos, path, t0) and
2383-
typeCanBeUsedForDisambiguation(t0)
2384-
)
2381+
FunctionOverloading::functionResolutionDependsOnArgument(i, f, traitTp, pos)
23852382
}
23862383

23872384
class Call extends MethodCallCand {
@@ -2825,7 +2822,7 @@ private module NonMethodResolution {
28252822
pragma[nomagic]
28262823
NonMethodFunction resolveCallTargetViaTypeInference(ImplOrTraitItemNode i) {
28272824
result = this.resolveCallTargetBlanketCand(i) and
2828-
not FunctionOverloading::functionResolutionDependsOnArgument(_, result, _, _, _, _)
2825+
not FunctionOverloading::functionResolutionDependsOnArgument(_, result, _, _)
28292826
or
28302827
NonMethodArgsAreInstantiationsOfBlanket1::argsAreInstantiationsOf(this, i, result)
28312828
or
@@ -2933,9 +2930,7 @@ private module NonMethodResolution {
29332930
ArgsAreInstantiationsOfInputSig
29342931
{
29352932
predicate toCheck(ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos) {
2936-
exists(Type t0 | typeCanBeUsedForDisambiguation(t0) |
2937-
FunctionOverloading::functionResolutionDependsOnArgument(i, f, traitTp, pos, _, t0)
2938-
)
2933+
FunctionOverloading::functionResolutionDependsOnArgument(i, f, traitTp, pos)
29392934
}
29402935

29412936
final class Call extends NonMethodCall {
@@ -2979,10 +2974,7 @@ private module NonMethodResolution {
29792974
ArgsAreInstantiationsOfInputSig
29802975
{
29812976
predicate toCheck(ImplOrTraitItemNode i, Function f, TypeParameter traitTp, FunctionPosition pos) {
2982-
// NonMethodArgsAreInstantiationsOfBlanketInput::toCheck(i, f, pos, t)
2983-
exists(Type t0 | typeCanBeUsedForDisambiguation(t0) |
2984-
FunctionOverloading::functionResolutionDependsOnArgument(i, f, traitTp, pos, _, t0)
2985-
)
2977+
FunctionOverloading::functionResolutionDependsOnArgument(i, f, traitTp, pos)
29862978
}
29872979

29882980
class Call extends NonMethodCall {
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
multipleResolvedTargets
2+
| main.rs:2220:9:2220:31 | ... .my_add(...) |
3+
| main.rs:2222:9:2222:29 | ... .my_add(...) |
24
| main.rs:2720:13:2720:17 | x.f() |

rust/ql/test/library-tests/type-inference/overloading.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,53 @@ pub mod impl_overlap {
151151
}
152152
}
153153

154+
mod foo {
155+
trait Trait1<T1> {
156+
fn f(self, x: T1) -> T1;
157+
}
158+
159+
impl Trait1<i32> for i32 {
160+
// f1
161+
fn f(self, x: i32) -> i32 {
162+
0
163+
}
164+
}
165+
166+
impl Trait1<i64> for i32 {
167+
// f2
168+
fn f(self, x: i64) -> i64 {
169+
0
170+
}
171+
}
172+
173+
fn f() {
174+
let x = 0;
175+
let y = x.f(0i32); // $ target=f1
176+
let z: i32 = x.f(Default::default()); // $ target=f1
177+
let z = x.f(0i64); // $ target=f2
178+
let z: i64 = x.f(Default::default()); // $ target=f2
179+
let z: i64 = x.g(0i32); // $ target=g4
180+
}
181+
182+
trait Trait2<T1, T2> {
183+
fn g(self, x: T1) -> T2;
184+
}
185+
186+
impl Trait2<i32, i32> for i32 {
187+
// g3
188+
fn g(self, x: i32) -> i32 {
189+
0
190+
}
191+
}
192+
193+
impl Trait2<i32, i64> for i32 {
194+
// g4
195+
fn g(self, x: i32) -> i64 {
196+
0
197+
}
198+
}
199+
}
200+
154201
mod from_default {
155202
#[derive(Default)]
156203
struct S;

0 commit comments

Comments
 (0)