@@ -458,6 +458,25 @@ swift::rewriting::desugarRequirement(Requirement req, SourceLoc loc,
458458 }
459459}
460460
461+ void swift::rewriting::desugarRequirements (SmallVector<StructuralRequirement, 2 > &reqs,
462+ SmallVectorImpl<RequirementError> &errors) {
463+ SmallVector<StructuralRequirement, 2 > result;
464+ for (auto req : reqs) {
465+ SmallVector<Requirement, 2 > desugaredReqs;
466+ SmallVector<RequirementError, 2 > ignoredErrors;
467+
468+ if (req.inferred )
469+ desugarRequirement (req.req , SourceLoc (), desugaredReqs, ignoredErrors);
470+ else
471+ desugarRequirement (req.req , req.loc , desugaredReqs, errors);
472+
473+ for (auto desugaredReq : desugaredReqs)
474+ result.push_back ({desugaredReq, req.loc , req.inferred });
475+ }
476+
477+ std::swap (reqs, result);
478+ }
479+
461480//
462481// Requirement realization and inference.
463482//
@@ -467,8 +486,6 @@ static void realizeTypeRequirement(DeclContext *dc,
467486 SourceLoc loc,
468487 SmallVectorImpl<StructuralRequirement> &result,
469488 SmallVectorImpl<RequirementError> &errors) {
470- SmallVector<Requirement, 2 > reqs;
471-
472489 // The GenericSignatureBuilder allowed the right hand side of a
473490 // conformance or superclass requirement to reference a protocol
474491 // typealias whose underlying type was a protocol or class.
@@ -497,22 +514,19 @@ static void realizeTypeRequirement(DeclContext *dc,
497514 }
498515
499516 if (constraintType->isConstraintType ()) {
500- Requirement req (RequirementKind::Conformance, subjectType, constraintType);
501- desugarRequirement (req, loc, reqs, errors);
517+ result.push_back ({Requirement (RequirementKind::Conformance,
518+ subjectType, constraintType),
519+ loc, /* wasInferred=*/ false });
502520 } else if (constraintType->getClassOrBoundGenericClass ()) {
503- Requirement req (RequirementKind::Superclass, subjectType, constraintType);
504- desugarRequirement (req, loc, reqs, errors);
521+ result.push_back ({Requirement (RequirementKind::Superclass,
522+ subjectType, constraintType),
523+ loc, /* wasInferred=*/ false });
505524 } else {
506525 errors.push_back (
507526 RequirementError::forInvalidTypeRequirement (subjectType,
508527 constraintType,
509528 loc));
510- return ;
511529 }
512-
513- // Add source location information.
514- for (auto req : reqs)
515- result.push_back ({req, loc, /* wasInferred=*/ false });
516530}
517531
518532namespace {
@@ -521,11 +535,11 @@ namespace {
521535struct InferRequirementsWalker : public TypeWalker {
522536 ModuleDecl *module ;
523537 DeclContext *dc;
524- SmallVector<Requirement, 2 > reqs;
525- SmallVector<RequirementError, 2 > errors;
538+ SmallVectorImpl<StructuralRequirement> &reqs;
526539
527- explicit InferRequirementsWalker (ModuleDecl *module , DeclContext *dc)
528- : module(module ), dc(dc) {}
540+ explicit InferRequirementsWalker (ModuleDecl *module , DeclContext *dc,
541+ SmallVectorImpl<StructuralRequirement> &reqs)
542+ : module(module ), dc(dc), reqs(reqs) {}
529543
530544 Action walkToTypePre (Type ty) override {
531545 // Unbound generic types are the result of recovered-but-invalid code, and
@@ -555,8 +569,7 @@ struct InferRequirementsWalker : public TypeWalker {
555569 return false ;
556570
557571 return (req.getKind () == RequirementKind::Conformance &&
558- req.getSecondType ()->castTo <ProtocolType>()->getDecl ()
559- ->isSpecificProtocol (KnownProtocolKind::Sendable));
572+ req.getProtocolDecl ()->isSpecificProtocol (KnownProtocolKind::Sendable));
560573 };
561574
562575 // Infer from generic typealiases.
@@ -567,7 +580,7 @@ struct InferRequirementsWalker : public TypeWalker {
567580 if (skipRequirement (rawReq, decl))
568581 continue ;
569582
570- desugarRequirement ( rawReq.subst (subMap), SourceLoc (), reqs, errors );
583+ reqs. push_back ({ rawReq.subst (subMap), SourceLoc (), /* inferred= */ true } );
571584 }
572585
573586 return Action::Continue;
@@ -581,10 +594,9 @@ struct InferRequirementsWalker : public TypeWalker {
581594 packExpansion->getPatternType ()->getTypeParameterPacks (packReferences);
582595
583596 auto countType = packExpansion->getCountType ();
584- for (auto pack : packReferences) {
585- Requirement req (RequirementKind::SameShape, countType, pack);
586- desugarRequirement (req, SourceLoc (), reqs, errors);
587- }
597+ for (auto pack : packReferences)
598+ reqs.push_back ({Requirement (RequirementKind::SameShape, countType, pack),
599+ SourceLoc (), /* inferred=*/ true });
588600 }
589601
590602 // Infer requirements from `@differentiable` function types.
@@ -596,9 +608,9 @@ struct InferRequirementsWalker : public TypeWalker {
596608 if (auto *fnTy = ty->getAs <AnyFunctionType>()) {
597609 // Add a new conformance constraint for a fixed protocol.
598610 auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
599- Requirement req (RequirementKind::Conformance, type,
600- protocol->getDeclaredInterfaceType ());
601- desugarRequirement (req, SourceLoc (), reqs, errors );
611+ reqs. push_back ({ Requirement (RequirementKind::Conformance, type,
612+ protocol->getDeclaredInterfaceType ()),
613+ SourceLoc (), /* inferred= */ true } );
602614 };
603615
604616 auto &ctx = module ->getASTContext ();
@@ -610,8 +622,9 @@ struct InferRequirementsWalker : public TypeWalker {
610622 auto secondType = assocType->getDeclaredInterfaceType ()
611623 ->castTo <DependentMemberType>()
612624 ->substBaseType (module , firstType);
613- Requirement req (RequirementKind::SameType, firstType, secondType);
614- desugarRequirement (req, SourceLoc (), reqs, errors);
625+ reqs.push_back ({Requirement (RequirementKind::SameType,
626+ firstType, secondType),
627+ SourceLoc (), /* inferred=*/ true });
615628 };
616629 auto *tangentVectorAssocType =
617630 differentiableProtocol->getAssociatedType (ctx.Id_TangentVector );
@@ -659,8 +672,7 @@ struct InferRequirementsWalker : public TypeWalker {
659672 if (skipRequirement (rawReq, decl))
660673 continue ;
661674
662- auto req = rawReq.subst (subMap);
663- desugarRequirement (req, SourceLoc (), reqs, errors);
675+ reqs.push_back ({rawReq.subst (subMap), SourceLoc (), /* inferred=*/ true });
664676 }
665677
666678 return Action::Continue;
@@ -683,15 +695,12 @@ void swift::rewriting::inferRequirements(
683695 if (!type)
684696 return ;
685697
686- InferRequirementsWalker walker (module , dc);
698+ InferRequirementsWalker walker (module , dc, result );
687699 type.walk (walker);
688-
689- for (const auto &req : walker.reqs )
690- result.push_back ({req, loc, /* wasInferred=*/ true });
691700}
692701
693- // / Desugar a requirement and perform requirement inference if requested
694- // / to obtain zero or more structural requirements .
702+ // / Perform requirement inference from the type representations in the
703+ // / requirement itself (eg, `T == Set<U>` infers `U: Hashable`) .
695704void swift::rewriting::realizeRequirement (
696705 DeclContext *dc,
697706 Requirement req, RequirementRepr *reqRepr,
@@ -732,12 +741,7 @@ void swift::rewriting::realizeRequirement(
732741 inferRequirements (firstType, firstLoc, moduleForInference, dc, result);
733742 }
734743
735- SmallVector<Requirement, 2 > reqs;
736- desugarRequirement (req, loc, reqs, errors);
737-
738- for (auto req : reqs)
739- result.push_back ({req, loc, /* wasInferred=*/ false });
740-
744+ result.push_back ({req, loc, /* wasInferred=*/ false });
741745 break ;
742746 }
743747
@@ -754,11 +758,7 @@ void swift::rewriting::realizeRequirement(
754758 inferRequirements (secondType, secondLoc, moduleForInference, dc, result);
755759 }
756760
757- SmallVector<Requirement, 2 > reqs;
758- desugarRequirement (req, loc, reqs, errors);
759-
760- for (auto req : reqs)
761- result.push_back ({req, loc, /* wasInferred=*/ false });
761+ result.push_back ({req, loc, /* wasInferred=*/ false });
762762 break ;
763763 }
764764 }
@@ -903,13 +903,13 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
903903 ProtocolDecl *proto) const {
904904 assert (!proto->hasLazyRequirementSignature ());
905905
906- SmallVector<StructuralRequirement, 4 > result;
907- SmallVector<RequirementError, 4 > errors;
906+ SmallVector<StructuralRequirement, 2 > result;
907+ SmallVector<RequirementError, 2 > errors;
908908
909909 auto &ctx = proto->getASTContext ();
910910 auto selfTy = proto->getSelfInterfaceType ();
911911
912- SmallVector<Type, 4 > needsDefaultReqirements ({selfTy});
912+ SmallVector<Type, 4 > needsDefaultRequirements ({selfTy});
913913
914914 unsigned errorCount = errors.size ();
915915 realizeInheritedRequirements (proto, selfTy,
@@ -950,7 +950,12 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
950950 result.push_back ({Requirement (RequirementKind::Layout, selfTy, layout),
951951 proto->getLoc (), /* inferred=*/ true });
952952
953- expandDefaultRequirements (ctx, needsDefaultReqirements, result, errors);
953+ desugarRequirements (result, errors);
954+ expandDefaultRequirements (ctx, needsDefaultRequirements, result, errors);
955+
956+ diagnoseRequirementErrors (ctx, errors,
957+ AllowConcreteTypePolicy::NestedAssocTypes);
958+
954959 return ctx.AllocateCopy (result);
955960 }
956961
@@ -976,7 +981,7 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
976981 return false ;
977982 });
978983
979- needsDefaultReqirements .push_back (assocType);
984+ needsDefaultRequirements .push_back (assocType);
980985 }
981986
982987 // Add requirements for each typealias.
@@ -1014,7 +1019,8 @@ StructuralRequirementsRequest::evaluate(Evaluator &evaluator,
10141019 }
10151020 }
10161021
1017- expandDefaultRequirements (ctx, needsDefaultReqirements, result, errors);
1022+ desugarRequirements (result, errors);
1023+ expandDefaultRequirements (ctx, needsDefaultRequirements, result, errors);
10181024
10191025 diagnoseRequirementErrors (ctx, errors,
10201026 AllowConcreteTypePolicy::NestedAssocTypes);
0 commit comments