Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions rust/ql/lib/codeql/rust/frameworks/stdlib/Stdlib.qll
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,46 @@
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
}

/** Any of the function traits `FnOnce`, `FnMut`, or `Fn`. */

Check warning

Code scanning / CodeQL

Class QLDoc style Warning

The QLDoc for a class should start with 'A', 'An', or 'The'.
class AnyFnTrait extends Trait {
/** Gets the `Args` type parameter of this trait. */
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }
}

/**
* The [`FnOnce` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.FnOnce.html
*/
class FnOnceTrait extends Trait {
class FnOnceTrait extends AnyFnTrait {
pragma[nomagic]
FnOnceTrait() { this.getCanonicalPath() = "core::ops::function::FnOnce" }

/** Gets the type parameter of this trait. */
TypeParam getTypeParam() { result = this.getGenericParamList().getGenericParam(0) }

/** Gets the `Output` associated type. */
pragma[nomagic]
TypeAlias getOutputType() { result = this.(TraitItemNode).getAssocItem("Output") }
}

/**
* The [`FnMut` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.FnMut.html
*/
class FnMutTrait extends AnyFnTrait {
pragma[nomagic]
FnMutTrait() { this.getCanonicalPath() = "core::ops::function::FnMut" }
}

/**
* The [`Fn` trait][1].
*
* [1]: https://doc.rust-lang.org/std/ops/trait.Fn.html
*/
class FnTrait extends AnyFnTrait {
pragma[nomagic]
FnTrait() { this.getCanonicalPath() = "core::ops::function::Fn" }
}

/**
* The [`Iterator` trait][1].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3825,16 +3825,29 @@ private Type invokedClosureFnTypeAt(InvokedClosureExpr ce, TypePath path) {
_, path, result)
}

/**
* Gets the root type of a closure.
*
* We model closures as `dyn Fn` trait object types. A closure might implement
* only `Fn`, `FnMut`, or `FnOnce`. But since `Fn` is a subtrait of the others,
* giving closures the type `dyn Fn` works well in practice—even if not entirely
* accurate.
*/
private DynTraitType closureRootType() {
result = TDynTraitType(any(FnTrait t)) // always exists because of the mention in `builtins/mentions.rs`
}

/** Gets the path to a closure's return type. */
private TypePath closureReturnPath() {
result = TypePath::singleton(getDynTraitTypeParameter(any(FnOnceTrait t).getOutputType()))
result =
TypePath::singleton(TDynTraitTypeParameter(any(FnTrait t), any(FnOnceTrait t).getOutputType()))
}

/** Gets the path to a closure with arity `arity`s `index`th parameter type. */
pragma[nomagic]
private TypePath closureParameterPath(int arity, int index) {
result =
TypePath::cons(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam()),
TypePath::cons(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam()),
TypePath::singleton(getTupleTypeParameter(arity, index)))
}

Expand Down Expand Up @@ -3872,9 +3885,7 @@ private Type inferDynamicCallExprType(Expr n, TypePath path) {
or
// _If_ the invoked expression has the type of a closure, then we propagate
// the surrounding types into the closure.
exists(int arity, TypePath path0 |
ce.getTypeAt(TypePath::nil()).(DynTraitType).getTrait() instanceof FnOnceTrait
|
exists(int arity, TypePath path0 | ce.getTypeAt(TypePath::nil()) = closureRootType() |
// Propagate the type of arguments to the parameter types of closure
exists(int index, ArgList args |
n = ce and
Expand All @@ -3898,10 +3909,10 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
path.isEmpty() and
result = TDynTraitType(any(FnOnceTrait t)) // always exists because of the mention in `builtins/mentions.rs`
result = closureRootType()
or
n = ce and
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnOnceTrait t).getTypeParam())) and
path = TypePath::singleton(TDynTraitTypeParameter(_, any(FnTrait t).getTypeParam())) and
result.(TupleType).getArity() = ce.getNumberOfParams()
or
// Propagate return type annotation to body
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,15 @@ class NonAliasPathTypeMention extends PathTypeMention {
// associated types of `Fn` and `FnMut` yet.
//
// [1]: https://doc.rust-lang.org/reference/paths.html#grammar-TypePathFn
exists(FnOnceTrait t, PathSegment s |
exists(AnyFnTrait t, PathSegment s |
t = resolved and
s = this.getSegment() and
s.hasParenthesizedArgList()
|
tp = TTypeParamTypeParameter(t.getTypeParam()) and
result = s.getParenthesizedArgList().(TypeMention).resolveTypeAt(path)
or
tp = TAssociatedTypeTypeParameter(t, t.getOutputType()) and
tp = TAssociatedTypeTypeParameter(t, any(FnOnceTrait tr).getOutputType()) and
(
result = s.getRetType().getTypeRepr().(TypeMention).resolveTypeAt(path)
or
Expand Down
74 changes: 74 additions & 0 deletions rust/ql/test/library-tests/type-inference/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,80 @@ mod fn_once_trait {
}
}

mod fn_mut_trait {
fn return_type<F: FnMut(bool) -> i64>(mut f: F) {
let _return = f(true); // $ type=_return:i64
}

fn return_type_omitted<F: FnMut(bool)>(mut f: F) {
let _return = f(true); // $ type=_return:()
}

fn argument_type<F: FnMut(bool) -> i64>(mut f: F) {
let arg = Default::default(); // $ target=default type=arg:bool
f(arg);
}

fn apply<A, B, F: FnMut(A) -> B>(mut f: F, a: A) -> B {
f(a)
}

fn apply_two(mut f: impl FnMut(i64) -> i64) -> i64 {
f(2)
}

fn test() {
let f = |x: bool| -> i64 {
if x {
1
} else {
0
}
};
let _r = apply(f, true); // $ target=apply type=_r:i64

let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
}
}

mod fn_trait {
fn return_type<F: Fn(bool) -> i64>(f: F) {
let _return = f(true); // $ type=_return:i64
}

fn return_type_omitted<F: Fn(bool)>(f: F) {
let _return = f(true); // $ type=_return:()
}

fn argument_type<F: Fn(bool) -> i64>(f: F) {
let arg = Default::default(); // $ target=default type=arg:bool
f(arg);
}

fn apply<A, B, F: Fn(A) -> B>(f: F, a: A) -> B {
f(a)
}

fn apply_two(f: impl Fn(i64) -> i64) -> i64 {
f(2)
}

fn test() {
let f = |x: bool| -> i64 {
if x {
1
} else {
0
}
};
let _r = apply(f, true); // $ target=apply type=_r:i64

let f = |x| x + 1; // $ MISSING: type=x:i64 target=add
let _r2 = apply_two(f); // $ target=apply_two certainType=_r2:i64
}
}

mod dyn_fn_once {
fn apply_boxed<A, B, F: FnOnce(A) -> B + ?Sized>(f: Box<F>, arg: A) -> B {
f(arg)
Expand Down
Loading