@@ -337,7 +337,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
337337 ) -> ClosureSignatures < ' tcx > {
338338 debug ! ( "sig_of_closure_no_expectation()" ) ;
339339
340- let bound_sig = self . supplied_sig_of_closure ( expr_def_id, decl) ;
340+ let bound_sig = self . supplied_sig_of_closure ( expr_def_id, decl, body ) ;
341341
342342 self . closure_sigs ( expr_def_id, body, bound_sig)
343343 }
@@ -490,7 +490,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
490490 //
491491 // (See comment on `sig_of_closure_with_expectation` for the
492492 // meaning of these letters.)
493- let supplied_sig = self . supplied_sig_of_closure ( expr_def_id, decl) ;
493+ let supplied_sig = self . supplied_sig_of_closure ( expr_def_id, decl, body ) ;
494494
495495 debug ! (
496496 "check_supplied_sig_against_expectation: supplied_sig={:?}" ,
@@ -591,14 +591,31 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
591591 & self ,
592592 expr_def_id : DefId ,
593593 decl : & hir:: FnDecl ,
594+ body : & hir:: Body ,
594595 ) -> ty:: PolyFnSig < ' tcx > {
595596 let astconv: & dyn AstConv < ' _ > = self ;
596597
598+ debug ! (
599+ "supplied_sig_of_closure(decl={:?}, body.generator_kind={:?})" ,
600+ decl,
601+ body. generator_kind,
602+ ) ;
603+
597604 // First, convert the types that the user supplied (if any).
598605 let supplied_arguments = decl. inputs . iter ( ) . map ( |a| astconv. ast_ty_to_ty ( a) ) ;
599606 let supplied_return = match decl. output {
600607 hir:: Return ( ref output) => astconv. ast_ty_to_ty ( & output) ,
601- hir:: DefaultReturn ( _) => astconv. ty_infer ( None , decl. output . span ( ) ) ,
608+ hir:: DefaultReturn ( _) => match body. generator_kind {
609+ // In the case of the async block that we create for a function body,
610+ // we expect the return type of the block to match that of the enclosing
611+ // function.
612+ Some ( hir:: GeneratorKind :: Async ( hir:: AsyncGeneratorKind :: Fn ) ) => {
613+ debug ! ( "supplied_sig_of_closure: closure is async fn body" ) ;
614+ self . deduce_future_output_from_obligations ( expr_def_id)
615+ }
616+
617+ _ => astconv. ty_infer ( None , decl. output . span ( ) ) ,
618+ }
602619 } ;
603620
604621 let result = ty:: Binder :: bind ( self . tcx . mk_fn_sig (
@@ -620,6 +637,117 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
620637 result
621638 }
622639
640+ /// Invoked when we are translating the generator that results
641+ /// from desugaring an `async fn`. Returns the "sugared" return
642+ /// type of the `async fn` -- that is, the return type that the
643+ /// user specified. The "desugared" return type is a `impl
644+ /// Future<Output = T>`, so we do this by searching through the
645+ /// obligations to extract the `T`.
646+ fn deduce_future_output_from_obligations (
647+ & self ,
648+ expr_def_id : DefId ,
649+ ) -> Ty < ' tcx > {
650+ debug ! ( "deduce_future_output_from_obligations(expr_def_id={:?})" , expr_def_id) ;
651+
652+ let ret_coercion =
653+ self . ret_coercion
654+ . as_ref ( )
655+ . unwrap_or_else ( || span_bug ! (
656+ self . tcx. def_span( expr_def_id) ,
657+ "async fn generator outside of a fn"
658+ ) ) ;
659+
660+ // In practice, the return type of the surrounding function is
661+ // always a (not yet resolved) inference variable, because it
662+ // is the hidden type for an `impl Trait` that we are going to
663+ // be inferring.
664+ let ret_ty = ret_coercion. borrow ( ) . expected_ty ( ) ;
665+ let ret_ty = self . inh . infcx . shallow_resolve ( ret_ty) ;
666+ let ret_vid = match ret_ty. kind {
667+ ty:: Infer ( ty:: TyVar ( ret_vid) ) => ret_vid,
668+ _ => {
669+ span_bug ! (
670+ self . tcx. def_span( expr_def_id) ,
671+ "async fn generator return type not an inference variable"
672+ )
673+ }
674+ } ;
675+
676+ // Search for a pending obligation like
677+ //
678+ // `<R as Future>::Output = T`
679+ //
680+ // where R is the return type we are expecting. This type `T`
681+ // will be our output.
682+ let output_ty = self . obligations_for_self_ty ( ret_vid)
683+ . find_map ( |( _, obligation) | {
684+ if let ty:: Predicate :: Projection ( ref proj_predicate) = obligation. predicate {
685+ self . deduce_future_output_from_projection (
686+ obligation. cause . span ,
687+ proj_predicate
688+ )
689+ } else {
690+ None
691+ }
692+ } )
693+ . unwrap ( ) ;
694+
695+ debug ! ( "deduce_future_output_from_obligations: output_ty={:?}" , output_ty) ;
696+ output_ty
697+ }
698+
699+ /// Given a projection like
700+ ///
701+ /// `<X as Future>::Output = T`
702+ ///
703+ /// where `X` is some type that has no late-bound regions, returns
704+ /// `Some(T)`. If the projection is for some other trait, returns
705+ /// `None`.
706+ fn deduce_future_output_from_projection (
707+ & self ,
708+ cause_span : Span ,
709+ predicate : & ty:: PolyProjectionPredicate < ' tcx > ,
710+ ) -> Option < Ty < ' tcx > > {
711+ debug ! ( "deduce_future_output_from_projection(predicate={:?})" , predicate) ;
712+
713+ // We do not expect any bound regions in our predicate, so
714+ // skip past the bound vars.
715+ let predicate = match predicate. no_bound_vars ( ) {
716+ Some ( p) => p,
717+ None => {
718+ debug ! ( "deduce_future_output_from_projection: has late-bound regions" ) ;
719+ return None ;
720+ }
721+ } ;
722+
723+ // Check that this is a projection from the `Future` trait.
724+ let trait_ref = predicate. projection_ty . trait_ref ( self . tcx ) ;
725+ let future_trait = self . tcx . lang_items ( ) . future_trait ( ) . unwrap ( ) ;
726+ if trait_ref. def_id != future_trait {
727+ debug ! ( "deduce_future_output_from_projection: not a future" ) ;
728+ return None ;
729+ }
730+
731+ // The `Future` trait has only one associted item, `Output`,
732+ // so check that this is what we see.
733+ let output_assoc_item = self . tcx . associated_items ( future_trait) . nth ( 0 ) . unwrap ( ) . def_id ;
734+ if output_assoc_item != predicate. projection_ty . item_def_id {
735+ span_bug ! (
736+ cause_span,
737+ "projecting associated item `{:?}` from future, which is not Output `{:?}`" ,
738+ predicate. projection_ty. item_def_id,
739+ output_assoc_item,
740+ ) ;
741+ }
742+
743+ // Extract the type from the projection. Note that there can
744+ // be no bound variables in this type because the "self type"
745+ // does not have any regions in it.
746+ let output_ty = self . resolve_vars_if_possible ( & predicate. ty ) ;
747+ debug ! ( "deduce_future_output_from_projection: output_ty={:?}" , output_ty) ;
748+ Some ( output_ty)
749+ }
750+
623751 /// Converts the types that the user supplied, in case that doing
624752 /// so should yield an error, but returns back a signature where
625753 /// all parameters are of type `TyErr`.
0 commit comments