Skip to content

Commit 6a9d864

Browse files
committed
[AArch64][SME] Handle zeroing ZA and ZT0 in functions with ZT0 state
In the MachineSMEABIPass, if we have a function with ZT0 state, then there are some additional cases where we need to zero ZA and ZT0. If the function has a private ZA interface, i.e., new ZT0 (and new ZA if present). Then ZT0/ZA must be zeroed when committing the incoming ZA save. If the function has a shared ZA interface, e.g. new ZA and shared ZT0. Then ZA must be zeroed on function entry (without a ZA save commit). The logic in the ABI pass has been reworked to use an "ENTRY" state to handle this (rather than the more specific "CALLER_DORMANT" state). Change-Id: Ib91e9b13ffa4752320fe6a7a720afe919cf00198
1 parent 3d5d32c commit 6a9d864

File tree

3 files changed

+68
-69
lines changed

3 files changed

+68
-69
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8802,15 +8802,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
88028802
}
88038803
}
88048804

8805-
if (getTM().useNewSMEABILowering()) {
8806-
// Clear new ZT0 state. TODO: Move this to the SME ABI pass.
8807-
if (Attrs.isNewZT0())
8808-
Chain = DAG.getNode(
8809-
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8810-
DAG.getTargetConstant(Intrinsic::aarch64_sme_zero_zt, DL, MVT::i32),
8811-
DAG.getTargetConstant(0, DL, MVT::i32));
8812-
}
8813-
88148805
return Chain;
88158806
}
88168807

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ enum ZAState {
8282
// A ZA save has been set up or committed (i.e. ZA is dormant or off)
8383
LOCAL_SAVED,
8484

85-
// ZA is off or a lazy save has been set up by the caller
86-
CALLER_DORMANT,
85+
// The ZA/ZT0 state on entry to the function.
86+
ENTRY,
8787

8888
// ZA is off
8989
OFF,
@@ -200,7 +200,7 @@ StringRef getZAStateString(ZAState State) {
200200
MAKE_CASE(ZAState::ANY)
201201
MAKE_CASE(ZAState::ACTIVE)
202202
MAKE_CASE(ZAState::LOCAL_SAVED)
203-
MAKE_CASE(ZAState::CALLER_DORMANT)
203+
MAKE_CASE(ZAState::ENTRY)
204204
MAKE_CASE(ZAState::OFF)
205205
default:
206206
llvm_unreachable("Unexpected ZAState");
@@ -281,8 +281,8 @@ struct MachineSMEABI : public MachineFunctionPass {
281281
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
282282

283283
// Emission routines for private and shared ZA functions (using lazy saves).
284-
void emitNewZAPrologue(MachineBasicBlock &MBB,
285-
MachineBasicBlock::iterator MBBI);
284+
void emitSMEPrologue(MachineBasicBlock &MBB,
285+
MachineBasicBlock::iterator MBBI);
286286
void emitRestoreLazySave(EmitContext &, MachineBasicBlock &MBB,
287287
MachineBasicBlock::iterator MBBI,
288288
LiveRegs PhysLiveRegs);
@@ -406,9 +406,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
406406

407407
if (MBB.isEntryBlock()) {
408408
// Entry block:
409-
Block.FixedEntryState = SMEFnAttrs.hasPrivateZAInterface()
410-
? ZAState::CALLER_DORMANT
411-
: ZAState::ACTIVE;
409+
Block.FixedEntryState = ZAState::ENTRY;
412410
} else if (MBB.isEHPad()) {
413411
// EH entry block:
414412
Block.FixedEntryState = ZAState::LOCAL_SAVED;
@@ -825,32 +823,49 @@ void MachineSMEABI::emitAllocateLazySaveBuffer(
825823
}
826824
}
827825

828-
void MachineSMEABI::emitNewZAPrologue(MachineBasicBlock &MBB,
829-
MachineBasicBlock::iterator MBBI) {
826+
static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
827+
828+
void MachineSMEABI::emitSMEPrologue(MachineBasicBlock &MBB,
829+
MachineBasicBlock::iterator MBBI) {
830830
auto *TLI = Subtarget->getTargetLowering();
831831
DebugLoc DL = getDebugLoc(MBB, MBBI);
832832

833-
// Get current TPIDR2_EL0.
834-
Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
835-
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
836-
.addReg(TPIDR2EL0, RegState::Define)
837-
.addImm(AArch64SysReg::TPIDR2_EL0);
838-
// If TPIDR2_EL0 is non-zero, commit the lazy save.
839-
// NOTE: Functions that only use ZT0 don't need to zero ZA.
840-
bool ZeroZA = AFI->getSMEFnAttrs().hasZAState();
841-
auto CommitZASave =
842-
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
843-
.addReg(TPIDR2EL0)
844-
.addImm(ZeroZA ? 1 : 0)
845-
.addImm(/*ZeroZT0=*/false)
846-
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
847-
.addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
848-
if (ZeroZA)
849-
CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
850-
// Enable ZA (as ZA could have previously been in the OFF state).
851-
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
852-
.addImm(AArch64SVCR::SVCRZA)
853-
.addImm(1);
833+
bool ZeroZA = AFI->getSMEFnAttrs().isNewZA();
834+
bool ZeroZT0 = AFI->getSMEFnAttrs().isNewZT0();
835+
if (AFI->getSMEFnAttrs().hasPrivateZAInterface()) {
836+
// Get current TPIDR2_EL0.
837+
Register TPIDR2EL0 = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
838+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MRS))
839+
.addReg(TPIDR2EL0, RegState::Define)
840+
.addImm(AArch64SysReg::TPIDR2_EL0);
841+
// If TPIDR2_EL0 is non-zero, commit the lazy save.
842+
// NOTE: Functions that only use ZT0 don't need to zero ZA.
843+
auto CommitZASave =
844+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
845+
.addReg(TPIDR2EL0)
846+
.addImm(ZeroZA)
847+
.addImm(ZeroZT0)
848+
.addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_TPIDR2_SAVE))
849+
.addRegMask(TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
850+
if (ZeroZA)
851+
CommitZASave.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
852+
if (ZeroZT0)
853+
CommitZASave.addDef(AArch64::ZT0, RegState::ImplicitDefine);
854+
// Enable ZA (as ZA could have previously been in the OFF state).
855+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
856+
.addImm(AArch64SVCR::SVCRZA)
857+
.addImm(1);
858+
} else if (AFI->getSMEFnAttrs().hasSharedZAInterface()) {
859+
if (ZeroZA) {
860+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_M))
861+
.addImm(ZERO_ALL_ZA_MASK)
862+
.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
863+
}
864+
if (ZeroZT0) {
865+
DebugLoc DL = getDebugLoc(MBB, MBBI);
866+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ZERO_T)).addDef(AArch64::ZT0);
867+
}
868+
}
854869
}
855870

856871
void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
@@ -932,19 +947,19 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
932947
if (From == ZAState::ANY || To == ZAState::ANY)
933948
return;
934949

935-
// If we're exiting from the CALLER_DORMANT state that means this new ZA
936-
// function did not touch ZA (so ZA was never turned on).
937-
if (From == ZAState::CALLER_DORMANT && To == ZAState::OFF)
950+
// If we're exiting from the ENTRY state that means that the function has not
951+
// used ZA, so in the case of private ZA/ZT0 functions we can omit any set up.
952+
if (From == ZAState::ENTRY && To == ZAState::OFF)
938953
return;
939954

955+
SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
956+
940957
// TODO: Avoid setting up the save buffer if there's no transition to
941958
// LOCAL_SAVED.
942-
if (From == ZAState::CALLER_DORMANT) {
943-
assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() &&
944-
"CALLER_DORMANT state requires private ZA interface");
959+
if (From == ZAState::ENTRY) {
945960
assert(&MBB == &MBB.getParent()->front() &&
946-
"CALLER_DORMANT state only valid in entry block");
947-
emitNewZAPrologue(MBB, MBB.getFirstNonPHI());
961+
"ENTRY state only valid in entry block");
962+
emitSMEPrologue(MBB, MBB.getFirstNonPHI());
948963
if (To == ZAState::ACTIVE)
949964
return; // Nothing more to do (ZA is active after the prologue).
950965

@@ -959,9 +974,9 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
959974
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
960975
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
961976
else if (To == ZAState::OFF) {
962-
assert(From != ZAState::CALLER_DORMANT &&
963-
"CALLER_DORMANT to OFF should have already been handled");
964-
assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() &&
977+
assert(From != ZAState::ENTRY &&
978+
"ENTRY to OFF should have already been handled");
979+
assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
965980
"Should not turn ZA off in agnostic ZA function");
966981
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
967982
} else {

llvm/test/CodeGen/AArch64/sme-zt0-state.ll

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ define void @zt0_new_caller_zt0_new_callee(ptr %callee) "aarch64_new_zt0" nounwi
199199
; CHECK-NEWLOWERING-NEXT: // %bb.1:
200200
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
201201
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
202+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
202203
; CHECK-NEWLOWERING-NEXT: .LBB6_2:
203204
; CHECK-NEWLOWERING-NEXT: smstart za
204-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
205205
; CHECK-NEWLOWERING-NEXT: mov x19, sp
206206
; CHECK-NEWLOWERING-NEXT: str zt0, [x19]
207207
; CHECK-NEWLOWERING-NEXT: smstop za
@@ -252,9 +252,9 @@ define i64 @zt0_new_caller_abi_routine_callee() "aarch64_new_zt0" nounwind {
252252
; CHECK-NEWLOWERING-NEXT: // %bb.1:
253253
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
254254
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
255+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
255256
; CHECK-NEWLOWERING-NEXT: .LBB7_2:
256257
; CHECK-NEWLOWERING-NEXT: smstart za
257-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
258258
; CHECK-NEWLOWERING-NEXT: mov x19, sp
259259
; CHECK-NEWLOWERING-NEXT: str zt0, [x19]
260260
; CHECK-NEWLOWERING-NEXT: bl __arm_sme_state
@@ -302,9 +302,9 @@ define void @zt0_new_caller(ptr %callee) "aarch64_new_zt0" nounwind {
302302
; CHECK-NEWLOWERING-NEXT: // %bb.1:
303303
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
304304
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
305+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
305306
; CHECK-NEWLOWERING-NEXT: .LBB8_2:
306307
; CHECK-NEWLOWERING-NEXT: smstart za
307-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
308308
; CHECK-NEWLOWERING-NEXT: blr x0
309309
; CHECK-NEWLOWERING-NEXT: smstop za
310310
; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -343,9 +343,9 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n
343343
; CHECK-NEWLOWERING-NEXT: bl __arm_tpidr2_save
344344
; CHECK-NEWLOWERING-NEXT: msr TPIDR2_EL0, xzr
345345
; CHECK-NEWLOWERING-NEXT: zero {za}
346+
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
346347
; CHECK-NEWLOWERING-NEXT: .LBB9_2:
347348
; CHECK-NEWLOWERING-NEXT: smstart za
348-
; CHECK-NEWLOWERING-NEXT: zero { zt0 }
349349
; CHECK-NEWLOWERING-NEXT: blr x0
350350
; CHECK-NEWLOWERING-NEXT: smstop za
351351
; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
@@ -356,20 +356,13 @@ define void @new_za_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_new_zt0" n
356356

357357
; Expect clear ZA on entry
358358
define void @new_za_shared_zt0_caller(ptr %callee) "aarch64_new_za" "aarch64_in_zt0" nounwind {
359-
; CHECK-LABEL: new_za_shared_zt0_caller:
360-
; CHECK: // %bb.0:
361-
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
362-
; CHECK-NEXT: zero {za}
363-
; CHECK-NEXT: blr x0
364-
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
365-
; CHECK-NEXT: ret
366-
;
367-
; CHECK-NEWLOWERING-LABEL: new_za_shared_zt0_caller:
368-
; CHECK-NEWLOWERING: // %bb.0:
369-
; CHECK-NEWLOWERING-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
370-
; CHECK-NEWLOWERING-NEXT: blr x0
371-
; CHECK-NEWLOWERING-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
372-
; CHECK-NEWLOWERING-NEXT: ret
359+
; CHECK-COMMON-LABEL: new_za_shared_zt0_caller:
360+
; CHECK-COMMON: // %bb.0:
361+
; CHECK-COMMON-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
362+
; CHECK-COMMON-NEXT: zero {za}
363+
; CHECK-COMMON-NEXT: blr x0
364+
; CHECK-COMMON-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
365+
; CHECK-COMMON-NEXT: ret
373366
call void %callee() "aarch64_inout_za" "aarch64_in_zt0";
374367
ret void;
375368
}

0 commit comments

Comments
 (0)