diff --git a/src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp b/src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp new file mode 100644 index 000000000..fc717b913 --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp @@ -0,0 +1,128 @@ +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert the Func dialect to the Tessera +// dialect and from the Tessera dialect to the Func dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +using namespace mlir; +using namespace mlir::enzyme::tessera; + +namespace { +} // namespace + + +//===----------------------------------------------------------------------===// +// Rewrite Patterns +//===----------------------------------------------------------------------===// + +namespace { + +// Rewrite 'func.func' -> 'tessera.define' +class FuncOpRewrite final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(func::FuncOp funcOp, + PatternRewriter &rewriter) const override { + FunctionType fnType = funcOp.getFunctionType(); + + if (fnType.getNumResults() > 1) + return rewriter.notifyMatchFailure( + funcOp, "only functions with zero or one result can be rewritten"); + + + // Create the `tessera.define` op + auto tesseraDefineOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), fnType); + + + // Copy over all attributes other than the function name and type. + for (const auto &namedAttr : funcOp->getAttrs()) { + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && + namedAttr.getName() != SymbolTable::getSymbolAttrName()) + tesseraDefineOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + // Add `extern` to specifiers if `func.func` is declaration only. + if (funcOp.isDeclaration()) { + ArrayAttr specifiers = rewriter.getStrArrayAttr({"extern"}); + tesseraDefineOp.setSpecifiersAttr(specifiers); + } + + // Add `static` to specifiers if `func.func` is private but not a + // declaration. + if (funcOp.isPrivate() && !funcOp.isDeclaration()) { + ArrayAttr specifiers = rewriter.getStrArrayAttr({"static"}); + tesseraDefineOp.setSpecifiersAttr(specifiers); + } + + if (!funcOp.isDeclaration()) { + rewriter.inlineRegionBefore(funcOp.getBody(), tesseraDefineOp.getBody(), + tesseraDefineOp.end()); + } + + + rewriter.eraseOp(funcOp); + + return success(); + } +}; + +// Rewrite 'func.call' -> 'tessera.call' +class CallOpRewrite final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(func::CallOp callOp, + PatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(callOp, callOp.getResultTypes(), + callOp.getOperands(), + callOp->getAttrs()); + + return success(); + } +}; + +// Rewrite 'func.return' -> 'tessera.return' +class ReturnOpRewrite final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(func::ReturnOp returnOp, + PatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp(returnOp, + returnOp.getOperands()); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pass to convert Func operations into Tessera operations +//===----------------------------------------------------------------------===// + +struct FuncToTesseraPass + : public PassWrapper> { + + void runOnOperation() override { + MLIRContext &ctx = patterns.getContext(); + RewritePatternSet patterns(&ctx); + + patterns.add(&ctx); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) + signalPassFailure(); + } +}; + diff --git a/src/enzyme_ad/jax/Passes/Tessera/Passes.td b/src/enzyme_ad/jax/Passes/Tessera/Passes.td new file mode 100644 index 000000000..91bcda7ad --- /dev/null +++ b/src/enzyme_ad/jax/Passes/Tessera/Passes.td @@ -0,0 +1,14 @@ +#ifndef ENZYME_AD_JAX_PASSES_TESSERA_PASSES_TD +#define ENZYME_AD_JAX_PASSES_TESSERA_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def FuncToTesseraPass : Pass<"func-to-tessera"> { + let summary = "Convert operations in the FuncDialect to operations in the TesseraDialect and vice versa"; + let dependentDialects = [ + "func::FuncDialect", + "tessera::TesseraDialect" + ]; +} + +#endif // ENZYME_AD_JAX_PASSES_TESSERA_PASSES_TD \ No newline at end of file diff --git a/workspace.bzl b/workspace.bzl index 7ac0532c8..a75afc217 100644 --- a/workspace.bzl +++ b/workspace.bzl @@ -8,7 +8,7 @@ ENZYME_SHA256 = "" # otherwise this should be a path to the folder containing the BUILD file for enzyme OVERRIDE_ENZYME_PATH = "" -HEDRON_COMPILE_COMMANDS_COMMIT = "4f28899228fb3ad0126897876f147ca15026151e" +HEDRON_COMPILE_COMMANDS_COMMIT = "d107d9c9025915902fd52346f1c6e18d87f7013a" HEDRON_COMPILE_COMMANDS_SHA256 = "" XLA_PATCHES = [