diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp index 00e271e62..6151ec940 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp @@ -1,23 +1,22 @@ #include "Enzyme/MLIR/Dialect/Ops.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "src/enzyme_ad/jax/Dialect/Ops.h" #include "src/enzyme_ad/jax/Passes/Passes.h" #include "src/enzyme_ad/jax/Utils.h" +#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" -#include "llvm/ADT/DynamicAPInt.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/MathExtras.h" -#include +#include #include #define DEBUG_TYPE "lower-enzyme-probprog" @@ -32,15 +31,6 @@ namespace enzyme { using namespace mlir; using namespace mlir::enzyme; -// Forward declarations for Enzyme probabilistic programming ops/types that are -// generated via TableGen but may not be visible to clang-tidy. -namespace mlir { -namespace enzyme { -class GetSampleFromConstraintOp; -class ConstraintType; -} // namespace enzyme -} // namespace mlir - static std::string getTensorSignature(Type tensorType) { if (auto rankedType = dyn_cast(tensorType)) { std::string sig; @@ -1078,6 +1068,630 @@ struct GetSubconstraintOpConversion } }; +struct ArithSelectOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string backend; + ArithSelectOpConversion(std::string backend, TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), backend(backend) { + } + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op.getType())) + return failure(); + + // Pattern: %extracted = tensor.extract %tensor[] : tensor + // %result = arith.select %extracted, ... : !enzyme.Trace + auto extractOp = op.getCondition().getDefiningOp(); + if (!extractOp) + return failure(); + + Value tensorCondition = extractOp.getTensor(); + + auto newOp = rewriter.create( + op.getLoc(), adaptor.getTrueValue().getType(), tensorCondition, + adaptor.getTrueValue(), adaptor.getFalseValue()); + + rewriter.replaceOp(op, newOp.getResult()); + + return success(); + } +}; + +// Remove tensor.extract op to generate scalar condition for arith.select op +// from EnzymeMLIR ProbProg pass. +struct TensorExtractOpElimination + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string backend; + TensorExtractOpElimination(std::string backend, TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), backend(backend) { + } + + LogicalResult + matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op->hasOneUse()) + return failure(); + + auto selectOp = dyn_cast(*op->user_begin()); + if (!selectOp || !isa(selectOp.getType())) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct GetSampleFromTraceOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string backend; + GetSampleFromTraceOpConversion(std::string backend, + TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), backend(backend) { + } + + LogicalResult + matchAndRewrite(enzyme::GetSampleFromTraceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = op->getContext(); + + Value trace = adaptor.getTrace(); + auto outputs = op.getSample(); + + auto symbolWrappedAttr = op.getSymbolAttr(); + if (!symbolWrappedAttr) { + return rewriter.notifyMatchFailure(op, "Missing symbol attribute"); + } + + uint64_t symbolValue = symbolWrappedAttr.getPtr(); + + size_t numOutputs = outputs.size(); + if (numOutputs == 0) + return rewriter.notifyMatchFailure(op, + "GetSampleFromTraceOp has no outputs"); + + if (backend == "cpu") { + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); + auto llvmVoidType = LLVM::LLVMVoidType::get(ctx); + auto llvmI64Type = IntegerType::get(ctx, 64); + + std::string getSampleFn = "enzyme_probprog_get_sample_from_trace"; + + auto i64TensorType = RankedTensorType::get({}, llvmI64Type); + auto symbolConst = rewriter.create( + op.getLoc(), i64TensorType, + cast(makeAttr(i64TensorType, symbolValue))); + + SmallVector llvmArgTypes; // (trace, symbol, out_ptrs...) + llvmArgTypes.push_back(llvmPtrType); + llvmArgTypes.push_back(llvmPtrType); + llvmArgTypes.append(numOutputs, llvmPtrType); + + auto funcType = LLVM::LLVMFunctionType::get(llvmVoidType, llvmArgTypes, + /*isVarArg=*/false); + + SmallVector originalTypes; + for (auto output : outputs) { + originalTypes.push_back(output.getType()); + } + + std::string wrapperFn = getOrCreateWrapper(getSampleFn, originalTypes); + + if (!moduleOp.lookupSymbol(wrapperFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto func = + rewriter.create(op.getLoc(), wrapperFn, funcType); + + rewriter.setInsertionPointToStart(func.addEntryBlock(rewriter)); + + auto oneConst = rewriter.create( + op.getLoc(), llvmI64Type, rewriter.getIntegerAttr(llvmI64Type, 1)); + + auto numOutputsConst = rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, numOutputs)); + + auto numOutputsAlloca = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, oneConst); + rewriter.create(op.getLoc(), numOutputsConst, + numOutputsAlloca); + + auto samplePtrArrayAlloca = rewriter.create( + op.getLoc(), llvmPtrType, llvmPtrType, numOutputsConst); + auto numDimsArrayAlloca = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, numOutputsConst); + auto shapePtrArrayAlloca = rewriter.create( + op.getLoc(), llvmPtrType, llvmPtrType, numOutputsConst); + auto dtypeWidthArrayAlloca = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, numOutputsConst); + + for (size_t i = 0; i < numOutputs; ++i) { + auto outType = cast(outputs[i].getType()); + auto outShape = outType.getShape(); + size_t outNumDims = outShape.size(); + size_t outWidth = outType.getElementType().getIntOrFloatBitWidth(); + + auto ptrGEP = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, samplePtrArrayAlloca, + ValueRange{rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, i))}); + rewriter.create(op.getLoc(), func.getArgument(2 + i), + ptrGEP); + + auto numDimsConst = rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, outNumDims)); + auto numDimsGEP = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, numDimsArrayAlloca, + ValueRange{rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, i))}); + rewriter.create(op.getLoc(), numDimsConst, numDimsGEP); + + auto widthConst = rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, outWidth)); + auto widthGEP = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, dtypeWidthArrayAlloca, + ValueRange{rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, i))}); + rewriter.create(op.getLoc(), widthConst, widthGEP); + + auto shapeSizeConst = rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, outNumDims)); + auto shapeArrAlloca = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, shapeSizeConst); + + for (size_t j = 0; j < outNumDims; ++j) { + auto dimConst = rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, outShape[j])); + auto dimGEP = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, shapeArrAlloca, + ValueRange{rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, j))}); + rewriter.create(op.getLoc(), dimConst, dimGEP); + } + + auto shapePtrGEP = rewriter.create( + op.getLoc(), llvmPtrType, llvmI64Type, shapePtrArrayAlloca, + ValueRange{rewriter.create( + op.getLoc(), llvmI64Type, + rewriter.getIntegerAttr(llvmI64Type, i))}); + rewriter.create(op.getLoc(), shapeArrAlloca, + shapePtrGEP); + } + + rewriter.create( + op.getLoc(), TypeRange{}, SymbolRefAttr::get(ctx, getSampleFn), + ValueRange{func.getArgument(0), func.getArgument(1), + samplePtrArrayAlloca, numOutputsAlloca, + numDimsArrayAlloca, shapePtrArrayAlloca, + dtypeWidthArrayAlloca}); + + rewriter.create(op.getLoc(), ValueRange{}); + } + + if (!moduleOp.lookupSymbol(getSampleFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto funcTypeExt = LLVM::LLVMFunctionType::get( + llvmVoidType, + {llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType, llvmPtrType, + llvmPtrType, llvmPtrType}, + /*isVarArg=*/false); + rewriter.create(op.getLoc(), getSampleFn, funcTypeExt, + LLVM::Linkage::External); + } + + SmallVector jitOperands; + jitOperands.push_back(trace); + jitOperands.push_back(symbolConst); + + for (size_t i = 0; i < numOutputs; ++i) { + auto outType = outputs[i].getType(); + auto bufConst = rewriter.create( + op.getLoc(), outType, cast(makeAttr(outType, 0))); + jitOperands.push_back(bufConst); + } + + SmallVector aliases; + for (size_t i = 0; i < numOutputs; ++i) { + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{}, /*operand_index=*/2 + i, + std::vector{})); + } + + auto jitCall = rewriter.create( + op.getLoc(), op->getResultTypes(), + mlir::FlatSymbolRefAttr::get(ctx, wrapperFn), jitOperands, + rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall.getResults()); + + return success(); + } + + return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); + } +}; + +struct RandomOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string backend; + RandomOpConversion(std::string backend, TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), backend(backend) { + } + + LogicalResult + matchAndRewrite(enzyme::RandomOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto distribution = op.getRngDistribution(); + auto resultType = op.getResult().getType(); + auto rankedType = dyn_cast(resultType); + if (!rankedType) { + return rewriter.notifyMatchFailure(op, "Result must be a ranked tensor"); + } + + auto elemType = rankedType.getElementType(); + assert(isa(elemType)); + auto rngStateType = adaptor.getRngState().getType(); + auto rngStateTensorType = dyn_cast(rngStateType); + if (!rngStateTensorType) { + return rewriter.notifyMatchFailure(op, "RNG state must be a tensor"); + } + + unsigned nbits = elemType.getIntOrFloatBitWidth(); + Type uintType = + IntegerType::get(rewriter.getContext(), nbits, IntegerType::Unsigned); + if (!uintType) + return rewriter.notifyMatchFailure( + op, "Failed to create unsigned integer type"); + + auto uintResultType = + RankedTensorType::get(rankedType.getShape(), uintType); + auto rngAlgorithm = mlir::stablehlo::RngAlgorithmAttr::get( + rewriter.getContext(), mlir::stablehlo::RngAlgorithm::DEFAULT); + auto rngBitGenOp = rewriter.create( + op.getLoc(), + /*output_state=*/rngStateTensorType, + /*output=*/uintResultType, + /*rng_algorithm=*/rngAlgorithm, + /*initial_state=*/adaptor.getRngState()); + + Value outputState = rngBitGenOp.getOutputState(); + Value randomBits = rngBitGenOp.getOutput(); + Value result; + + if (distribution == enzyme::RngDistribution::UNIFORM) { + unsigned mantissaBits; + if (nbits == 16) + mantissaBits = 10; // TODO bfloat16 + else if (nbits == 32) + mantissaBits = 23; + else if (nbits == 64) + mantissaBits = 52; + else + return rewriter.notifyMatchFailure(op, "Unsupported float type"); + + auto shiftAmount = rewriter.create( + op.getLoc(), uintResultType, + DenseElementsAttr::get( + uintResultType, + rewriter.getIntegerAttr(uintType, nbits - mantissaBits))); + auto shiftedBits = rewriter.create( + op.getLoc(), uintResultType, randomBits, shiftAmount); + + uint64_t onePattern; + if (nbits == 16) + onePattern = 0x3C00; // TODO bfloat16 + else if (nbits == 32) + onePattern = 0x3F800000; + else if (nbits == 64) + onePattern = 0x3FF0000000000000ULL; + else + return rewriter.notifyMatchFailure(op, + "Unsupported float type: $(nbits)"); + + auto onePatternConst = rewriter.create( + op.getLoc(), uintResultType, + DenseElementsAttr::get( + uintResultType, rewriter.getIntegerAttr(uintType, onePattern))); + auto floatBits = rewriter.create( + op.getLoc(), uintResultType, shiftedBits, onePatternConst); + auto floatValue = rewriter.create( + op.getLoc(), rankedType, floatBits); + auto oneConst = rewriter.create( + op.getLoc(), rankedType, + DenseElementsAttr::get(rankedType, + rewriter.getFloatAttr(elemType, 1.0))); + result = rewriter.create(op.getLoc(), rankedType, + floatValue, oneConst); + } else if (distribution == enzyme::RngDistribution::NORMAL) { + unsigned mantissaBits; + if (nbits == 16) + mantissaBits = 10; // TODO bfloat16 + else if (nbits == 32) + mantissaBits = 23; + else if (nbits == 64) + mantissaBits = 52; + else + return rewriter.notifyMatchFailure(op, + "Unsupported float type: $(nbits)"); + + auto shiftAmount = rewriter.create( + op.getLoc(), uintResultType, + DenseElementsAttr::get( + uintResultType, + rewriter.getIntegerAttr(uintType, nbits - mantissaBits))); + auto shiftedBits = rewriter.create( + op.getLoc(), uintResultType, randomBits, shiftAmount); + + uint64_t onePattern; + if (nbits == 16) + onePattern = 0x3C00; + else if (nbits == 32) + onePattern = 0x3F800000; + else if (nbits == 64) + onePattern = 0x3FF0000000000000ULL; + else + return rewriter.notifyMatchFailure(op, + "Unsupported float type: $(nbits)"); + + auto onePatternConst = rewriter.create( + op.getLoc(), uintResultType, + DenseElementsAttr::get( + uintResultType, rewriter.getIntegerAttr(uintType, onePattern))); + auto floatBits = rewriter.create( + op.getLoc(), uintResultType, shiftedBits, onePatternConst); + + Value randUniform = rewriter + .create( + op.getLoc(), rankedType, floatBits) + .getResult(); + auto oneConst = rewriter.create( + op.getLoc(), rankedType, + DenseElementsAttr::get(rankedType, + rewriter.getFloatAttr(elemType, 1.0))); + randUniform = rewriter + .create(op.getLoc(), rankedType, + randUniform, oneConst) + .getResult(); + auto twoConst = rewriter.create( + op.getLoc(), rankedType, + DenseElementsAttr::get(rankedType, + rewriter.getFloatAttr(elemType, 2.0))); + Value scaledUniform = + rewriter + .create(op.getLoc(), rankedType, randUniform, + twoConst) + .getResult(); + scaledUniform = rewriter + .create( + op.getLoc(), rankedType, scaledUniform, oneConst) + .getResult(); + auto probit = rewriter.create(op.getLoc(), rankedType, + scaledUniform); + double sqrt2 = std::sqrt(2.0); + auto sqrt2Const = rewriter.create( + op.getLoc(), rankedType, + DenseElementsAttr::get(rankedType, + rewriter.getFloatAttr(elemType, sqrt2))); + result = rewriter + .create(op.getLoc(), rankedType, probit, + sqrt2Const) + .getResult(); + } else { + return rewriter.notifyMatchFailure(op, "Unknown RNG distribution"); + } + + rewriter.replaceOp(op, {outputState, result}); + return success(); + } +}; + +struct GetSubtraceOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string backend; + GetSubtraceOpConversion(std::string backend, TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), backend(backend) { + } + + LogicalResult + matchAndRewrite(enzyme::GetSubtraceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = op->getContext(); + + Value trace = adaptor.getTrace(); + + auto symbolAttr = op.getSymbolAttr(); + if (!symbolAttr) { + return rewriter.notifyMatchFailure(op, "Missing symbol attribute"); + } + + uint64_t symbolValue = symbolAttr.getPtr(); + + if (backend == "cpu") { + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); + auto llvmVoidType = LLVM::LLVMVoidType::get(ctx); + auto llvmI64Type = IntegerType::get(ctx, 64); + auto loweredTraceType = RankedTensorType::get( + {}, IntegerType::get(ctx, 64, IntegerType::Unsigned)); + + std::string getSubtraceFn = "enzyme_probprog_get_subtrace"; + + auto i64TensorType = RankedTensorType::get({}, llvmI64Type); + auto symbolConst = rewriter.create( + op.getLoc(), i64TensorType, + cast(makeAttr(i64TensorType, symbolValue))); + + auto subtracePtr = rewriter.create( + op.getLoc(), loweredTraceType, + cast(makeAttr(loweredTraceType, 0))); + + std::string wrapperFn = getOrCreateWrapper(getSubtraceFn); + + if (!moduleOp.lookupSymbol(wrapperFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, {llvmPtrType, llvmPtrType, llvmPtrType}, + /*isVarArg=*/false); + auto func = + rewriter.create(op.getLoc(), wrapperFn, funcType); + + rewriter.setInsertionPointToStart(func.addEntryBlock(rewriter)); + rewriter.create( + op.getLoc(), TypeRange{}, SymbolRefAttr::get(ctx, getSubtraceFn), + ValueRange{func.getArgument(0), func.getArgument(1), + func.getArgument(2)}); + rewriter.create(op.getLoc(), ValueRange{}); + } + + if (!moduleOp.lookupSymbol(getSubtraceFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, {llvmPtrType, llvmPtrType, llvmPtrType}, + /*isVarArg=*/false); + rewriter.create(op.getLoc(), getSubtraceFn, funcType, + LLVM::Linkage::External); + } + + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{}, 2, std::vector{})); + + auto jitCall = rewriter.create( + op.getLoc(), TypeRange{loweredTraceType}, + mlir::FlatSymbolRefAttr::get(ctx, wrapperFn), + ValueRange{trace, symbolConst, subtracePtr}, + rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, + /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall.getResults()); + + return success(); + } + + return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); + } +}; + +struct GetWeightFromTraceOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string backend; + GetWeightFromTraceOpConversion(std::string backend, + TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), backend(backend) { + } + + LogicalResult + matchAndRewrite(enzyme::GetWeightFromTraceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = op->getContext(); + + Value trace = adaptor.getTrace(); + auto weightType = op.getWeight().getType(); + + if (backend == "cpu") { + auto moduleOp = op->getParentOfType(); + + auto llvmPtrType = LLVM::LLVMPointerType::get(ctx); + auto llvmVoidType = LLVM::LLVMVoidType::get(ctx); + + std::string getWeightFn = "enzyme_probprog_get_weight_from_trace"; + SmallVector originalTypes = {weightType}; + std::string wrapperFn = getOrCreateWrapper(getWeightFn, originalTypes); + auto weightConst = rewriter.create( + op.getLoc(), weightType, cast(makeAttr(weightType, 0))); + + if (!moduleOp.lookupSymbol(wrapperFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, {llvmPtrType, llvmPtrType}, /*isVarArg=*/false); + auto func = + rewriter.create(op.getLoc(), wrapperFn, funcType); + + rewriter.setInsertionPointToStart(func.addEntryBlock(rewriter)); + rewriter.create(op.getLoc(), TypeRange{}, + SymbolRefAttr::get(ctx, getWeightFn), + func.getArguments()); + rewriter.create(op.getLoc(), ValueRange{}); + } + + if (!moduleOp.lookupSymbol(getWeightFn)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto funcType = LLVM::LLVMFunctionType::get( + llvmVoidType, {llvmPtrType, llvmPtrType}, /*isVarArg=*/false); + rewriter.create(op.getLoc(), getWeightFn, funcType, + LLVM::Linkage::External); + } + + SmallVector aliases; + aliases.push_back(stablehlo::OutputOperandAliasAttr::get( + ctx, std::vector{}, 1, std::vector{})); + + auto jitCall = rewriter.create( + op.getLoc(), TypeRange{weightType}, + mlir::FlatSymbolRefAttr::get(ctx, wrapperFn), + ValueRange{trace, weightConst}, rewriter.getStringAttr(""), + /*operand_layouts=*/nullptr, /*result_layouts=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, + /*output_operand_aliases=*/rewriter.getArrayAttr(aliases), + /*xla_side_effect_free=*/nullptr); + + rewriter.replaceOp(op, jitCall.getResults()); + + return success(); + } + + return rewriter.notifyMatchFailure(op, "Unknown backend " + backend); + } +}; + struct LowerEnzymeProbProgPass : public enzyme::impl::LowerEnzymeProbProgPassBase< LowerEnzymeProbProgPass> { @@ -1113,6 +1727,10 @@ struct LowerEnzymeProbProgPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addDynamicallyLegalOp([&](func::FuncOp f) { return typeConverter.isSignatureLegal(f.getFunctionType()); @@ -1127,6 +1745,17 @@ struct LowerEnzymeProbProgPass return typeConverter.isLegal(c.getOperandTypes()) && typeConverter.isLegal(c.getResultTypes()); }); + target.addDynamicallyLegalOp( + [&](arith::SelectOp s) { return typeConverter.isLegal(s.getType()); }); + target.addDynamicallyLegalOp( + [&](tensor::ExtractOp extract) { + if (!extract->hasOneUse()) + return true; + auto selectOp = dyn_cast(*extract->user_begin()); + if (!selectOp) + return true; + return typeConverter.isLegal(selectOp.getType()); + }); RewritePatternSet patterns(context); @@ -1135,12 +1764,15 @@ struct LowerEnzymeProbProgPass populateCallOpTypeConversionPattern(patterns, typeConverter); populateReturnOpTypeConversionPattern(patterns, typeConverter); - patterns.add< - InitTraceOpConversion, AddSampleToTraceOpConversion, - AddSubtraceOpConversion, AddWeightToTraceOpConversion, - AddRetvalToTraceOpConversion, GetSampleFromConstraintOpConversion, - GetSubconstraintOpConversion, UnrealizedConversionCastOpConversion>( - backend, typeConverter, context); + patterns + .add( + backend, typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/test/lit_tests/probprog/mh.mlir b/test/lit_tests/probprog/mh.mlir new file mode 100644 index 000000000..9a3601501 --- /dev/null +++ b/test/lit_tests/probprog/mh.mlir @@ -0,0 +1,164 @@ +// RUN: enzymexlamlir-opt %s --arith-raise --lower-enzyme-probprog | FileCheck %s --check-prefix=CPU +module { + func.func private @model.regenerate(%arg0: !enzyme.Trace, %arg1: tensor<2xui64>) -> (!enzyme.Trace, tensor, tensor<2xui64>) { + %cst = arith.constant dense<0.000000e+00> : tensor + %0 = enzyme.initTrace : !enzyme.Trace + %1 = enzyme.getSampleFromTrace %arg0 {symbol = #enzyme.symbol<1>} : tensor<2xf64> + %2 = enzyme.addSampleToTrace(%1 : tensor<2xf64>) into %0 {symbol = #enzyme.symbol<2>} + return %2, %cst, %arg1 : !enzyme.Trace, tensor, tensor<2xui64> + } + + func.func @mh_program(%arg0: tensor<2xui64>) -> (tensor, tensor<2xui64>) { + %zero = arith.constant dense<0.000000e+00> : tensor + %one = arith.constant dense<1.000000e+00> : tensor + %c0 = stablehlo.constant dense<0> : tensor + %c100 = stablehlo.constant dense<100> : tensor + %c1 = stablehlo.constant dense<1> : tensor + %init_trace = stablehlo.constant dense<0> : tensor + + %0:3 = stablehlo.while(%iterArg = %c0, %iterArg_trace = %init_trace, %iterArg_rng = %arg0) : tensor, tensor, tensor<2xui64> attributes {enzymexla.disable_min_cut} + cond { + %cond = stablehlo.compare LT, %iterArg, %c100 : (tensor, tensor) -> tensor + stablehlo.return %cond : tensor + } do { + %iter_next = stablehlo.add %iterArg, %c1 : tensor + %old_trace = builtin.unrealized_conversion_cast %iterArg_trace : tensor to !enzyme.Trace + %new_trace, %new_weight, %rng1 = func.call @model.regenerate(%old_trace, %iterArg_rng) : (!enzyme.Trace, tensor<2xui64>) -> (!enzyme.Trace, tensor, tensor<2xui64>) + %old_weight = enzyme.getWeightFromTrace %old_trace : tensor + %log_alpha = arith.subf %new_weight, %old_weight : tensor + %rng2, %uniform = enzyme.random %rng1, %zero, %one {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %log_uniform = math.log %uniform : tensor + %accept = arith.cmpf olt, %log_uniform, %log_alpha : tensor + %accept_extracted = tensor.extract %accept[] : tensor + %selected_trace = arith.select %accept_extracted, %new_trace, %old_trace : !enzyme.Trace + %selected_trace_ui64 = builtin.unrealized_conversion_cast %selected_trace : !enzyme.Trace to tensor + stablehlo.return %iter_next, %selected_trace_ui64, %rng2 : tensor, tensor, tensor<2xui64> + } + return %0#1, %0#2 : tensor, tensor<2xui64> + } +} + +// CPU: llvm.func @enzyme_probprog_get_weight_from_trace(!llvm.ptr, !llvm.ptr) +// CPU: llvm.func @enzyme_probprog_get_weight_from_trace_wrapper_0(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { +// CPU-NEXT: llvm.call @enzyme_probprog_get_weight_from_trace(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> () +// CPU-NEXT: llvm.return +// CPU-NEXT: } + +// CPU: llvm.func @enzyme_probprog_add_sample_to_trace(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) +// CPU: llvm.func @enzyme_probprog_add_sample_to_trace_wrapper_0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { +// CPU-NEXT: %0 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %1 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %2 = llvm.alloca %0 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: llvm.store %1, %2 : i64, !llvm.ptr +// CPU-NEXT: %3 = llvm.alloca %1 x !llvm.ptr : (i64) -> !llvm.ptr +// CPU-NEXT: %4 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: %5 = llvm.alloca %1 x !llvm.ptr : (i64) -> !llvm.ptr +// CPU-NEXT: %6 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: %7 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %8 = llvm.getelementptr %3[%7] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %arg2, %8 : !llvm.ptr, !llvm.ptr +// CPU-NEXT: %9 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %10 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %11 = llvm.getelementptr %4[%10] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %9, %11 : i64, !llvm.ptr +// CPU-NEXT: %12 = llvm.mlir.constant(64 : i64) : i64 +// CPU-NEXT: %13 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %14 = llvm.getelementptr %6[%13] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %12, %14 : i64, !llvm.ptr +// CPU-NEXT: %15 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %16 = llvm.alloca %15 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: %17 = llvm.mlir.constant(2 : i64) : i64 +// CPU-NEXT: %18 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %19 = llvm.getelementptr %16[%18] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %17, %19 : i64, !llvm.ptr +// CPU-NEXT: %20 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %21 = llvm.getelementptr %5[%20] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %16, %21 : !llvm.ptr, !llvm.ptr +// CPU-NEXT: llvm.call @enzyme_probprog_add_sample_to_trace(%arg0, %arg1, %3, %2, %4, %5, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () +// CPU-NEXT: llvm.return +// CPU-NEXT: } + +// CPU: llvm.func @enzyme_probprog_get_sample_from_trace(!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) +// CPU: llvm.func @enzyme_probprog_get_sample_from_trace_wrapper_0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { +// CPU-NEXT: %0 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %1 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %2 = llvm.alloca %0 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: llvm.store %1, %2 : i64, !llvm.ptr +// CPU-NEXT: %3 = llvm.alloca %1 x !llvm.ptr : (i64) -> !llvm.ptr +// CPU-NEXT: %4 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: %5 = llvm.alloca %1 x !llvm.ptr : (i64) -> !llvm.ptr +// CPU-NEXT: %6 = llvm.alloca %1 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: %7 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %8 = llvm.getelementptr %3[%7] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %arg2, %8 : !llvm.ptr, !llvm.ptr +// CPU-NEXT: %9 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %10 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %11 = llvm.getelementptr %4[%10] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %9, %11 : i64, !llvm.ptr +// CPU-NEXT: %12 = llvm.mlir.constant(64 : i64) : i64 +// CPU-NEXT: %13 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %14 = llvm.getelementptr %6[%13] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %12, %14 : i64, !llvm.ptr +// CPU-NEXT: %15 = llvm.mlir.constant(1 : i64) : i64 +// CPU-NEXT: %16 = llvm.alloca %15 x i64 : (i64) -> !llvm.ptr +// CPU-NEXT: %17 = llvm.mlir.constant(2 : i64) : i64 +// CPU-NEXT: %18 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %19 = llvm.getelementptr %16[%18] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %17, %19 : i64, !llvm.ptr +// CPU-NEXT: %20 = llvm.mlir.constant(0 : i64) : i64 +// CPU-NEXT: %21 = llvm.getelementptr %5[%20] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CPU-NEXT: llvm.store %16, %21 : !llvm.ptr, !llvm.ptr +// CPU-NEXT: llvm.call @enzyme_probprog_get_sample_from_trace(%arg0, %arg1, %3, %2, %4, %5, %6) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> () +// CPU-NEXT: llvm.return +// CPU-NEXT: } + +// CPU: llvm.func @enzyme_probprog_init_trace(!llvm.ptr) +// CPU: llvm.func @enzyme_probprog_init_trace_wrapper_0(%arg0: !llvm.ptr) { +// CPU-NEXT: llvm.call @enzyme_probprog_init_trace(%arg0) : (!llvm.ptr) -> () +// CPU-NEXT: llvm.return +// CPU-NEXT: } + +// CPU: func.func private @model.regenerate(%arg0: tensor, %arg1: tensor<2xui64>) -> (tensor, tensor, tensor<2xui64>) { +// CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// CPU-NEXT: %c = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %0 = enzymexla.jit_call @enzyme_probprog_init_trace_wrapper_0 (%c) {operand_layouts = [dense<> : tensor<0xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<> : tensor<0xindex>]} : (tensor) -> tensor +// CPU-NEXT: %c_0 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<2xf64> +// CPU-NEXT: %1 = enzymexla.jit_call @enzyme_probprog_get_sample_from_trace_wrapper_0 (%arg0, %c_0, %cst_1) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor, tensor, tensor<2xf64>) -> tensor<2xf64> +// CPU-NEXT: %c_2 = stablehlo.constant dense<2> : tensor +// CPU-NEXT: %2 = enzymexla.jit_call @enzyme_probprog_add_sample_to_trace_wrapper_0 (%0, %c_2, %1) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor, tensor, tensor<2xf64>) -> tensor +// CPU-NEXT: return %2, %cst, %arg1 : tensor, tensor, tensor<2xui64> +// CPU-NEXT: } + +// CPU: func.func @mh_program(%arg0: tensor<2xui64>) -> (tensor, tensor<2xui64>) { +// CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// CPU-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor +// CPU-NEXT: %c = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<100> : tensor +// CPU-NEXT: %c_2 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_3 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %0:3 = stablehlo.while(%iterArg = %c, %iterArg_4 = %c_3, %iterArg_5 = %arg0) : tensor, tensor, tensor<2xui64> attributes {enzymexla.disable_min_cut} +// CPU-NEXT: cond { +// CPU-NEXT: %1 = stablehlo.compare LT, %iterArg, %c_1 : (tensor, tensor) -> tensor +// CPU-NEXT: stablehlo.return %1 : tensor +// CPU-NEXT: } do { +// CPU-NEXT: %1 = stablehlo.add %iterArg, %c_2 : tensor +// CPU-NEXT: %2:3 = func.call @model.regenerate(%iterArg_4, %iterArg_5) : (tensor, tensor<2xui64>) -> (tensor, tensor, tensor<2xui64>) +// CPU-NEXT: %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor +// CPU-NEXT: %3 = enzymexla.jit_call @enzyme_probprog_get_weight_from_trace_wrapper_0 (%iterArg_4, %cst_6) {output_operand_aliases = [#stablehlo.output_operand_alias]} : (tensor, tensor) -> tensor +// CPU-NEXT: %4 = stablehlo.subtract %2#1, %3 : tensor +// CPU-NEXT: %output_state, %output = stablehlo.rng_bit_generator %2#2, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor) +// CPU-NEXT: %c_7 = stablehlo.constant dense<12> : tensor +// CPU-NEXT: %5 = stablehlo.shift_right_logical %output, %c_7 : tensor +// CPU-NEXT: %c_8 = stablehlo.constant dense<4607182418800017408> : tensor +// CPU-NEXT: %6 = stablehlo.or %5, %c_8 : tensor +// CPU-NEXT: %7 = stablehlo.bitcast_convert %6 : (tensor) -> tensor +// CPU-NEXT: %cst_9 = stablehlo.constant dense<1.000000e+00> : tensor +// CPU-NEXT: %8 = stablehlo.subtract %7, %cst_9 : tensor +// CPU-NEXT: %9 = stablehlo.log %8 : tensor +// CPU-NEXT: %10 = stablehlo.compare LT, %9, %4, FLOAT : (tensor, tensor) -> tensor +// CPU-NEXT: %11 = stablehlo.select %10, %2#0, %iterArg_4 : tensor, tensor +// CPU-NEXT: stablehlo.return %1, %11, %output_state : tensor, tensor, tensor<2xui64> +// CPU-NEXT: } +// CPU-NEXT: return %0#1, %0#2 : tensor, tensor<2xui64> +// CPU-NEXT: }