@@ -7513,46 +7513,36 @@ void Sema::checkCall(NamedDecl *FDecl, const FunctionProtoType *Proto,
75137513 }
75147514 }
75157515
7516- // If the callee uses AArch64 SME ZA state but the caller doesn't define
7517- // any, then this is an error.
7518- FunctionType::ArmStateValue ArmZAState =
7516+ FunctionType::ArmStateValue CalleeArmZAState =
75197517 FunctionType::getArmZAState(ExtInfo.AArch64SMEAttributes);
7520- if (ArmZAState != FunctionType::ARM_None) {
7518+ FunctionType::ArmStateValue CalleeArmZT0State =
7519+ FunctionType::getArmZT0State(ExtInfo.AArch64SMEAttributes);
7520+ if (CalleeArmZAState != FunctionType::ARM_None ||
7521+ CalleeArmZT0State != FunctionType::ARM_None) {
75217522 bool CallerHasZAState = false;
7523+ bool CallerHasZT0State = false;
75227524 if (const auto *CallerFD = dyn_cast<FunctionDecl>(CurContext)) {
75237525 auto *Attr = CallerFD->getAttr<ArmNewAttr>();
75247526 if (Attr && Attr->isNewZA())
75257527 CallerHasZAState = true;
7526- else if (const auto *FPT =
7527- CallerFD->getType()->getAs<FunctionProtoType>())
7528- CallerHasZAState = FunctionType::getArmZAState(
7529- FPT->getExtProtoInfo().AArch64SMEAttributes) !=
7530- FunctionType::ARM_None;
7531- }
7532-
7533- if (!CallerHasZAState)
7534- Diag(Loc, diag::err_sme_za_call_no_za_state);
7535- }
7536-
7537- // If the callee uses AArch64 SME ZT0 state but the caller doesn't define
7538- // any, then this is an error.
7539- FunctionType::ArmStateValue ArmZT0State =
7540- FunctionType::getArmZT0State(ExtInfo.AArch64SMEAttributes);
7541- if (ArmZT0State != FunctionType::ARM_None) {
7542- bool CallerHasZT0State = false;
7543- if (const auto *CallerFD = dyn_cast<FunctionDecl>(CurContext)) {
7544- auto *Attr = CallerFD->getAttr<ArmNewAttr>();
75457528 if (Attr && Attr->isNewZT0())
75467529 CallerHasZT0State = true;
7547- else if (const auto *FPT =
7548- CallerFD->getType()->getAs<FunctionProtoType>())
7549- CallerHasZT0State =
7530+ if (const auto *FPT = CallerFD->getType()->getAs<FunctionProtoType>()) {
7531+ CallerHasZAState |=
7532+ FunctionType::getArmZAState(
7533+ FPT->getExtProtoInfo().AArch64SMEAttributes) !=
7534+ FunctionType::ARM_None;
7535+ CallerHasZT0State |=
75507536 FunctionType::getArmZT0State(
75517537 FPT->getExtProtoInfo().AArch64SMEAttributes) !=
75527538 FunctionType::ARM_None;
7539+ }
75537540 }
75547541
7555- if (!CallerHasZT0State)
7542+ if (CalleeArmZAState != FunctionType::ARM_None && !CallerHasZAState)
7543+ Diag(Loc, diag::err_sme_za_call_no_za_state);
7544+
7545+ if (CalleeArmZT0State != FunctionType::ARM_None && !CallerHasZT0State)
75567546 Diag(Loc, diag::err_sme_zt0_call_no_zt0_state);
75577547 }
75587548 }
0 commit comments