Skip to content

Conversation

@MacDue
Copy link
Member

@MacDue MacDue commented Nov 4, 2025

In the MachineSMEABIPass, if we have a function with ZT0 state, then there are some additional cases where we need to zero ZA and ZT0.

If the function has a private ZA interface, i.e., new ZT0 (and new ZA if present). Then ZT0/ZA must be zeroed when committing the incoming ZA save.

If the function has a shared ZA interface, e.g. new ZA and shared ZT0. Then ZA must be zeroed on function entry (without a ZA save commit).

The logic in the ABI pass has been reworked to use an "ENTRY" state to handle this (rather than the more specific "CALLER_DORMANT" state).

In the MachineSMEABIPass, if we have a function with ZT0 state, then
there are some additional cases where we need to zero ZA and ZT0.

If the function has a private ZA interface, i.e., new ZT0 (and new ZA if
present). Then ZT0/ZA must be zeroed when committing the incoming ZA
save.

If the function has a shared ZA interface, e.g. new ZA and shared ZT0.
Then ZA must be zeroed on function entry (without a ZA save commit).

The logic in the ABI pass has been reworked to use an "ENTRY" state to
handle this (rather than the more specific "CALLER_DORMANT" state).

Change-Id: Ib91e9b13ffa4752320fe6a7a720afe919cf00198
@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Benjamin Maxwell (MacDue)

Changes

In the MachineSMEABIPass, if we have a function with ZT0 state, then there are some additional cases where we need to zero ZA and ZT0.

If the function has a private ZA interface, i.e., new ZT0 (and new ZA if present). Then ZT0/ZA must be zeroed when committing the incoming ZA save.

If the function has a shared ZA interface, e.g. new ZA and shared ZT0. Then ZA must be zeroed on function entry (without a ZA save commit).

The logic in the ABI pass has been reworked to use an "ENTRY" state to handle this (rather than the more specific "CALLER_DORMANT" state).


Full diff: https://github.com/llvm/llvm-project/pull/166361.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (-9)
  • (modified) llvm/lib/Target/AArch64/MachineSMEABIPass.cpp (+57-42)
  • (modified) llvm/test/CodeGen/AArch64/sme-zt0-state.ll (+11-18)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 60aa61e993b26..30f961043e78b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8735,15 +8735,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
     }
   }
 
-  if (getTM().useNewSMEABILowering()) {
-    // Clear new ZT0 state. TODO: Move this to the SME ABI pass.
-    if (Attrs.isNewZT0())
-      Chain = DAG.getNode(
-          ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
-          DAG.getConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
-          DAG.getTargetConstant(0, DL, MVT::i32));
-  }
-
   return Chain;
 }
 
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index 8f9aae944ad6d..bb4dfe8c60904 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -82,8 +82,8 @@ enum ZAState {
   // A ZA save has been set up or committed (i.e. ZA is dormant or off)
   LOCAL_SAVED,
 
-  // ZA is off or a lazy save has been set up by the caller
-  CALLER_DORMANT,
+  // The ZA/ZT0 state on entry to the function.
+  ENTRY,
 
   // ZA is off
   OFF,
@@ -200,7 +200,7 @@ StringRef getZAStateString(ZAState State) {
     MAKE_CASE(ZAState::ANY)
     MAKE_CASE(ZAState::ACTIVE)
     MAKE_CASE(ZAState::LOCAL_SAVED)
-    MAKE_CASE(ZAState::CALLER_DORMANT)
+    MAKE_CASE(ZAState::ENTRY)
     MAKE_CASE(ZAState::OFF)
   default:
     llvm_unreachable("Unexpected ZAState");
@@ -281,8 +281,8 @@ struct MachineSMEABI : public MachineFunctionPass {
   void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
 
   // Emission routines for private and shared ZA functions (using lazy saves).
-  void emitNewZAPrologue(MachineBasicBlock &MBB,
-                         MachineBasicBlock::iterator MBBI);
+  void emitSMEPrologue(MachineBasicBlock &MBB,
+                       MachineBasicBlock::iterator MBBI);
   void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
                            MachineBasicBlock::iterator MBBI,
                            LiveRegs PhysLiveRegs);
@@ -395,9 +395,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
 
     if (MBB.isEntryBlock()) {
       // Entry block:
-      Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
-                                  ? ZAState::CALLER_DORMANT
-                                  : ZAState::ACTIVE;
+      Block.FixedEntryState = ZAState::ENTRY;
     } else if (MBB.isEHPad()) {
       // EH entry block:
       Block.FixedEntryState = ZAState::LOCAL_SAVED;
@@ -815,32 +813,49 @@ void MachineSMEABI::emitAllocateLazySaveBuffer(
   }
 }
 
-void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
-                                      MachineBasicBlock::iterator MBBI) {
+static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
+
+void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
+                                    MachineBasicBlock::iterator MBBI) {
   auto *TLI = Subtarget->getTargetLowering();
   DebugLoc DL = getDebugLoc(MBB, MBBI);
 
-  // Get current TPIDR2_EL0.
-  Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
-  BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
-      .addReg(TPIDR2EL0, RegState::Define)
-      .addImm(AArch64SysReg::TPIDR2_EL0);
-  // If TPIDR2_EL0 is non-zero, commit the lazy save.
-  // NOTE: Functions that only use ZT0 don't need to zero ZA.
-  bool ZeroZA = AFI->getSMEFnAttrs().hasZAState();
-  auto CommitZASave =
-      BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
-          .addReg(TPIDR2EL0)
-          .addImm(ZeroZA ? 1 : 0)
-          .addImm(/*ZeroZT0=*/false)
-          .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
-          .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
-  if (ZeroZA)
-    CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
-  // Enable ZA (as ZA could have previously been in the OFF state).
-  BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
-      .addImm(AArch64SVCR::SVCRZA)
-      .addImm(1);
+  bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
+  bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
+  if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) {
+    // Get current TPIDR2_EL0.
+    Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
+        .addReg(TPIDR2EL0, RegState::Define)
+        .addImm(AArch64SysReg::TPIDR2_EL0);
+    // If TPIDR2_EL0 is non-zero, commit the lazy save.
+    // NOTE: Functions that only use ZT0 don't need to zero ZA.
+    auto CommitZASave =
+        BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
+            .addReg(TPIDR2EL0)
+            .addImm(ZeroZA)
+            .addImm(ZeroZT0)
+            .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
+            .addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
+    if (ZeroZA)
+      CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
+    if (ZeroZT0)
+      CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
+    // Enable ZA (as ZA could have previously been in the OFF state).
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
+        .addImm(AArch64SVCR::SVCRZA)
+        .addImm(1);
+  } else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
+    if (ZeroZA) {
+      BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
+          .addImm(ZERO_ALL_ZA_MASK)
+          .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
+    }
+    if (ZeroZT0) {
+      DebugLoc DL = getDebugLoc(MBB, MBBI);
+      BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
+    }
+  }
 }
 
 void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
@@ -922,19 +937,19 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
   if (From == ZAState::ANY || To == ZAState::ANY)
     return;
 
-  // If we're exiting from the CALLER_DORMANT state that means this new ZA
-  // function did not touch ZA (so ZA was never turned on).
-  if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF)
+  // If we're exiting from the ENTRY state that means that the function has not
+  // used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
+  if (From == ZAState::ENTRY && To == ZAState::OFF)
     return;
 
+  SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
+
   // TODO: Avoid setting up the save buffer if there's no transition to
   // LOCAL_SAVED.
-  if (From == ZAState::CALLER_DORMANT) {
-    assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() &&
-           "CALLER_DORMANT state requires private ZA interface");
+  if (From == ZAState::ENTRY) {
     assert(&MBB == &MBB.getParent()->front() &&
-           "CALLER_DORMANT state only valid in entry block");
-    emitNewZAPrologue(MBB, MBB.getFirstNonPHI());
+           "ENTRY state only valid in entry block");
+    emitSMEPrologue(MBB, MBB.getFirstNonPHI());
     if (To == ZAState::ACTIVE)
       return; // Nothing more to do (ZA is active after the prologue).
 
@@ -949,9 +964,9 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
   else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
     emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
   else if (To == ZAState::OFF) {
-    assert(From != ZAState::CALLER_DORMANT &&
-           "CALLER_DORMANT to OFF should have already been handled");
-    assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() &&
+    assert(From != ZAState::ENTRY &&
+           "ENTRY to OFF should have already been handled");
+    assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
            "Should not turn ZA off in agnostic ZA function");
     emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
   } else {
diff --git a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
index 5b81f5dafe421..4c48e41294a3a 100644
--- a/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
+++ b/llvm/test/CodeGen/AArch64/sme-zt0-state.ll
@@ -199,9 +199,9 @@ define void @zt0_new_caller_zt0_new_callee(ptr %callee) "aarch64_new_zt0" nounwi
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1:
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB6_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    mov x19, sp
 ; CHECK-NEWLOWERING-NEXT:    str zt0, [x19]
 ; CHECK-NEWLOWERING-NEXT:    smstop za
@@ -252,9 +252,9 @@ define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1:
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB7_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    mov x19, sp
 ; CHECK-NEWLOWERING-NEXT:    str zt0, [x19]
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state
@@ -302,9 +302,9 @@ define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" nounwind {
 ; CHECK-NEWLOWERING-NEXT:  // %bb.1:
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB8_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    blr x0
 ; CHECK-NEWLOWERING-NEXT:    smstop za
 ; CHECK-NEWLOWERING-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -343,9 +343,9 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n
 ; CHECK-NEWLOWERING-NEXT:    bl __arm_tpidr2_save
 ; CHECK-NEWLOWERING-NEXT:    msr TPIDR2_EL0, xzr
 ; CHECK-NEWLOWERING-NEXT:    zero {za}
+; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:  .LBB9_2:
 ; CHECK-NEWLOWERING-NEXT:    smstart za
-; CHECK-NEWLOWERING-NEXT:    zero { zt0 }
 ; CHECK-NEWLOWERING-NEXT:    blr x0
 ; CHECK-NEWLOWERING-NEXT:    smstop za
 ; CHECK-NEWLOWERING-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -356,20 +356,13 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n
 
 ; Expect clear ZA on entry
 define void @new_za_shared_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_in_zt0" nounwind {
-; CHECK-LABEL: new_za_shared_zt0_caller:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    zero {za}
-; CHECK-NEXT:    blr x0
-; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
-; CHECK-NEXT:    ret
-;
-; CHECK-NEWLOWERING-LABEL: new_za_shared_zt0_caller:
-; CHECK-NEWLOWERING:       // %bb.0:
-; CHECK-NEWLOWERING-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEWLOWERING-NEXT:    blr x0
-; CHECK-NEWLOWERING-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
-; CHECK-NEWLOWERING-NEXT:    ret
+; CHECK-COMMON-LABEL: new_za_shared_zt0_caller:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-COMMON-NEXT:    zero {za}
+; CHECK-COMMON-NEXT:    blr x0
+; CHECK-COMMON-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-COMMON-NEXT:    ret
   call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
   ret void;
 }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants