Skip to content

Commit 0c087a1

Browse files
committed
[AArch64][SME] Support saving/restoring ZT0 in the MachineSMEABIPass
This patch extends the MachineSMEABIPass to support ZT0. This is done with the addition of two new states: - `ACTIVE_ZT0_SAVED` * This is used when calling a function that shares ZA, but does share ZT0 (i.e., no ZT0 attributes). * This state indicates ZT0 must be saved to the save slot, but must remain on, with no lazy save setup - `LOCAL_COMMITTED` * This is used for saving ZT0 in functions without ZA state. * This state indicates ZA is off and ZT0 has been saved. * This state is general enough to support ZA, but those have not been implemented† To aid with readability, the state transitions have been reworked to a switch of `transitionFrom(<FromState>).to(<ToState>)`, rather than nested ifs, which helps manage more transitions. † This could be implemented to handle some cases of undefined behavior better. Change-Id: I14be4a7f8b998fe667bfaade5088f88039515f91
1 parent ae3ec41 commit 0c087a1

File tree

7 files changed

+321
-105
lines changed

7 files changed

+321
-105
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,7 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
17171717
}
17181718
case AArch64::InOutZAUsePseudo:
17191719
case AArch64::RequiresZASavePseudo:
1720+
case AArch64::RequiresZT0SavePseudo:
17201721
case AArch64::SMEStateAllocPseudo:
17211722
case AArch64::COALESCER_BARRIER_FPR16:
17221723
case AArch64::COALESCER_BARRIER_FPR32:

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9457,6 +9457,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94579457
if (CallAttrs.requiresLazySave() ||
94589458
CallAttrs.requiresPreservingAllZAState())
94599459
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
9460+
else if (CallAttrs.requiresPreservingZT0())
9461+
ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE;
94609462
else if (CallAttrs.caller().hasZAState() ||
94619463
CallAttrs.caller().hasZT0State())
94629464
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
@@ -9576,7 +9578,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
95769578

95779579
SDValue ZTFrameIdx;
95789580
MachineFrameInfo &MFI = MF.getFrameInfo();
9579-
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
9581+
bool ShouldPreserveZT0 =
9582+
!UseNewSMEABILowering && CallAttrs.requiresPreservingZT0();
95809583

95819584
// If the caller has ZT0 state which will not be preserved by the callee,
95829585
// spill ZT0 before the call.
@@ -9589,7 +9592,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
95899592

95909593
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
95919594
// PSTATE.ZA before the call if there is no lazy-save active.
9592-
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
9595+
bool DisableZA =
9596+
!UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall();
95939597
assert((!DisableZA || !RequiresLazySave) &&
95949598
"Lazy-save should have PSTATE.SM=1 on entry to the function");
95959599

@@ -10074,7 +10078,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
1007410078
getSMToggleCondition(CallAttrs));
1007510079
}
1007610080

10077-
if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
10081+
if (!UseNewSMEABILowering &&
10082+
(RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()))
1007810083
// Unconditionally resume ZA.
1007910084
Result = DAG.getNode(
1008010085
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
102102
let hasSideEffects = 1, isMeta = 1 in {
103103
def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
104104
def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
105+
def RequiresZT0SavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
105106
}
106107

107108
def SMEStateAllocPseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
@@ -122,6 +123,11 @@ def AArch64_requires_za_save
122123
[SDNPHasChain, SDNPInGlue]>;
123124
def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
124125

126+
def AArch64_requires_zt0_save
127+
: SDNode<"AArch64ISD::REQUIRES_ZT0_SAVE", SDTypeProfile<0, 0, []>,
128+
[SDNPHasChain, SDNPInGlue]>;
129+
def : Pat<(AArch64_requires_zt0_save), (RequiresZT0SavePseudo)>;
130+
125131
def AArch64_sme_state_alloc
126132
: SDNode<"AArch64ISD::SME_STATE_ALLOC", SDTypeProfile<0, 0,[]>,
127133
[SDNPHasChain]>;

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 150 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,30 @@ using namespace llvm;
7272

7373
namespace {
7474

75-
enum ZAState {
75+
// Note: For agnostic ZA, we assume the function is always entered/exited in the
76+
// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
77+
// possibility, but for the purpose of placing ZA saves/restores, that does not
78+
// matter).
79+
enum ZAState : uint8_t {
7680
// Any/unknown state (not valid)
7781
ANY = 0,
7882

7983
// ZA is in use and active (i.e. within the accumulator)
8084
ACTIVE,
8185

86+
// ZA is active, but ZT0 has been saved.
87+
// This handles the edge case of sharedZA && !sharesZT0.
88+
ACTIVE_ZT0_SAVED,
89+
8290
// A ZA save has been set up or committed (i.e. ZA is dormant or off)
91+
// If the function uses ZT0 it must also be saved.
8392
LOCAL_SAVED,
8493

94+
// ZA has been committed to the lazy save buffer of the current function.
95+
// If the function uses ZT0 it must also be saved.
96+
// ZA is off when a save has been committed.
97+
LOCAL_COMMITTED,
98+
8599
// The ZA/ZT0 state on entry to the function.
86100
ENTRY,
87101

@@ -164,6 +178,14 @@ class EmitContext {
164178
return AgnosticZABufferPtr;
165179
}
166180

181+
int getZT0SaveSlot(MachineFunction &MF) {
182+
if (ZT0SaveFI)
183+
return *ZT0SaveFI;
184+
MachineFrameInfo &MFI = MF.getFrameInfo();
185+
ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
186+
return *ZT0SaveFI;
187+
}
188+
167189
/// Returns true if the function must allocate a ZA save buffer on entry. This
168190
/// will be the case if, at any point in the function, a ZA save was emitted.
169191
bool needsSaveBuffer() const {
@@ -173,6 +195,7 @@ class EmitContext {
173195
}
174196

175197
private:
198+
std::optional<int> ZT0SaveFI;
176199
std::optional<int> TPIDR2BlockFI;
177200
Register AgnosticZABufferPtr = AArch64::NoRegister;
178201
};
@@ -184,8 +207,10 @@ class EmitContext {
184207
/// state would not be legal, as transitioning to it drops the content of ZA.
185208
static bool isLegalEdgeBundleZAState(ZAState State) {
186209
switch (State) {
187-
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
188-
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
210+
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
211+
case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active).
212+
case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack.
213+
case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack.
189214
return true;
190215
default:
191216
return false;
@@ -199,7 +224,9 @@ StringRef getZAStateString(ZAState State) {
199224
switch (State) {
200225
MAKE_CASE(ZAState::ANY)
201226
MAKE_CASE(ZAState::ACTIVE)
227+
MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
202228
MAKE_CASE(ZAState::LOCAL_SAVED)
229+
MAKE_CASE(ZAState::LOCAL_COMMITTED)
203230
MAKE_CASE(ZAState::ENTRY)
204231
MAKE_CASE(ZAState::OFF)
205232
default:
@@ -221,18 +248,34 @@ static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
221248
/// Returns the required ZA state needed before \p MI and an iterator pointing
222249
/// to where any code required to change the ZA state should be inserted.
223250
static std::pair<ZAState, MachineBasicBlock::iterator>
224-
getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
225-
bool ZAOffAtReturn) {
251+
getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
252+
SMEAttrs SMEFnAttrs) {
226253
MachineBasicBlock::iterator InsertPt(MI);
227254

228255
if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
229256
return {ZAState::ACTIVE, std::prev(InsertPt)};
230257

258+
// Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
231259
if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
232260
return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
233261

234-
if (MI.isReturn())
262+
// If we only need to save ZT0 there's two cases to consider:
263+
// 1. The function has ZA state (that we don't need to save).
264+
// - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
265+
// This only saves ZT0.
266+
// 2. The function does not have ZA state
267+
// - In this case we switch to "LOCAL_COMMITTED" state.
268+
// This saves ZT0 and turns ZA off.
269+
if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
270+
return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
271+
: ZAState::LOCAL_COMMITTED,
272+
std::prev(InsertPt)};
273+
}
274+
275+
if (MI.isReturn()) {
276+
bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
235277
return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
278+
}
236279

237280
for (auto &MO : MI.operands()) {
238281
if (isZAorZTRegOp(TRI, MO))
@@ -280,6 +323,9 @@ struct MachineSMEABI : public MachineFunctionPass {
280323
/// predecessors).
281324
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
282325

326+
void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
327+
MachineBasicBlock::iterator MBBI, bool IsSave);
328+
283329
// Emission routines for private and shared ZA functions (using lazy saves).
284330
void emitSMEPrologue(MachineBasicBlock &MBB,
285331
MachineBasicBlock::iterator MBBI);
@@ -290,8 +336,8 @@ struct MachineSMEABI : public MachineFunctionPass {
290336
MachineBasicBlock::iterator MBBI);
291337
void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
292338
MachineBasicBlock::iterator MBBI);
293-
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
294-
bool ClearTPIDR2);
339+
void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
340+
bool ClearTPIDR2, bool On);
295341

296342
// Emission routines for agnostic ZA functions.
297343
void emitSetupFullZASave(MachineBasicBlock &MBB,
@@ -398,7 +444,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
398444
Block.FixedEntryState = ZAState::ENTRY;
399445
} else if (MBB.isEHPad()) {
400446
// EH entry block:
401-
Block.FixedEntryState = ZAState::LOCAL_SAVED;
447+
Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
402448
}
403449

404450
LiveRegUnits LiveUnits(*TRI);
@@ -420,8 +466,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
420466
PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
421467
}
422468
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
423-
auto [NeededState, InsertPt] = getZAStateBeforeInst(
424-
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
469+
auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
425470
assert((InsertPt == MBBI ||
426471
InsertPt->getOpcode() == AArch64::ADJCALLSTACKDOWN) &&
427472
"Unexpected state change insertion point!");
@@ -742,9 +787,9 @@ void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
742787
restorePhyRegSave(RegSave, MBB, MBBI, DL);
743788
}
744789

745-
void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
746-
MachineBasicBlock::iterator MBBI,
747-
bool ClearTPIDR2) {
790+
void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
791+
MachineBasicBlock::iterator MBBI,
792+
bool ClearTPIDR2, bool On) {
748793
DebugLoc DL = getDebugLoc(MBB, MBBI);
749794

750795
if (ClearTPIDR2)
@@ -755,7 +800,7 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
755800
// Disable ZA.
756801
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
757802
.addImm(AArch64SVCR::SVCRZA)
758-
.addImm(0);
803+
.addImm(On ? 1 : 0);
759804
}
760805

761806
void MachineSMEABI::emitAllocateLazySaveBuffer(
@@ -884,6 +929,28 @@ void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
884929
restorePhyRegSave(RegSave, MBB, MBBI, DL);
885930
}
886931

932+
void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
933+
MachineBasicBlock &MBB,
934+
MachineBasicBlock::iterator MBBI,
935+
bool IsSave) {
936+
DebugLoc DL = getDebugLoc(MBB, MBBI);
937+
Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
938+
939+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
940+
.addFrameIndex(Context.getZT0SaveSlot(*MF))
941+
.addImm(0)
942+
.addImm(0);
943+
944+
if (IsSave) {
945+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
946+
.addReg(AArch64::ZT0)
947+
.addReg(ZT0Save);
948+
} else {
949+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
950+
.addReg(ZT0Save);
951+
}
952+
}
953+
887954
void MachineSMEABI::emitAllocateFullZASaveBuffer(
888955
EmitContext &Context, MachineBasicBlock &MBB,
889956
MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
@@ -928,6 +995,17 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
928995
restorePhyRegSave(RegSave, MBB, MBBI, DL);
929996
}
930997

998+
struct FromState {
999+
ZAState From;
1000+
1001+
constexpr uint8_t to(ZAState To) const {
1002+
static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1003+
return uint8_t(From) << 4 | uint8_t(To);
1004+
}
1005+
};
1006+
1007+
constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }
1008+
9311009
void MachineSMEABI::emitStateChange(EmitContext &Context,
9321010
MachineBasicBlock &MBB,
9331011
MachineBasicBlock::iterator InsertPt,
@@ -959,17 +1037,63 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
9591037
From = ZAState::ACTIVE;
9601038
}
9611039

962-
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
963-
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
964-
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
965-
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
966-
else if (To == ZAState::OFF) {
967-
assert(From != ZAState::ENTRY &&
968-
"ENTRY to OFF should have already been handled");
969-
assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
970-
"Should not turn ZA off in agnostic ZA function");
971-
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
972-
} else {
1040+
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1041+
bool HasZT0State = SMEFnAttrs.hasZT0State();
1042+
bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1043+
1044+
switch (transitionFrom(From).to(To)) {
1045+
// This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1046+
case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
1047+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1048+
break;
1049+
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
1050+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1051+
break;
1052+
1053+
// This section handles: ACTIVE -> LOCAL_SAVED
1054+
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
1055+
if (HasZT0State)
1056+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1057+
if (HasZAState)
1058+
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
1059+
break;
1060+
1061+
// This section handles: ACTIVE -> LOCAL_COMMITTED
1062+
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
1063+
// Note: We could support ZA state here, but this transition is currently
1064+
// only possible when we _don't_ have ZA state.
1065+
assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1066+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1067+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1068+
break;
1069+
1070+
// This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1071+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
1072+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
1073+
// These transistions are a no-op.
1074+
break;
1075+
1076+
// This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1077+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
1078+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
1079+
case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
1080+
if (HasZAState)
1081+
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
1082+
else
1083+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1084+
if (HasZT0State && To == ZAState::ACTIVE)
1085+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1086+
break;
1087+
default:
1088+
if (To == ZAState::OFF) {
1089+
assert(From != ZAState::ENTRY &&
1090+
"ENTRY to OFF should have already been handled");
1091+
assert(SMEFnAttrs.hasPrivateZAInterface() &&
1092+
"Did not expect to turn ZA off in shared/agnostic ZA function");
1093+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1094+
/*On=*/false);
1095+
break;
1096+
}
9731097
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
9741098
<< getZAStateString(To) << '\n';
9751099
llvm_unreachable("Unimplemented state transition");

llvm/test/CodeGen/AArch64/sme-peephole-opts.ll

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,6 @@ define void @test7() nounwind "aarch64_inout_zt0" {
230230
; CHECK-NEXT: str zt0, [x19]
231231
; CHECK-NEXT: smstop za
232232
; CHECK-NEXT: bl callee
233-
; CHECK-NEXT: smstart za
234-
; CHECK-NEXT: ldr zt0, [x19]
235-
; CHECK-NEXT: str zt0, [x19]
236-
; CHECK-NEXT: smstop za
237233
; CHECK-NEXT: bl callee
238234
; CHECK-NEXT: smstart za
239235
; CHECK-NEXT: ldr zt0, [x19]

0 commit comments

Comments
 (0)