@@ -573,18 +573,16 @@ void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
573
573
* shadow(which we then accumulate into and return). So specifically,
574
574
*
575
575
* pInput' = pInput (if the activity is enzyme_active, enzyme_const)
576
- * | pInput, dInput (if the activity is enzyme_dup)
577
- * | dInput (if the activity is enzyme_dupnoneed)
576
+ * | pInput, dInput (if the activity is enzyme_dup,
577
+ * enzyme_dupnoneed)
578
578
*
579
579
* Now that we have fixed the codegen semantics, we can go ahead and optimize
580
580
* for both the input return activities based on usage. Possible activity
581
581
* promotion flow for the inputs can be as follows:
582
582
* 1. enzyme_active --> enzyme_const (dInput is never used, so we simply don't
583
583
* compute it)
584
- * 2. enzyme_activenoneed --> enzyme_constnoneed (It is the noneed equivalent
585
- * of the previous rule and semantically makes sense, although I can't think of
586
- * a function where you don't pass in the input but still compute the derivative
587
- * w.r.t that input)
584
+ * 2. enzyme_dup --> enzyme_const (pInput is mutable, readonly, nocapture,
585
+ * dInput is never used post AD)
588
586
*
589
587
* Similarly, one can define a similar activity promotion flow for the outputs:
590
588
* 1. enzyme_active --> enzyme_activenoneed (we do need to pass in dOutput into
@@ -611,22 +609,43 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
611
609
auto inpActivity = uop.getActivity ();
612
610
auto retActivity = uop.getRetActivity ();
613
611
auto out_idx = 0 ;
612
+ SmallVector<mlir::Value, 2 > in_args;
614
613
SmallVector<mlir::Value, 2 > outs_args;
614
+ SmallVector<Type, 2 > in_ty;
615
615
SmallVector<Type, 2 > out_ty;
616
616
SmallVector<ActivityAttr, 2 > newInActivityArgs;
617
617
SmallVector<ActivityAttr, 2 > newRetActivityArgs;
618
618
619
619
bool changed = false ;
620
+ auto in_idx = 0 ;
621
+
622
+ // go upto dOutput
623
+ for (auto [idx, act] : llvm::enumerate (inpActivity)) {
624
+ auto iattr = cast<ActivityAttr>(act);
625
+ auto val = iattr.getValue ();
626
+ mlir::Value res = uop.getInputs ()[in_idx];
627
+ in_args.push_back (res);
628
+ in_ty.push_back (res.getType ());
629
+ in_idx++;
630
+
631
+ if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
632
+ mlir::Value dres = uop.getInputs ()[in_idx];
633
+ in_args.push_back (dres);
634
+ in_ty.push_back (dres.getType ());
635
+ in_idx++;
636
+ }
637
+ }
638
+ // function isn't differentiable
639
+ if (in_idx == uop.getInputs ().size ())
640
+ return failure ();
620
641
621
642
// handle pOutput
622
643
for (auto [idx, act] : llvm::enumerate (retActivity)) {
623
-
624
644
auto iattr = cast<ActivityAttr>(act);
625
645
auto val = iattr.getValue ();
626
646
627
647
// skip primal return
628
648
if (val == Activity::enzyme_constnoneed ||
629
- val == Activity::enzyme_activenoneed ||
630
649
val == Activity::enzyme_dupnoneed) {
631
650
newRetActivityArgs.push_back (iattr);
632
651
continue ;
@@ -635,49 +654,108 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
635
654
mlir::Value res = uop.getOutputs ()[out_idx];
636
655
637
656
switch (val) {
638
- case Activity::enzyme_active:
657
+ case Activity::enzyme_active: {
658
+ // active -> activenoneed(if res isn't used)
659
+ // active -> const(if dres == 0)
660
+ // active -> constnoneed(both)
661
+
662
+ mlir::Value dres = uop.getInputs ()[in_idx];
663
+ in_idx++;
664
+
665
+ auto dres_type = dres.getType ();
666
+ auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
667
+
639
668
if (!res.use_empty ()) {
640
669
outs_args.push_back (res);
641
670
out_ty.push_back (res.getType ());
642
- newRetActivityArgs.push_back (iattr);
671
+ ActivityAttr new_act = iattr;
672
+ if (dres_type_intf && !isMutable (dres_type) &&
673
+ dres_type_intf.isZero (dres)) {
674
+ // const
675
+ changed = true ;
676
+ new_act = ActivityAttr::get (rewriter.getContext (),
677
+ Activity::enzyme_const);
678
+ } else {
679
+ in_args.push_back (dres);
680
+ in_ty.push_back (dres_type);
681
+ }
682
+ newRetActivityArgs.push_back (new_act);
643
683
} else {
644
684
changed = true ;
645
- auto new_activenn = ActivityAttr::get (rewriter.getContext (),
646
- Activity::enzyme_activenoneed);
647
- newRetActivityArgs.push_back (new_activenn);
685
+ ActivityAttr new_act = ActivityAttr::get (
686
+ rewriter.getContext (), Activity::enzyme_activenoneed);
687
+ if (dres_type_intf && !isMutable (dres_type) &&
688
+ dres_type_intf.isZero (dres)) {
689
+ // constnoneed
690
+ new_act = ActivityAttr::get (rewriter.getContext (),
691
+ Activity::enzyme_constnoneed);
692
+ } else {
693
+ // activenoneed
694
+ in_args.push_back (dres);
695
+ in_ty.push_back (dres_type);
696
+ }
697
+ newRetActivityArgs.push_back (new_act);
648
698
}
699
+
700
+ ++out_idx;
649
701
break ;
702
+ }
650
703
651
- case Activity::enzyme_const:
652
- if (!res.use_empty ()) {
653
- outs_args.push_back (res);
654
- out_ty.push_back (res.getType ());
704
+ case Activity::enzyme_activenoneed:
705
+ // activenoneed -> constnoneed
706
+ {
707
+ mlir::Value dres = uop.getInputs ()[in_idx];
708
+ in_idx++;
709
+ auto new_act = iattr;
710
+
711
+ auto dres_type = dres.getType ();
712
+ auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
713
+ if (dres_type_intf && !isMutable (dres_type) &&
714
+ dres_type_intf.isZero (dres)) {
715
+ // constnoneed
716
+ new_act = ActivityAttr::get (rewriter.getContext (),
717
+ Activity::enzyme_constnoneed);
718
+ } else {
719
+ in_args.push_back (dres);
720
+ in_ty.push_back (dres_type);
721
+ }
655
722
newRetActivityArgs.push_back (iattr);
656
- } else {
657
- changed = true ;
658
- auto new_constnn = ActivityAttr::get (rewriter.getContext (),
659
- Activity::enzyme_constnoneed);
660
- newRetActivityArgs.push_back (new_constnn);
723
+ break ;
724
+ }
725
+ case Activity::enzyme_const:
726
+ // const -> constnoneed
727
+ {
728
+ auto new_act = iattr;
729
+ if (!res.use_empty ()) {
730
+ outs_args.push_back (res);
731
+ out_ty.push_back (res.getType ());
732
+ newRetActivityArgs.push_back (new_act);
733
+ } else {
734
+ changed = true ;
735
+ new_act = ActivityAttr::get (rewriter.getContext (),
736
+ Activity::enzyme_constnoneed);
737
+ newRetActivityArgs.push_back (new_act);
738
+ }
739
+ ++out_idx;
740
+ break ;
661
741
}
662
- break ;
663
742
664
743
case Activity::enzyme_dup:
665
- // dont do anything here for now
744
+ // TODO: check if ret_arg == enzyme_dup inserts a derivative as the
745
+ // output and input both
666
746
outs_args.push_back (res);
667
747
out_ty.push_back (res.getType ());
668
748
newRetActivityArgs.push_back (iattr);
749
+ ++out_idx;
669
750
break ;
670
751
671
- case Activity::enzyme_activenoneed:
672
752
case Activity::enzyme_constnoneed:
673
753
case Activity::enzyme_dupnoneed:
674
754
break ;
675
755
676
756
default :
677
757
llvm_unreachable (" unexpected activity arg" );
678
758
}
679
-
680
- ++out_idx;
681
759
}
682
760
683
761
// handle dInputs
@@ -732,15 +810,14 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
732
810
newRetActivityArgs.end ()));
733
811
734
812
AutoDiffOp newOp = rewriter.create <AutoDiffOp>(
735
- uop.getLoc (), out_ty, uop.getFnAttr (), uop. getInputs () , newInActivity,
813
+ uop.getLoc (), out_ty, uop.getFnAttr (), in_args , newInActivity,
736
814
newRetActivity, uop.getWidthAttr (), uop.getStrongZeroAttr ());
737
815
738
816
// Map old uses of uop to newOp
739
817
auto oldIdx = 0 ;
740
818
auto newIdx = 0 ;
741
819
for (auto [idx, old_act, new_act] :
742
820
llvm::enumerate (retActivity, newRetActivityArgs)) {
743
-
744
821
auto iattr = cast<ActivityAttr>(old_act);
745
822
auto old_val = iattr.getValue ();
746
823
auto new_val = new_act.getValue ();
@@ -763,6 +840,16 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
763
840
} else if (new_val == Activity::enzyme_constnoneed &&
764
841
old_val == Activity::enzyme_const) {
765
842
++oldIdx; // skip const primal
843
+ } else if (old_val == Activity::enzyme_active &&
844
+ new_val == Activity::enzyme_const) {
845
+ uop.getOutputs ()[oldIdx++].replaceAllUsesWith (
846
+ newOp.getOutputs ()[newIdx++]);
847
+ } else if (old_val == Activity::enzyme_active &&
848
+ new_val == Activity::enzyme_constnoneed) {
849
+ ++oldIdx;
850
+ } else if (old_val == Activity::enzyme_activenoneed &&
851
+ new_val == Activity::enzyme_constnoneed) {
852
+ // just skip
766
853
}
767
854
}
768
855
}
@@ -788,7 +875,6 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
788
875
}
789
876
}
790
877
}
791
-
792
878
rewriter.eraseOp (uop);
793
879
return success ();
794
880
}
0 commit comments