Skip to content

Commit 61a5390

Browse files
committed
[AArch64][SME] Support saving/restoring ZT0 in the MachineSMEABIPass
This patch extends the MachineSMEABIPass to support ZT0. This is done with the addition of two new states: - `ACTIVE_ZT0_SAVED` * This is used when calling a function that shares ZA, but does share ZT0 (i.e., no ZT0 attributes). * This state indicates ZT0 must be saved to the save slot, but must remain on, with no lazy save setup - `LOCAL_COMMITTED` * This is used for saving ZT0 in functions without ZA state. * This state indicates ZA is off and ZT0 has been saved. * This state is general enough to support ZA, but those have not been implemented† To aid with readability, the state transitions have been reworked to a switch of `transitionFrom(<FromState>).to(<ToState>)`, rather than nested ifs, which helps manage more transitions. † This could be implemented to handle some cases of undefined behavior better. Change-Id: I14be4a7f8b998fe667bfaade5088f88039515f91
1 parent 6a9d864 commit 61a5390

File tree

7 files changed

+321
-105
lines changed

7 files changed

+321
-105
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,7 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
17171717
}
17181718
case AArch64::InOutZAUsePseudo:
17191719
case AArch64::RequiresZASavePseudo:
1720+
case AArch64::RequiresZT0SavePseudo:
17201721
case AArch64::SMEStateAllocPseudo:
17211722
case AArch64::COALESCER_BARRIER_FPR16:
17221723
case AArch64::COALESCER_BARRIER_FPR32:

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9524,6 +9524,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
95249524
if (CallAttrs.requiresLazySave() ||
95259525
CallAttrs.requiresPreservingAllZAState())
95269526
ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
9527+
else if (CallAttrs.requiresPreservingZT0())
9528+
ZAMarkerNode = AArch64ISD::REQUIRES_ZT0_SAVE;
95279529
else if (CallAttrs.caller().hasZAState() ||
95289530
CallAttrs.caller().hasZT0State())
95299531
ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
@@ -9643,7 +9645,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96439645

96449646
SDValue ZTFrameIdx;
96459647
MachineFrameInfo &MFI = MF.getFrameInfo();
9646-
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
9648+
bool ShouldPreserveZT0 =
9649+
!UseNewSMEABILowering && CallAttrs.requiresPreservingZT0();
96479650

96489651
// If the caller has ZT0 state which will not be preserved by the callee,
96499652
// spill ZT0 before the call.
@@ -9656,7 +9659,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96569659

96579660
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
96589661
// PSTATE.ZA before the call if there is no lazy-save active.
9659-
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
9662+
bool DisableZA =
9663+
!UseNewSMEABILowering && CallAttrs.requiresDisablingZABeforeCall();
96609664
assert((!DisableZA || !RequiresLazySave) &&
96619665
"Lazy-save should have PSTATE.SM=1 on entry to the function");
96629666

@@ -10142,7 +10146,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
1014210146
getSMToggleCondition(CallAttrs));
1014310147
}
1014410148

10145-
if (RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall())
10149+
if (!UseNewSMEABILowering &&
10150+
(RequiresLazySave || CallAttrs.requiresEnablingZAAfterCall()))
1014610151
// Unconditionally resume ZA.
1014710152
Result = DAG.getNode(
1014810153
AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
102102
let hasSideEffects = 1, isMeta = 1 in {
103103
def InOutZAUsePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
104104
def RequiresZASavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
105+
def RequiresZT0SavePseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
105106
}
106107

107108
def SMEStateAllocPseudo : Pseudo<(outs), (ins), []>, Sched<[]>;
@@ -122,6 +123,11 @@ def AArch64_requires_za_save
122123
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
123124
def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
124125

126+
def AArch64_requires_zt0_save
127+
: SDNode<"AArch64ISD::REQUIRES_ZT0_SAVE", SDTypeProfile<0, 0, []>,
128+
[SDNPHasChain, SDNPInGlue, SDNPOutGlue]>;
129+
def : Pat<(AArch64_requires_zt0_save), (RequiresZT0SavePseudo)>;
130+
125131
def AArch64_sme_state_alloc
126132
: SDNode<"AArch64ISD::SME_STATE_ALLOC", SDTypeProfile<0, 0,[]>,
127133
[SDNPHasChain]>;

llvm/lib/Target/AArch64/MachineSMEABIPass.cpp

Lines changed: 150 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,30 @@ using namespace llvm;
7272

7373
namespace {
7474

75-
enum ZAState {
75+
// Note: For agnostic ZA, we assume the function is always entered/exited in the
76+
// "ACTIVE" state -- this _may_ not be the case (since OFF is also a
77+
// possibility, but for the purpose of placing ZA saves/restores, that does not
78+
// matter).
79+
enum ZAState : uint8_t {
7680
// Any/unknown state (not valid)
7781
ANY = 0,
7882

7983
// ZA is in use and active (i.e. within the accumulator)
8084
ACTIVE,
8185

86+
// ZA is active, but ZT0 has been saved.
87+
// This handles the edge case of sharedZA && !sharesZT0.
88+
ACTIVE_ZT0_SAVED,
89+
8290
// A ZA save has been set up or committed (i.e. ZA is dormant or off)
91+
// If the function uses ZT0 it must also be saved.
8392
LOCAL_SAVED,
8493

94+
// ZA has been committed to the lazy save buffer of the current function.
95+
// If the function uses ZT0 it must also be saved.
96+
// ZA is off when a save has been committed.
97+
LOCAL_COMMITTED,
98+
8599
// The ZA/ZT0 state on entry to the function.
86100
ENTRY,
87101

@@ -164,6 +178,14 @@ class EmitContext {
164178
return AgnosticZABufferPtr;
165179
}
166180

181+
int getZT0SaveSlot(MachineFunction &MF) {
182+
if (ZT0SaveFI)
183+
return *ZT0SaveFI;
184+
MachineFrameInfo &MFI = MF.getFrameInfo();
185+
ZT0SaveFI = MFI.CreateSpillStackObject(64, Align(16));
186+
return *ZT0SaveFI;
187+
}
188+
167189
/// Returns true if the function must allocate a ZA save buffer on entry. This
168190
/// will be the case if, at any point in the function, a ZA save was emitted.
169191
bool needsSaveBuffer() const {
@@ -173,6 +195,7 @@ class EmitContext {
173195
}
174196

175197
private:
198+
std::optional<int> ZT0SaveFI;
176199
std::optional<int> TPIDR2BlockFI;
177200
Register AgnosticZABufferPtr = AArch64::NoRegister;
178201
};
@@ -184,8 +207,10 @@ class EmitContext {
184207
/// state would not be legal, as transitioning to it drops the content of ZA.
185208
static bool isLegalEdgeBundleZAState(ZAState State) {
186209
switch (State) {
187-
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
188-
case ZAState::LOCAL_SAVED: // ZA state is saved on the stack.
210+
case ZAState::ACTIVE: // ZA state within the accumulator/ZT0.
211+
case ZAState::ACTIVE_ZT0_SAVED: // ZT0 is saved (ZA is active).
212+
case ZAState::LOCAL_SAVED: // ZA state may be saved on the stack.
213+
case ZAState::LOCAL_COMMITTED: // ZA state is saved on the stack.
189214
return true;
190215
default:
191216
return false;
@@ -199,7 +224,9 @@ StringRef getZAStateString(ZAState State) {
199224
switch (State) {
200225
MAKE_CASE(ZAState::ANY)
201226
MAKE_CASE(ZAState::ACTIVE)
227+
MAKE_CASE(ZAState::ACTIVE_ZT0_SAVED)
202228
MAKE_CASE(ZAState::LOCAL_SAVED)
229+
MAKE_CASE(ZAState::LOCAL_COMMITTED)
203230
MAKE_CASE(ZAState::ENTRY)
204231
MAKE_CASE(ZAState::OFF)
205232
default:
@@ -221,18 +248,34 @@ static bool isZAorZTRegOp(const TargetRegisterInfo &TRI,
221248
/// Returns the required ZA state needed before \p MI and an iterator pointing
222249
/// to where any code required to change the ZA state should be inserted.
223250
static std::pair<ZAState, MachineBasicBlock::iterator>
224-
getZAStateBeforeInst(const TargetRegisterInfo &TRI, MachineInstr &MI,
225-
bool ZAOffAtReturn) {
251+
getInstNeededZAState(const TargetRegisterInfo &TRI, MachineInstr &MI,
252+
SMEAttrs SMEFnAttrs) {
226253
MachineBasicBlock::iterator InsertPt(MI);
227254

228255
if (MI.getOpcode() == AArch64::InOutZAUsePseudo)
229256
return {ZAState::ACTIVE, std::prev(InsertPt)};
230257

258+
// Note: If we need to save both ZA and ZT0 we use RequiresZASavePseudo.
231259
if (MI.getOpcode() == AArch64::RequiresZASavePseudo)
232260
return {ZAState::LOCAL_SAVED, std::prev(InsertPt)};
233261

234-
if (MI.isReturn())
262+
// If we only need to save ZT0 there's two cases to consider:
263+
// 1. The function has ZA state (that we don't need to save).
264+
// - In this case we switch to the "ACTIVE_ZT0_SAVED" state.
265+
// This only saves ZT0.
266+
// 2. The function does not have ZA state
267+
// - In this case we switch to "LOCAL_COMMITTED" state.
268+
// This saves ZT0 and turns ZA off.
269+
if (MI.getOpcode() == AArch64::RequiresZT0SavePseudo) {
270+
return {SMEFnAttrs.hasZAState() ? ZAState::ACTIVE_ZT0_SAVED
271+
: ZAState::LOCAL_COMMITTED,
272+
std::prev(InsertPt)};
273+
}
274+
275+
if (MI.isReturn()) {
276+
bool ZAOffAtReturn = SMEFnAttrs.hasPrivateZAInterface();
235277
return {ZAOffAtReturn ? ZAState::OFF : ZAState::ACTIVE, InsertPt};
278+
}
236279

237280
for (auto &MO : MI.operands()) {
238281
if (isZAorZTRegOp(TRI, MO))
@@ -280,6 +323,9 @@ struct MachineSMEABI : public MachineFunctionPass {
280323
/// predecessors).
281324
void propagateDesiredStates(FunctionInfo &FnInfo, bool Forwards = true);
282325

326+
void emitZT0SaveRestore(EmitContext &, MachineBasicBlock &MBB,
327+
MachineBasicBlock::iterator MBBI, bool IsSave);
328+
283329
// Emission routines for private and shared ZA functions (using lazy saves).
284330
void emitSMEPrologue(MachineBasicBlock &MBB,
285331
MachineBasicBlock::iterator MBBI);
@@ -290,8 +336,8 @@ struct MachineSMEABI : public MachineFunctionPass {
290336
MachineBasicBlock::iterator MBBI);
291337
void emitAllocateLazySaveBuffer(EmitContext &, MachineBasicBlock &MBB,
292338
MachineBasicBlock::iterator MBBI);
293-
void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
294-
bool ClearTPIDR2);
339+
void emitZAMode(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
340+
bool ClearTPIDR2, bool On);
295341

296342
// Emission routines for agnostic ZA functions.
297343
void emitSetupFullZASave(MachineBasicBlock &MBB,
@@ -409,7 +455,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
409455
Block.FixedEntryState = ZAState::ENTRY;
410456
} else if (MBB.isEHPad()) {
411457
// EH entry block:
412-
Block.FixedEntryState = ZAState::LOCAL_SAVED;
458+
Block.FixedEntryState = ZAState::LOCAL_COMMITTED;
413459
}
414460

415461
LiveRegUnits LiveUnits(*TRI);
@@ -431,8 +477,7 @@ FunctionInfo MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
431477
PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
432478
}
433479
// Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
434-
auto [NeededState, InsertPt] = getZAStateBeforeInst(
435-
*TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
480+
auto [NeededState, InsertPt] = getInstNeededZAState(*TRI, MI, SMEFnAttrs);
436481
assert((InsertPt == MBBI || isCallStartOpcode(InsertPt->getOpcode())) &&
437482
"Unexpected state change insertion point!");
438483
// TODO: Do something to avoid state changes where NZCV is live.
@@ -752,9 +797,9 @@ void MachineSMEABI::emitRestoreLazySave(EmitContext &Context,
752797
restorePhyRegSave(RegSave, MBB, MBBI, DL);
753798
}
754799

755-
void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
756-
MachineBasicBlock::iterator MBBI,
757-
bool ClearTPIDR2) {
800+
void MachineSMEABI::emitZAMode(MachineBasicBlock &MBB,
801+
MachineBasicBlock::iterator MBBI,
802+
bool ClearTPIDR2, bool On) {
758803
DebugLoc DL = getDebugLoc(MBB, MBBI);
759804

760805
if (ClearTPIDR2)
@@ -765,7 +810,7 @@ void MachineSMEABI::emitZAOff(MachineBasicBlock &MBB,
765810
// Disable ZA.
766811
BuildMI(MBB, MBBI, DL, TII->get(AArch64::MSRpstatesvcrImm1))
767812
.addImm(AArch64SVCR::SVCRZA)
768-
.addImm(0);
813+
.addImm(On ? 1 : 0);
769814
}
770815

771816
void MachineSMEABI::emitAllocateLazySaveBuffer(
@@ -894,6 +939,28 @@ void MachineSMEABI::emitFullZASaveRestore(EmitContext &Context,
894939
restorePhyRegSave(RegSave, MBB, MBBI, DL);
895940
}
896941

942+
void MachineSMEABI::emitZT0SaveRestore(EmitContext &Context,
943+
MachineBasicBlock &MBB,
944+
MachineBasicBlock::iterator MBBI,
945+
bool IsSave) {
946+
DebugLoc DL = getDebugLoc(MBB, MBBI);
947+
Register ZT0Save = MRI->createVirtualRegister(&AArch64::GPR64spRegClass);
948+
949+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::ADDXri), ZT0Save)
950+
.addFrameIndex(Context.getZT0SaveSlot(*MF))
951+
.addImm(0)
952+
.addImm(0);
953+
954+
if (IsSave) {
955+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::STR_TX))
956+
.addReg(AArch64::ZT0)
957+
.addReg(ZT0Save);
958+
} else {
959+
BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDR_TX), AArch64::ZT0)
960+
.addReg(ZT0Save);
961+
}
962+
}
963+
897964
void MachineSMEABI::emitAllocateFullZASaveBuffer(
898965
EmitContext &Context, MachineBasicBlock &MBB,
899966
MachineBasicBlock::iterator MBBI, LiveRegs PhysLiveRegs) {
@@ -938,6 +1005,17 @@ void MachineSMEABI::emitAllocateFullZASaveBuffer(
9381005
restorePhyRegSave(RegSave, MBB, MBBI, DL);
9391006
}
9401007

1008+
struct FromState {
1009+
ZAState From;
1010+
1011+
constexpr uint8_t to(ZAState To) const {
1012+
static_assert(NUM_ZA_STATE < 16, "expected ZAState to fit in 4-bits");
1013+
return uint8_t(From) << 4 | uint8_t(To);
1014+
}
1015+
};
1016+
1017+
constexpr FromState transitionFrom(ZAState From) { return FromState{From}; }
1018+
9411019
void MachineSMEABI::emitStateChange(EmitContext &Context,
9421020
MachineBasicBlock &MBB,
9431021
MachineBasicBlock::iterator InsertPt,
@@ -969,17 +1047,63 @@ void MachineSMEABI::emitStateChange(EmitContext &Context,
9691047
From = ZAState::ACTIVE;
9701048
}
9711049

972-
if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
973-
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
974-
else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
975-
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
976-
else if (To == ZAState::OFF) {
977-
assert(From != ZAState::ENTRY &&
978-
"ENTRY to OFF should have already been handled");
979-
assert(!SMEFnAttrs.hasAgnosticZAInterface() &&
980-
"Should not turn ZA off in agnostic ZA function");
981-
emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
982-
} else {
1050+
bool IsAgnosticZA = SMEFnAttrs.hasAgnosticZAInterface();
1051+
bool HasZT0State = SMEFnAttrs.hasZT0State();
1052+
bool HasZAState = IsAgnosticZA || SMEFnAttrs.hasZAState();
1053+
1054+
switch (transitionFrom(From).to(To)) {
1055+
// This section handles: ACTIVE <-> ACTIVE_ZT0_SAVED
1056+
case transitionFrom(ZAState::ACTIVE).to(ZAState::ACTIVE_ZT0_SAVED):
1057+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1058+
break;
1059+
case transitionFrom(ZAState::ACTIVE_ZT0_SAVED).to(ZAState::ACTIVE):
1060+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1061+
break;
1062+
1063+
// This section handles: ACTIVE -> LOCAL_SAVED
1064+
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_SAVED):
1065+
if (HasZT0State)
1066+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1067+
if (HasZAState)
1068+
emitZASave(Context, MBB, InsertPt, PhysLiveRegs);
1069+
break;
1070+
1071+
// This section handles: ACTIVE -> LOCAL_COMMITTED
1072+
case transitionFrom(ZAState::ACTIVE).to(ZAState::LOCAL_COMMITTED):
1073+
// Note: We could support ZA state here, but this transition is currently
1074+
// only possible when we _don't_ have ZA state.
1075+
assert(HasZT0State && !HasZAState && "Expect to only have ZT0 state.");
1076+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/true);
1077+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/false);
1078+
break;
1079+
1080+
// This section handles: LOCAL_COMMITTED -> (OFF|LOCAL_SAVED)
1081+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::OFF):
1082+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::LOCAL_SAVED):
1083+
// These transistions are a no-op.
1084+
break;
1085+
1086+
// This section handles: LOCAL_(SAVED|COMMITTED) -> ACTIVE[_ZT0_SAVED]
1087+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE):
1088+
case transitionFrom(ZAState::LOCAL_COMMITTED).to(ZAState::ACTIVE_ZT0_SAVED):
1089+
case transitionFrom(ZAState::LOCAL_SAVED).to(ZAState::ACTIVE):
1090+
if (HasZAState)
1091+
emitZARestore(Context, MBB, InsertPt, PhysLiveRegs);
1092+
else
1093+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/false, /*On=*/true);
1094+
if (HasZT0State && To == ZAState::ACTIVE)
1095+
emitZT0SaveRestore(Context, MBB, InsertPt, /*IsSave=*/false);
1096+
break;
1097+
default:
1098+
if (To == ZAState::OFF) {
1099+
assert(From != ZAState::ENTRY &&
1100+
"ENTRY to OFF should have already been handled");
1101+
assert(SMEFnAttrs.hasPrivateZAInterface() &&
1102+
"Did not expect to turn ZA off in shared/agnostic ZA function");
1103+
emitZAMode(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED,
1104+
/*On=*/false);
1105+
break;
1106+
}
9831107
dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
9841108
<< getZAStateString(To) << '\n';
9851109
llvm_unreachable("Unimplemented state transition");

llvm/test/CodeGen/AArch64/sme-peephole-opts.ll

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,6 @@ define void @test7() nounwind "aarch64_inout_zt0" {
230230
; CHECK-NEXT: str zt0, [x19]
231231
; CHECK-NEXT: smstop za
232232
; CHECK-NEXT: bl callee
233-
; CHECK-NEXT: smstart za
234-
; CHECK-NEXT: ldr zt0, [x19]
235-
; CHECK-NEXT: str zt0, [x19]
236-
; CHECK-NEXT: smstop za
237233
; CHECK-NEXT: bl callee
238234
; CHECK-NEXT: smstart za
239235
; CHECK-NEXT: ldr zt0, [x19]

0 commit comments

Comments
 (0)