Skip to content

Commit 3a44f70

Browse files
authored
[MLIR] Reverse Mode return arg conversions (#2354)
* reverse mode conversion for active -> const Signed-off-by: Vimarsh Sathia <[email protected]> * Added test Signed-off-by: Vimarsh Sathia <[email protected]> * Added placeholder for rev: dup -> const Signed-off-by: Vimarsh Sathia <[email protected]> * move to ReverseRetOpt * Add return activity conversions * test active -> const for return arg * remove ReverseInOpt for now --------- Signed-off-by: Vimarsh Sathia <[email protected]>
1 parent a96989d commit 3a44f70

File tree

2 files changed

+124
-30
lines changed

2 files changed

+124
-30
lines changed

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -573,18 +573,16 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
573573
* shadow(which we then accumulate into and return). So specifically,
574574
*
575575
* pInput' = pInput (if the activity is enzyme_active, enzyme_const)
576-
* | pInput, dInput (if the activity is enzyme_dup)
577-
* | dInput (if the activity is enzyme_dupnoneed)
576+
* | pInput, dInput (if the activity is enzyme_dup,
577+
* enzyme_dupnoneed)
578578
*
579579
* Now that we have fixed the codegen semantics, we can go ahead and optimize
580580
* for both the input return activities based on usage. Possible activity
581581
* promotion flow for the inputs can be as follows:
582582
* 1. enzyme_active --> enzyme_const (dInput is never used, so we simply don't
583583
* compute it)
584-
* 2. enzyme_activenoneed --> enzyme_constnoneed (It is the noneed equivalent
585-
* of the previous rule and semantically makes sense, although I can't think of
586-
* a function where you don't pass in the input but still compute the derivative
587-
* w.r.t that input)
584+
* 2. enzyme_dup --> enzyme_const (pInput is mutable, readonly, nocapture,
585+
* dInput is never used post AD)
588586
*
589587
* Similarly, one can define a similar activity promotion flow for the outputs:
590588
* 1. enzyme_active --> enzyme_activenoneed (we do need to pass in dOutput into
@@ -611,22 +609,43 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
611609
auto inpActivity = uop.getActivity();
612610
auto retActivity = uop.getRetActivity();
613611
auto out_idx = 0;
612+
SmallVector<mlir::Value, 2> in_args;
614613
SmallVector<mlir::Value, 2> outs_args;
614+
SmallVector<Type, 2> in_ty;
615615
SmallVector<Type, 2> out_ty;
616616
SmallVector<ActivityAttr, 2> newInActivityArgs;
617617
SmallVector<ActivityAttr, 2> newRetActivityArgs;
618618

619619
bool changed = false;
620+
auto in_idx = 0;
621+
622+
// go upto dOutput
623+
for (auto [idx, act] : llvm::enumerate(inpActivity)) {
624+
auto iattr = cast<ActivityAttr>(act);
625+
auto val = iattr.getValue();
626+
mlir::Value res = uop.getInputs()[in_idx];
627+
in_args.push_back(res);
628+
in_ty.push_back(res.getType());
629+
in_idx++;
630+
631+
if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
632+
mlir::Value dres = uop.getInputs()[in_idx];
633+
in_args.push_back(dres);
634+
in_ty.push_back(dres.getType());
635+
in_idx++;
636+
}
637+
}
638+
// function isn't differentiable
639+
if (in_idx == uop.getInputs().size())
640+
return failure();
620641

621642
// handle pOutput
622643
for (auto [idx, act] : llvm::enumerate(retActivity)) {
623-
624644
auto iattr = cast<ActivityAttr>(act);
625645
auto val = iattr.getValue();
626646

627647
// skip primal return
628648
if (val == Activity::enzyme_constnoneed ||
629-
val == Activity::enzyme_activenoneed ||
630649
val == Activity::enzyme_dupnoneed) {
631650
newRetActivityArgs.push_back(iattr);
632651
continue;
@@ -635,49 +654,108 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
635654
mlir::Value res = uop.getOutputs()[out_idx];
636655

637656
switch (val) {
638-
case Activity::enzyme_active:
657+
case Activity::enzyme_active: {
658+
// active -> activenoneed(if res isn't used)
659+
// active -> const(if dres == 0)
660+
// active -> constnoneed(both)
661+
662+
mlir::Value dres = uop.getInputs()[in_idx];
663+
in_idx++;
664+
665+
auto dres_type = dres.getType();
666+
auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
667+
639668
if (!res.use_empty()) {
640669
outs_args.push_back(res);
641670
out_ty.push_back(res.getType());
642-
newRetActivityArgs.push_back(iattr);
671+
ActivityAttr new_act = iattr;
672+
if (dres_type_intf && !isMutable(dres_type) &&
673+
dres_type_intf.isZero(dres)) {
674+
// const
675+
changed = true;
676+
new_act = ActivityAttr::get(rewriter.getContext(),
677+
Activity::enzyme_const);
678+
} else {
679+
in_args.push_back(dres);
680+
in_ty.push_back(dres_type);
681+
}
682+
newRetActivityArgs.push_back(new_act);
643683
} else {
644684
changed = true;
645-
auto new_activenn = ActivityAttr::get(rewriter.getContext(),
646-
Activity::enzyme_activenoneed);
647-
newRetActivityArgs.push_back(new_activenn);
685+
ActivityAttr new_act = ActivityAttr::get(
686+
rewriter.getContext(), Activity::enzyme_activenoneed);
687+
if (dres_type_intf && !isMutable(dres_type) &&
688+
dres_type_intf.isZero(dres)) {
689+
// constnoneed
690+
new_act = ActivityAttr::get(rewriter.getContext(),
691+
Activity::enzyme_constnoneed);
692+
} else {
693+
// activenoneed
694+
in_args.push_back(dres);
695+
in_ty.push_back(dres_type);
696+
}
697+
newRetActivityArgs.push_back(new_act);
648698
}
699+
700+
++out_idx;
649701
break;
702+
}
650703

651-
case Activity::enzyme_const:
652-
if (!res.use_empty()) {
653-
outs_args.push_back(res);
654-
out_ty.push_back(res.getType());
704+
case Activity::enzyme_activenoneed:
705+
// activenoneed -> constnoneed
706+
{
707+
mlir::Value dres = uop.getInputs()[in_idx];
708+
in_idx++;
709+
auto new_act = iattr;
710+
711+
auto dres_type = dres.getType();
712+
auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
713+
if (dres_type_intf && !isMutable(dres_type) &&
714+
dres_type_intf.isZero(dres)) {
715+
// constnoneed
716+
new_act = ActivityAttr::get(rewriter.getContext(),
717+
Activity::enzyme_constnoneed);
718+
} else {
719+
in_args.push_back(dres);
720+
in_ty.push_back(dres_type);
721+
}
655722
newRetActivityArgs.push_back(iattr);
656-
} else {
657-
changed = true;
658-
auto new_constnn = ActivityAttr::get(rewriter.getContext(),
659-
Activity::enzyme_constnoneed);
660-
newRetActivityArgs.push_back(new_constnn);
723+
break;
724+
}
725+
case Activity::enzyme_const:
726+
// const -> constnoneed
727+
{
728+
auto new_act = iattr;
729+
if (!res.use_empty()) {
730+
outs_args.push_back(res);
731+
out_ty.push_back(res.getType());
732+
newRetActivityArgs.push_back(new_act);
733+
} else {
734+
changed = true;
735+
new_act = ActivityAttr::get(rewriter.getContext(),
736+
Activity::enzyme_constnoneed);
737+
newRetActivityArgs.push_back(new_act);
738+
}
739+
++out_idx;
740+
break;
661741
}
662-
break;
663742

664743
case Activity::enzyme_dup:
665-
// dont do anything here for now
744+
// TODO: check if ret_arg == enzyme_dup inserts a derivative as the
745+
// output and input both
666746
outs_args.push_back(res);
667747
out_ty.push_back(res.getType());
668748
newRetActivityArgs.push_back(iattr);
749+
++out_idx;
669750
break;
670751

671-
case Activity::enzyme_activenoneed:
672752
case Activity::enzyme_constnoneed:
673753
case Activity::enzyme_dupnoneed:
674754
break;
675755

676756
default:
677757
llvm_unreachable("unexpected activity arg");
678758
}
679-
680-
++out_idx;
681759
}
682760

683761
// handle dInputs
@@ -732,15 +810,14 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
732810
newRetActivityArgs.end()));
733811

734812
AutoDiffOp newOp = rewriter.create<AutoDiffOp>(
735-
uop.getLoc(), out_ty, uop.getFnAttr(), uop.getInputs(), newInActivity,
813+
uop.getLoc(), out_ty, uop.getFnAttr(), in_args, newInActivity,
736814
newRetActivity, uop.getWidthAttr(), uop.getStrongZeroAttr());
737815

738816
// Map old uses of uop to newOp
739817
auto oldIdx = 0;
740818
auto newIdx = 0;
741819
for (auto [idx, old_act, new_act] :
742820
llvm::enumerate(retActivity, newRetActivityArgs)) {
743-
744821
auto iattr = cast<ActivityAttr>(old_act);
745822
auto old_val = iattr.getValue();
746823
auto new_val = new_act.getValue();
@@ -763,6 +840,16 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
763840
} else if (new_val == Activity::enzyme_constnoneed &&
764841
old_val == Activity::enzyme_const) {
765842
++oldIdx; // skip const primal
843+
} else if (old_val == Activity::enzyme_active &&
844+
new_val == Activity::enzyme_const) {
845+
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
846+
newOp.getOutputs()[newIdx++]);
847+
} else if (old_val == Activity::enzyme_active &&
848+
new_val == Activity::enzyme_constnoneed) {
849+
++oldIdx;
850+
} else if (old_val == Activity::enzyme_activenoneed &&
851+
new_val == Activity::enzyme_constnoneed) {
852+
// just skip
766853
}
767854
}
768855
}
@@ -788,7 +875,6 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
788875
}
789876
}
790877
}
791-
792878
rewriter.eraseOp(uop);
793879
return success();
794880
}

enzyme/test/MLIR/ReverseMode/canonicalize.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,12 @@ module {
4141
// CHECK: enzyme.autodiff @square2(%arg0, %arg1, %arg2, %arg3){{.*}}activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>]{{.*}}ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_activenoneed>]{{.*}}
4242
return %cst : f32
4343
}
44+
45+
// Test 5: active -> const for ret_activity (iff derivative is 0)
46+
func.func @test5(%x: f32, %y: f32, %dr0: f32) -> (f32,f32,f32,f32) {
47+
%cst = arith.constant 0.0000e+00 : f32
48+
%r:4 = enzyme.autodiff @square2(%x,%y,%dr0,%cst) { activity=[#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>] } : (f32,f32,f32,f32) -> (f32,f32,f32,f32)
49+
// CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2){{.*}}activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>]{{.*}}ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]{{.*}}
50+
return %r#0,%r#1,%r#2,%r#3 : f32,f32,f32,f32
51+
}
4452
}

0 commit comments

Comments
 (0)