diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index add61054ca..040ffb0e3a 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -139,9 +139,14 @@ void addPassesNNPA(mlir::OwningOpRef &module, else { // Partially lower Krnl ops to Affine dialect. addKrnlToAffinePasses(pm); + // Replace zlow.stick, zlow.unstick by inserting dlf16 conversion + // directly into affine.for loops. This must be done before + // normalize-memrefs so that access indices are automatically generated. + pm.addPass(zlow::createZLowInsertDLF16ConversionPass()); + pm.addPass(mlir::createCanonicalizerPass()); // Normalize MemRefs. normalizeMemRefsPasses(pm); - // Some Knrl ops, e.g. KrnlMemset, potentially exist and will be lowered + // Some Krnl ops, e.g. KrnlMemset, potentially exist and will be lowered // to Affine when its operands are normalized. addKrnlToAffinePasses(pm); // Optimizations at ZLow. diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp index b7033470e5..ced2e90b71 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp @@ -27,6 +27,9 @@ using namespace onnx_mlir; namespace onnx_mlir { namespace zlow { +static bool FUNC_CALL_FOR_DLF16_CONVERSION = false; +static bool SIMD_FOR_DLF16_CONVERSION = true; + zdnn_data_layouts UNDEFINED_ZDNN_LAYOUT = (zdnn_data_layouts)255; // Obtain a zDNN API for an elementwise ZLow operation. @@ -1524,6 +1527,331 @@ class ZLowBatchNormLowering : public ConvertToLLVMPattern { ApiRegistry apiRegistry; }; +class ZLowDLF16ToF32Lowering : public ConvertToLLVMPattern { +public: + explicit ZLowDLF16ToF32Lowering(MLIRContext *context, + LLVMTypeConverter &lowering_, ApiRegistry apiRegistry) + : ConvertToLLVMPattern( + ZLowConvertDLF16ToF32Op::getOperationName(), context, lowering_) { + this->apiRegistry = apiRegistry; + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + Location loc = op->getLoc(); + ZLowConvertDLF16ToF32Op::Adaptor operandAdaptor(operands); + Value input = operandAdaptor.getInput(); + Type i16Ty = rewriter.getI16Type(); + Type i32Ty = rewriter.getI32Type(); + Type f32Ty = rewriter.getF32Type(); + + MultiDialectBuilder create(rewriter, loc); + Value outputF32; + Value inputI16 = create.llvm.bitcast(i16Ty, input); + + if (FUNC_CALL_FOR_DLF16_CONVERSION) { + // This code is for the purpose of testing the correctness of the + // generated LLVM code. + outputF32 = callApi( + rewriter, loc, module, apiRegistry, API::DLF16_TO_F32, {inputI16}); + } else { + if (SIMD_FOR_DLF16_CONVERSION) { + // clang-format off + // tail call <4 x i32> asm sideeffect ".insn vrr,0xe60000000056,$0,$1,0,2,0,0", "=&v,v"(i16 %input) + // clang-format on + + // a vector of 8 elements of i16 - for input + Type vecTypeI16 = LLVM::getFixedVectorType(i16Ty, 8); + // a vector of 4 elements of i32 - for output + Type vecTypeI32 = LLVM::getFixedVectorType(i32Ty, 4); + Type vecTypeF32 = LLVM::getFixedVectorType(f32Ty, 4); + + // SIMD instruction in string for z/Linux. + const char *asmStr = ".insn vrr,0xe60000000056,$0,$2,0,2,0,0 \n\t " + ".insn vrr,0xe6000000005E,$1,$2,0,2,0,0 \n\t"; + const char *asmConstraints = "=&v,=v,v"; + + // Prepare the input vector. + // Only care about the first element. + Value inputVecI16 = rewriter.create(loc, vecTypeI16); + inputVecI16 = create.llvm.insertElement(inputVecI16, inputI16, 0); + SmallVector asmVals{inputVecI16}; + + // Emit SIMD instruction for conversion. + Value outVecI32Struct = + rewriter + .create(loc, + LLVM::LLVMStructType::getLiteral(rewriter.getContext(), + {vecTypeI32, vecTypeI32}, /*Packed=*/false), + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/true, + /*is_align_stack=*/false, + /*asm_dialect=*/LLVM::AsmDialectAttr(), + /*operand_attrs=*/ArrayAttr()) + .getResult(0); + Value outVecI32 = + create.llvm.extractValue(vecTypeI32, outVecI32Struct, 0); + Value outVecF32 = create.llvm.bitcast(vecTypeF32, outVecI32); + outputF32 = create.llvm.extractElement(f32Ty, outVecF32, 0); + } else { + // Generating LLVM instruction here. + // This code is equivalent to the one generated by clang: + // `clang -emit-llvm convert_dlf16_to_f32.cpp -S -O3` + // where `convert_dlf16_to_f32.cpp` can be found at + // https://github.com/tungld/onnx-mlir-tools/blob/main/convert_dlf16_to_f32.cpp + Value inputI32 = create.llvm.zext(i32Ty, inputI16); + // ~DLF16_SIGN + Value c32767 = create.llvm.constant(i32Ty, (int64_t)32767); + // dlf16 & ~DLF16_SIGN + Value v19 = create.llvm.andi(inputI32, c32767); + Value c0 = create.llvm.constant(i32Ty, (int64_t)0); + + // Split the block right before the current op into two blocks. + Block *currentBlock = rewriter.getInsertionBlock(); + // New block contains the terminator of the current block. + Block *newBlock = + currentBlock->splitBlock(rewriter.getInsertionPoint()); + + // Add a block for zero case. + Block *trueBlock = rewriter.createBlock(currentBlock->getParent(), + std::next(Region::iterator(currentBlock))); + + // Add a block for non-zero case. + Block *falseBlock = rewriter.createBlock( + trueBlock->getParent(), std::next(Region::iterator(trueBlock))); + + // Add a new block that acts as a phi node. + Block *endBlock = rewriter.createBlock( + newBlock->getParent(), Region::iterator(newBlock), f32Ty, loc); + rewriter.mergeBlocks(newBlock, endBlock, {}); + + // Emit `if (v19 == 0) then trueBlock else falseBlock` + rewriter.setInsertionPointToEnd(currentBlock); + Value v19Zero = create.llvm.icmp(LLVM::ICmpPredicate::eq, v19, c0); + create.llvm.condBr(v19Zero, trueBlock, {}, falseBlock, {}); + + // Emit code for zero case. + rewriter.setInsertionPointToEnd(trueBlock); + Value cf0 = create.llvm.constant(f32Ty, (float)0.000000e+00); + Value cfm0 = create.llvm.constant(f32Ty, (float)-0.000000e+00); + Value c32768 = create.llvm.constant(i32Ty, (int64_t)32768); + Value v20 = create.llvm.andi(inputI32, c32768); + Value v21 = create.llvm.icmp(LLVM::ICmpPredicate::eq, v20, c0); + Value v22 = create.llvm.select(v21, cf0, cfm0); + create.llvm.br({v22}, endBlock); + + // Emit code for non-zero case. + rewriter.setInsertionPointToEnd(falseBlock); + { + Block *condBlock = rewriter.getInsertionBlock(); + Block *defaultBlock = + condBlock->splitBlock(rewriter.getInsertionPoint()); + + rewriter.setInsertionPointToEnd(condBlock); + Value nan = create.llvm.constant(f32Ty, (float)0x7FC00000); + Value inf = create.llvm.constant(i32Ty, (int64_t)32767); + Value v19Inf = create.llvm.icmp(LLVM::ICmpPredicate::eq, v19, inf); + // Emit `if (v19 == inf) then endBlock(nan) else defaultBlock` + create.llvm.condBr(v19Inf, endBlock, {nan}, defaultBlock, {}); + + // Emit code for non-infinity case. + rewriter.setInsertionPointToEnd(defaultBlock); + Value c14 = create.llvm.constant(i32Ty, (int64_t)14); + Value c16 = create.llvm.constant(i32Ty, (int64_t)16); + Value cm2147483648 = + create.llvm.constant(i32Ty, (int64_t)-2147483648); + Value c528482304 = create.llvm.constant(i32Ty, (int64_t)528482304); + Value c805306368 = create.llvm.constant(i32Ty, (int64_t)805306368); + Value c8372224 = create.llvm.constant(i32Ty, (int64_t)8372224); + Value v23 = create.llvm.shl(inputI32, c16); + Value v24 = create.llvm.andi(v23, cm2147483648); + Value v25 = create.llvm.shl(inputI32, c14); + Value v26 = create.llvm.andi(v25, c528482304); + Value v27 = create.llvm.add(v26, c805306368); + Value v28 = create.llvm.ori(v27, v24); + Value v29 = create.llvm.andi(v25, c8372224); + Value v30 = create.llvm.ori(v28, v29); + Value v31 = create.llvm.bitcast(f32Ty, v30); + create.llvm.br({v31}, endBlock); + } + + rewriter.setInsertionPoint(op); + outputF32 = endBlock->getArgument(0); + } + } + + rewriter.replaceOp(op, {outputF32}); + return success(); + } + +private: + ApiRegistry apiRegistry; +}; + +class ZLowF32ToDLF16Lowering : public ConvertToLLVMPattern { +public: + explicit ZLowF32ToDLF16Lowering(MLIRContext *context, + LLVMTypeConverter &lowering_, ApiRegistry apiRegistry) + : ConvertToLLVMPattern( + ZLowConvertF32ToDLF16Op::getOperationName(), context, lowering_) { + this->apiRegistry = apiRegistry; + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + Location loc = op->getLoc(); + MultiDialectBuilder create(rewriter, loc); + + Type llvmF16Ty = rewriter.getF16Type(); + Type i16Ty = rewriter.getI16Type(); + Type f16Ty = rewriter.getF16Type(); + Type i32Ty = rewriter.getI32Type(); + + ZLowConvertF32ToDLF16Op::Adaptor operandAdaptor(operands); + Value input = operandAdaptor.getInput(); + Value outputDLF16; + + if (FUNC_CALL_FOR_DLF16_CONVERSION) { + // This code is for the purpose of testing the correctness of the + // generated LLVM code. + Value outputI16 = callApi( + rewriter, loc, module, apiRegistry, API::F32_TO_DLF16, {input}); + outputDLF16 = create.llvm.bitcast(llvmF16Ty, outputI16); + } else { + if (SIMD_FOR_DLF16_CONVERSION) { + // a vector of 4 elements of i32 - for input + Type vecTypeI32 = LLVM::getFixedVectorType(i32Ty, 4); + // a vector of 8 elements of i16 - for output + Type vecTypeI16 = LLVM::getFixedVectorType(i16Ty, 8); + Type vecTypeF16 = LLVM::getFixedVectorType(f16Ty, 8); + + // SIMD instruction in string for z/Linux. + const char *asmStr = ".insn vrr,0xe60000000075,$0,$1,$2,0,2,0"; + const char *asmConstraints = "=v,v,v"; + + // Prepare two input vectors: each for left/right four elements. + // Only care about the first element. + Value inputI32 = create.llvm.bitcast(i32Ty, input); + Value inputVecI32Left = rewriter.create(loc, vecTypeI32); + inputVecI32Left = + create.llvm.insertElement(inputVecI32Left, inputI32, 0); + Value inputVecI32Right = + rewriter.create(loc, vecTypeI32); + SmallVector asmVals{inputVecI32Left, inputVecI32Right}; + + // Emit SIMD instruction for conversion. + Value outVecI16 = + rewriter + .create(loc, vecTypeI16, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/true, + /*is_align_stack=*/false, + /*asm_dialect=*/LLVM::AsmDialectAttr(), + /*operand_attrs=*/ArrayAttr()) + .getResult(0); + Value outVecDLF16 = create.llvm.bitcast(vecTypeF16, outVecI16); + outputDLF16 = create.llvm.extractElement(f16Ty, outVecDLF16, 0); + } else { + // Generating LLVM instruction here. + // This code is equivalent to the one generated by clang: + // `clang -emit-llvm convert_f32_to_dlf16.cpp -S -O3` + // where `convert_f32_to_dlf16.cpp` can be found at + // https://github.com/tungld/onnx-mlir-tools/blob/main/convert_f32_to_dlf16.cpp + Value c0 = create.llvm.constant(i32Ty, (int64_t)0); + Value c9 = create.llvm.constant(i32Ty, (int64_t)9); + Value c14 = create.llvm.constant(i32Ty, (int64_t)14); + Value c16 = create.llvm.constant(i32Ty, (int64_t)16); + Value c23 = create.llvm.constant(i32Ty, (int64_t)23); + Value c255 = create.llvm.constant(i32Ty, (int64_t)255); + Value c8192 = create.llvm.constant(i32Ty, (int64_t)8192); + Value c32767 = create.llvm.constant(i32Ty, (int64_t)32767); + Value c32768 = create.llvm.constant(i32Ty, (int64_t)32768); + Value c32256 = create.llvm.constant(i32Ty, (int64_t)32256); + Value c8388607 = create.llvm.constant(i32Ty, (int64_t)8388607); + Value c8380415 = create.llvm.constant(i32Ty, (int64_t)8380415); + Value c1342152704 = create.llvm.constant(i32Ty, (int64_t)1342152704); + Value c2147475456 = create.llvm.constant(i32Ty, (int64_t)2147475456); + Value cm1 = create.llvm.constant(i32Ty, (int64_t)-1); + Value cm95 = create.llvm.constant(i32Ty, (int64_t)-95); + Value cm96 = create.llvm.constant(i32Ty, (int64_t)-96); + Value inputI32 = create.llvm.bitcast(i32Ty, input); + Value v24 = create.llvm.lshr(inputI32, c23); + Value v25 = create.llvm.andi(v24, c255); + Value v26 = create.llvm.andi(inputI32, c8388607); + Value v27 = create.llvm.add(v26, c8192); + Value v28 = create.llvm.icmp(LLVM::ICmpPredicate::ugt, v26, c8380415); + Value v29 = create.llvm.select(v28, cm95, cm96); + Value v30 = create.llvm.add(v29, v25); + Value v31 = create.llvm.lshr(inputI32, c16); + Value v32 = create.llvm.andi(v31, c32768); + Value v33 = create.llvm.icmp(LLVM::ICmpPredicate::sgt, v30, cm1); + + // Split the block right before the current op into two blocks. + Block *currentBlock = rewriter.getInsertionBlock(); + // New block contains the terminator of the current block. + Block *newBlock = + currentBlock->splitBlock(rewriter.getInsertionPoint()); + + // Add a new block for the true branch of the conditional statement we + // will add. + Block *trueBlock = rewriter.createBlock(currentBlock->getParent(), + std::next(Region::iterator(currentBlock))); + + // Add a new block that acts as a phi node. + Block *endBlock = rewriter.createBlock(newBlock->getParent(), + Region::iterator(newBlock), v32.getType(), loc); + rewriter.mergeBlocks(newBlock, endBlock, {}); + + rewriter.setInsertionPointToEnd(currentBlock); + create.llvm.condBr(v33, trueBlock, {}, endBlock, {v32}); + + rewriter.setInsertionPointToEnd(trueBlock); + { + Block *currentBlock = rewriter.getInsertionBlock(); + Block *thenBlock = + currentBlock->splitBlock(rewriter.getInsertionPoint()); + Block *elseBlock = rewriter.createBlock( + thenBlock->getParent(), std::next(Region::iterator(thenBlock))); + + rewriter.setInsertionPointToEnd(currentBlock); + Value v34 = create.llvm.andi(inputI32, c2147475456); + Value v35 = + create.llvm.icmp(LLVM::ICmpPredicate::ult, v34, c1342152704); + create.llvm.condBr(v35, thenBlock, {}, elseBlock, {}); + + rewriter.setInsertionPointToEnd(thenBlock); + Value v36 = create.llvm.shl(v30, c9); + Value v37 = create.llvm.andi(v36, c32256); + Value v38 = create.llvm.lshr(v27, c14); + Value v39 = create.llvm.select(v28, c0, v38); + Value v40 = create.llvm.ori(v39, v37); + Value v41 = create.llvm.ori(v40, v32); + create.llvm.br({v41}, endBlock); + + rewriter.setInsertionPointToEnd(elseBlock); + Value v42 = create.llvm.ori(v31, c32767); + create.llvm.br({v42}, endBlock); + } + + rewriter.setInsertionPoint(op); + Value outputI32 = endBlock->getArgument(0); + Value outputI16 = create.llvm.trunc(i16Ty, outputI32); + outputDLF16 = create.llvm.bitcast(llvmF16Ty, outputI16); + } + } + + rewriter.replaceOp(op, {outputDLF16}); + return success(); + } + +private: + ApiRegistry apiRegistry; +}; + void populateZLowToLLVMConversionPattern(mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter, mlir::MLIRContext *ctx) { ApiRegistry apiRegistry = RegisterAllApis(ctx); @@ -1542,7 +1870,10 @@ void populateZLowToLLVMConversionPattern(mlir::RewritePatternSet &patterns, ZLowMatMulLowering, ZLowConv2DLowering, ZLowMeanReduce2DLowering, - ZLowBatchNormLowering + ZLowBatchNormLowering, + // Scalar operations + ZLowDLF16ToF32Lowering, + ZLowF32ToDLF16Lowering >(ctx, typeConverter, apiRegistry); patterns.insert< // Elementwise operations diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp index a0c4b06cc6..d11664d346 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp @@ -29,9 +29,12 @@ namespace zlow { ApiRegistry RegisterAllApis(MLIRContext *context) { auto voidTy = LLVM::LLVMVoidType::get(context); - auto opaquePtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); + auto int8Ty = IntegerType::get(context, 8); + auto int16Ty = IntegerType::get(context, 16); auto int32Ty = IntegerType::get(context, 32); auto int64Ty = IntegerType::get(context, 64); + auto float32Ty = FloatType::getF32(context); + auto opaquePtrTy = LLVM::LLVMPointerType::get(int8Ty); // Declare API type as an enum value, its string name and an LLVM Type // specifying its signature. @@ -74,9 +77,13 @@ ApiRegistry RegisterAllApis(MLIRContext *context) { ApiSpec(API::ZDNN_MATMUL_BCAST_OP, "zdnn_matmul_bcast_op", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false), ApiSpec(API::ZDNN_CONV2D, "zdnn_conv2d", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_AVGPOOL2D, "zdnn_avgpool2d", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false), - ApiSpec(API:: ZDNN_MAXPOOL2D, "zdnn_maxpool2d", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false), - ApiSpec(API:: ZDNN_MEANREDUCE2D, "zdnn_meanreduce2d", int32Ty, {opaquePtrTy, opaquePtrTy}, false), - ApiSpec(API:: ZDNN_BATCHNORM, "zdnn_batchnorm", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy}, false), + ApiSpec(API::ZDNN_MAXPOOL2D, "zdnn_maxpool2d", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false), + ApiSpec(API::ZDNN_MEANREDUCE2D, "zdnn_meanreduce2d", int32Ty, {opaquePtrTy, opaquePtrTy}, false), + ApiSpec(API::ZDNN_BATCHNORM, "zdnn_batchnorm", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy}, false), + // Scalar operations + // Note: these APIs are for testing purpose only. + ApiSpec(API::DLF16_TO_F32, "cnvt_1_dlf16_to_fp32", float32Ty, {int16Ty}, false), + ApiSpec(API::F32_TO_DLF16, "cnvt_1_fp32_to_dlf16", int16Ty, {float32Ty}, false), }; // clang-format on diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp index 394e84b176..6dc9b8c65e 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp @@ -56,7 +56,10 @@ enum class API { ZDNN_AVGPOOL2D, ZDNN_MAXPOOL2D, ZDNN_MEANREDUCE2D, - ZDNN_BATCHNORM + ZDNN_BATCHNORM, + // Scalar operations. + DLF16_TO_F32, + F32_TO_DLF16, }; // Obtain a zDNN API for an elementwise ZLow operation. diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td index 9f1c0f5df0..c4915caa5a 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td @@ -30,6 +30,14 @@ def ZLow_Dialect : Dialect { class ZLow_Op traits = []> : Op; +// DLFloat16 is not natively supported in MLIR. +// F16 is used instead as a container. +def DLF16 : Type, "dlfloat16 type">, + BuildableType<"$_builder.getF16Type()">; + +// MemRef-like type for zTensor. +def ZMemRef : MemRefOf<[DLF16]>; + //===----------------------------------------------------------------------===// // ZLow Operations //===----------------------------------------------------------------------===// @@ -41,10 +49,10 @@ def ZLowAddOp:ZLow_Op<"add", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform an add. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Y, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + ZMemRef:$Y, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -53,10 +61,10 @@ def ZLowSubOp:ZLow_Op<"sub", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a sub. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Y, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + ZMemRef:$Y, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -65,10 +73,10 @@ def ZLowMulOp:ZLow_Op<"mul", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a mul. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Y, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + ZMemRef:$Y, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -77,10 +85,10 @@ def ZLowDivOp:ZLow_Op<"div", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a div. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Y, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + ZMemRef:$Y, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -89,9 +97,9 @@ def ZLowLogOp:ZLow_Op<"log", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a log. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -100,9 +108,9 @@ def ZLowExpOp:ZLow_Op<"exp", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a exp. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -111,10 +119,10 @@ def ZLowMinOp:ZLow_Op<"min", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a min. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Y, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + ZMemRef:$Y, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -123,10 +131,10 @@ def ZLowMaxOp:ZLow_Op<"max", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a max. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Y, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + ZMemRef:$Y, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -135,9 +143,9 @@ def ZLowReluOp:ZLow_Op<"relu", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a relu. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -146,9 +154,9 @@ def ZLowTanhOp:ZLow_Op<"tanh", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a tanh. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -157,9 +165,9 @@ def ZLowSigmoidOp:ZLow_Op<"sigmoid", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a sigmoid. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$layout); } @@ -170,10 +178,10 @@ def ZLowSoftmaxOp:ZLow_Op<"softmax", [MemRefsNormalizable]> { work_area: a 4K-aligned buffer. act_func: ACT_NONE or ACT_LOG. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$work_area, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + MemRefOf<[I8]>:$work_area, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, StrAttr:$act_func); } @@ -196,11 +204,11 @@ def ZLowMatMulOp:ZLow_Op<"matmul", [MemRefsNormalizable]> { * is_bcast: -1 broadcasting, 0: no broadcasting. * is_stacked: -1 stacked, 0: unstacked. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Y, - AnyTypeOf<[AnyMemRef]>:$Bias, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + ZMemRef:$Y, + ZMemRef:$Bias, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, DefaultValuedAttr:$is_bcast, DefaultValuedAttr:$is_stacked); } @@ -220,17 +228,17 @@ def ZLowLSTMOp:ZLow_Op<"lstm", [MemRefsNormalizable]> { * return_all_steps: -1 returns all timesteps, 0: returns only the last timestep * prev_layer for where input comes is "none", "uni", or "bidir" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$input, - AnyTypeOf<[AnyMemRef]>:$h0, - AnyTypeOf<[AnyMemRef]>:$c0, - AnyTypeOf<[AnyMemRef]>:$input_weights, - AnyTypeOf<[AnyMemRef]>:$input_bias, - AnyTypeOf<[AnyMemRef]>:$hidden_weights, - AnyTypeOf<[AnyMemRef]>:$hidden_bias, - AnyTypeOf<[AnyMemRef]>:$work_area, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$hn_output, - AnyTypeOf<[AnyMemRef]>:$cf_output, + let arguments = (ins ZMemRef:$input, + ZMemRef:$h0, + ZMemRef:$c0, + ZMemRef:$input_weights, + ZMemRef:$input_bias, + ZMemRef:$hidden_weights, + ZMemRef:$hidden_bias, + MemRefOf<[I8]>:$work_area, + MemRefOf<[I64]>:$shape, + ZMemRef:$hn_output, + ZMemRef:$cf_output, DefaultValuedStrAttr:$direction, DefaultValuedAttr:$return_all_steps, DefaultValuedStrAttr:$prev_layer); @@ -251,15 +259,15 @@ def ZLowGRUOp:ZLow_Op<"gru", [MemRefsNormalizable]> { * return_all_steps: -1 returns all timesteps, 0: returns only the last timestep. * prev_layer for where input comes is "none", "uni", or "bidir" }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$input, - AnyTypeOf<[AnyMemRef]>:$h0, - AnyTypeOf<[AnyMemRef]>:$input_weights, - AnyTypeOf<[AnyMemRef]>:$input_bias, - AnyTypeOf<[AnyMemRef]>:$hidden_weights, - AnyTypeOf<[AnyMemRef]>:$hidden_bias, - AnyTypeOf<[AnyMemRef]>:$work_area, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$hn_output, + let arguments = (ins ZMemRef:$input, + ZMemRef:$h0, + ZMemRef:$input_weights, + ZMemRef:$input_bias, + ZMemRef:$hidden_weights, + ZMemRef:$hidden_bias, + MemRefOf<[I8]>:$work_area, + MemRefOf<[I64]>:$shape, + ZMemRef:$hn_output, DefaultValuedStrAttr:$direction, DefaultValuedAttr:$return_all_steps, DefaultValuedStrAttr:$prev_layer); @@ -270,8 +278,8 @@ def ZLowStickOp:ZLow_Op<"stick", [MemRefsNormalizable]> { let description = [{ "ZLow operation to perform a stick." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins MemRefOf<[F16, F32]>:$X, + ZMemRef:$Out, OptionalAttr:$layout); let builders = [ OpBuilder<(ins "::mlir::Value":$X, "::mlir::Value":$Out), [{ @@ -287,11 +295,11 @@ def ZLowStickForLSTMOp:ZLow_Op<"stickForLSTM", [MemRefsNormalizable]> { Variadic: list of pointers for input data to be transformed: - LSTM concatenated: 4 data pointers, one for each input gate in Forget, Input, Cell, Output (FICO) order. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$f_gate, - AnyTypeOf<[AnyMemRef]>:$i_gate, - AnyTypeOf<[AnyMemRef]>:$c_gate, - AnyTypeOf<[AnyMemRef]>:$o_gate, - AnyTypeOf<[AnyMemRef]>:$out, + let arguments = (ins MemRefOf<[F16, F32]>:$f_gate, + MemRefOf<[F16, F32]>:$i_gate, + MemRefOf<[F16, F32]>:$c_gate, + MemRefOf<[F16, F32]>:$o_gate, + ZMemRef:$out, DefaultValuedStrAttr:$prev_layer); } @@ -302,10 +310,10 @@ def ZLowStickForGRUOp:ZLow_Op<"stickForGRU", [MemRefsNormalizable]> { Variadic: list of pointers for input data to be transformed: - GRU concatenated: 3 data pointers, one for each input gate in (Z)update, Reset, Hidden, (ZRH) gate order. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$z_gate, - AnyTypeOf<[AnyMemRef]>:$r_gate, - AnyTypeOf<[AnyMemRef]>:$h_gate, - AnyTypeOf<[AnyMemRef]>:$out, + let arguments = (ins MemRefOf<[F16, F32]>:$z_gate, + MemRefOf<[F16, F32]>:$r_gate, + MemRefOf<[F16, F32]>:$h_gate, + ZMemRef:$out, DefaultValuedStrAttr:$prev_layer); } @@ -314,8 +322,8 @@ def ZLowUnstickOp:ZLow_Op<"unstick", [MemRefsNormalizable]> { let description = [{ ZLow operation to perform a unstick. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$X, - AnyTypeOf<[AnyMemRef]>:$Out, + let arguments = (ins ZMemRef:$X, + MemRefOf<[F16, F32]>:$Out, OptionalAttr:$layout); let builders = [ OpBuilder<(ins "::mlir::Value":$X, "::mlir::Value":$Out), [{ @@ -341,11 +349,11 @@ def ZLowConv2DOp:ZLow_Op<"conv2d", [MemRefsNormalizable]> { * padding_type: SAME_PADDING or VALID_PADDING. * act_func: ACT_NONE or ACT_RELU. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$input, - AnyTypeOf<[AnyMemRef]>:$input_kernel, - AnyTypeOf<[AnyMemRef]>:$input_bias, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$output, + let arguments = (ins ZMemRef:$input, + ZMemRef:$input_kernel, + ZMemRef:$input_bias, + MemRefOf<[I64]>:$shape, + ZMemRef:$output, I64ArrayAttr:$kernel_shape, I64ArrayAttr:$strides, DefaultValuedStrAttr:$padding_type, @@ -367,9 +375,9 @@ def ZLowAvgPool2DOp:ZLow_Op<"avgpool2d", [MemRefsNormalizable]> { * strides: 1D array of stride height and width * padding_type: SAME_PADDING or VALID_PADDING. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$input, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$output, + let arguments = (ins ZMemRef:$input, + MemRefOf<[I64]>:$shape, + ZMemRef:$output, I64ArrayAttr:$kernel_shape, I64ArrayAttr:$strides, DefaultValuedStrAttr:$padding_type @@ -391,9 +399,9 @@ def ZLowMaxPool2DOp:ZLow_Op<"maxpool2d", [MemRefsNormalizable]> { * strides: 1D array of stride height and width * padding_type: SAME_PADDING or VALID_PADDING. }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$input, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$output, + let arguments = (ins ZMemRef:$input, + MemRefOf<[I64]>:$shape, + ZMemRef:$output, I64ArrayAttr:$kernel_shape, I64ArrayAttr:$strides, DefaultValuedStrAttr:$padding_type @@ -410,9 +418,9 @@ def ZLowMeanReduce2DOp:ZLow_Op<"meanreduce2d", [MemRefsNormalizable]> { * 3th item: width": 3rd dim of input * 4nd item: channel": 4th dim of input }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$input, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$output); + let arguments = (ins ZMemRef:$input, + MemRefOf<[I64]>:$shape, + ZMemRef:$output); } def ZLowBatchNormOp:ZLow_Op<"batchnorm", [MemRefsNormalizable]> { @@ -425,11 +433,11 @@ def ZLowBatchNormOp:ZLow_Op<"batchnorm", [MemRefsNormalizable]> { * 3rd item: width * 4th item: channel }]; - let arguments = (ins AnyTypeOf<[AnyMemRef]>:$input, - AnyTypeOf<[AnyMemRef]>:$A, - AnyTypeOf<[AnyMemRef]>:$B, - AnyTypeOf<[MemRefOf<[I64]>]>:$shape, - AnyTypeOf<[AnyMemRef]>:$output); + let arguments = (ins ZMemRef:$input, + ZMemRef:$A, + ZMemRef:$B, + MemRefOf<[I64]>:$shape, + ZMemRef:$output); } def ZLowDummyOp:ZLow_Op<"dummy", [MemRefsNormalizable]> { @@ -448,4 +456,34 @@ def ZLowDummyOp:ZLow_Op<"dummy", [MemRefsNormalizable]> { let hasCanonicalizer = 1; } +def ZLowConvertDLF16ToF32Op:ZLow_Op<"dlf16_to_f32", [Pure]> { + let summary = "Convert a dlfloat16 value to a float32 value"; + let description = [{ + This operation converts a dlfloat16 value to a float32 value. + }]; + let arguments = (ins DLF16:$input); + let results = (outs F32:$output); + let builders = [ + OpBuilder<(ins "::mlir::Value":$input), [{ + build($_builder, $_state, $_builder.getF32Type(), input); + }]> + ]; +} + +def ZLowConvertF32ToDLF16Op:ZLow_Op<"f32_to_dlf16", [Pure]> { + let summary = "Convert a float32 value to a dlfloat16 value"; + let description = [{ + This operation converts a float32 value to a dlfloat16 value. + }]; + let arguments = (ins F32:$input); + let results = (outs DLF16:$output); + let builders = [ + OpBuilder<(ins "::mlir::Value":$input), [{ + // F16 is used as a container for DLF16. + build($_builder, $_state, $_builder.getF16Type(), input); + }]> + ]; + let hasCanonicalizer = 1; +} + #endif // ZLOW_OPS diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.cpp b/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.cpp index cee3c9cbe1..5b2668d372 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.cpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.cpp @@ -35,5 +35,11 @@ void ZLowDummyOp::getCanonicalizationPatterns( results.insert(context); } +/// ZLowConvertF32ToDLF16Op. +void ZLowConvertF32ToDLF16Op::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.insert(context); +} + } // namespace zlow } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.td b/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.td index 1a71a045c7..1b744653ad 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.td +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLowCombine.td @@ -37,4 +37,10 @@ def RemoveDummyOpPattern: Pat< (replaceWithValue $arg) >; +// Remove pairs of dlf16_to_f32 and f32_to_dlf16 +def DLF16ConversionOpPattern: Pat< + (ZLowConvertF32ToDLF16Op (ZLowConvertDLF16ToF32Op $arg)), + (replaceWithValue $arg) +>; + #endif // ZLOW_COMBINE diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index a048cf6749..6473e32a3d 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -86,6 +86,10 @@ void NNPAAccelerator::initPasses(int optLevel) const { return onnx_mlir::createRewriteONNXForZHighPass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return onnx_mlir::zlow::createZLowInsertDLF16ConversionPass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return onnx_mlir::zlow::createZLowRewritePass(); }); diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index 6de62c6797..e2ed831b8b 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -47,6 +47,10 @@ namespace zlow { /// Add pass for rewriting ZLow ops. std::unique_ptr createZLowRewritePass(); +/// Add pass for replacing zlow.unstick and zlow.stick by inserting dlf16 +/// conversion directly into affine.for loops. +std::unique_ptr createZLowInsertDLF16ConversionPass(); + /// Add pass for rewriting ZLow ops. std::unique_ptr createZLowDummyOpForMultiDerefPass(); diff --git a/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt b/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt index 9722e39384..9b96c61817 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt +++ b/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt @@ -19,3 +19,17 @@ add_onnx_mlir_library(OMZLowDummyOpForMultiDeref MLIRTransformUtils OMZLowOps ) + +add_onnx_mlir_library(OMZLowInsertDLF16Conversion + ZLowInsertDLF16Conversion.cpp + + LINK_LIBS PUBLIC + MLIRAffineUtils + MLIRFuncDialect + MLIRRewrite + MLIRTransformUtils + OMZLowOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${NNPA_INCLUDE_PATH} + ) diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowInsertDLF16Conversion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowInsertDLF16Conversion.cpp new file mode 100644 index 0000000000..c72292dd2b --- /dev/null +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowInsertDLF16Conversion.cpp @@ -0,0 +1,227 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------- ZLowInsertDLF16ConversionPass.cpp - ZLow Rewrite Patterns ---===// +// +// Copyright 2023 The IBM Research Authors. +// +// ============================================================================= +// +// This passs removes zlow.unstick and zlow.stick by inserting dlf16 conversion +// directly into affine-for loops; +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Analysis/Utils.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" +#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" +#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" +#include "src/Dialect/Mlir/DialectBuilder.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace zlow { + +// clang-format off + +/// This pattern rewrites +/// ```mlir +/// "zlow.unstick"(%arg0, %alloc) {layout = "2D"} : (memref<5x7xf16, #map>, memref<5x7xf32>) -> () +/// %out = memref.alloc(%dim) {alignment = 16 : i64} : memref<5x7xf32> +/// affine.for %arg1 = 0 to 5 { +/// affine.for %arg2 = 0 to 7 { +/// %0 = affine.load %alloc[%arg1, %arg2] : memref<5x7xf32> +/// %1 = math.sqrt %0 : f32 +/// affine.store %1, %out[%arg1, %arg2] : memref<5x7xf32> +/// } +/// } +/// +/// into +/// +/// %out = memref.alloc(%dim) {alignment = 16 : i64} : memref<5x7xf32> +/// affine.for %arg1 = 0 to 5 { +/// affine.for %arg2 = 0 to 7 { +/// %0 = affine.load %arg0[0, %arg2 floordiv 64, 0, %arg1 floordiv 32, %arg1 mod 32, %arg2 mod 64] : memref<1x1x1x1x32x64xf16> +/// %1 = "zlow.dlf16_to_f32"(%0) : (f16) -> f32 +/// %2 = math.sqrt %1 : f32 +/// affine.store %2, %out[%arg1, %arg2] : memref<5x7xf32> +/// } +/// } +/// ``` +/// where `affine.load` is now loading data directly from a zTensor. +/// + +// clang-format on + +class DLF16ConversionForLoadPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ZLowUnstickOp unstickOp, PatternRewriter &rewriter) const override { + Location loc = unstickOp.getLoc(); + + Operation *op = unstickOp.getOperation(); + Value zMemRef = unstickOp.getX(); + Value cpuMemRef = unstickOp.getOut(); + std::string layout = unstickOp.getLayout().value().str(); + + // 1. Match + + // Only support fp32 and identity affine layout in the CPU MemRef. + if (auto type = dyn_cast(cpuMemRef.getType())) { + if (!type.getElementType().isa()) + return failure(); + AffineMap m = type.getLayout().getAffineMap(); + if (m.getNumResults() != 1 && !m.isIdentity()) + return failure(); + // Optional sizeInBytes = getMemRefSizeInBytes(type); + // if (!sizeInBytes.has_value() || (sizeInBytes.value() > 256 * 4)) + // return failure(); + } + + // Do not support layout 1D and 2DS since their access index functions are + // incorrect: https://github.com/onnx/onnx-mlir/issues/1940 + if ((layout == LAYOUT_1D) || (layout == LAYOUT_2DS)) + return failure(); + + // All users except zlow.unstick must be affine.load, so that zlow.unstick + // will be dangling and can be totally removed at the end of this pass. + SmallVector affineLoads; + for (Operation *user : cpuMemRef.getUsers()) { + if (user == op) + continue; + if (auto affineLoad = llvm::dyn_cast(user)) + affineLoads.emplace_back(affineLoad); + else + return failure(); + } + if (affineLoads.size() == 0) + return failure(); + + // 2. Rewrite + MultiDialectBuilder create(rewriter, loc); + for (AffineLoadOp loadOp : affineLoads) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(loadOp); + ValueRange indices = loadOp.getIndices(); + // Load a dfl16 directly from zTensor and convert it to fp32. + Value loadDLF16 = create.affine.load(zMemRef, indices); + Value toFP32 = rewriter.create(loc, loadDLF16); + rewriter.replaceOp(loadOp, {toFP32}); + } + + rewriter.eraseOp(unstickOp); + return success(); + } +}; + +class DLF16ConversionForStorePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ZLowStickOp stickOp, PatternRewriter &rewriter) const override { + Location loc = stickOp.getLoc(); + + Operation *op = stickOp.getOperation(); + Value cpuMemRef = stickOp.getX(); + Value zMemRef = stickOp.getOut(); + std::string layout = stickOp.getLayout().value().str(); + + // 1. Match + + // Only support fp32 and identity affine layout in the CPU MemRef. + if (auto type = dyn_cast(cpuMemRef.getType())) { + if (!type.getElementType().isa()) + return failure(); + AffineMap m = type.getLayout().getAffineMap(); + if (m.getNumResults() != 1 && !m.isIdentity()) + return failure(); + // Optional sizeInBytes = getMemRefSizeInBytes(type); + // if (!sizeInBytes.has_value() || (sizeInBytes.value() > 256 * 4)) + // return failure(); + } + + // Do not support layout 1D and 2DS since their access index functions are + // incorrect: https://github.com/onnx/onnx-mlir/issues/1940 + if ((layout == LAYOUT_1D) || (layout == LAYOUT_2DS)) + return failure(); + + // All users except zlow.stick must be affine.load, so that zlow.stick + // will be dangling and can be totally removed at the end of this pass. + SmallVector affineStores; + for (Operation *user : cpuMemRef.getUsers()) { + if (user == op) + continue; + if (auto affineStore = llvm::dyn_cast(user)) + affineStores.emplace_back(affineStore); + else + return failure(); + } + if (affineStores.size() == 0) + return failure(); + + // 2. Rewrite + // Move up the allocation of zMemRef so that it dominates the following + // stores. + zMemRef.getDefiningOp()->moveAfter(cpuMemRef.getDefiningOp()); + + // Replace AffineStoreOp. + MultiDialectBuilder create(rewriter, loc); + for (AffineStoreOp storeOp : affineStores) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(storeOp); + ValueRange indices = storeOp.getIndices(); + Value f32 = storeOp.getValue(); + // Convert the value to dlf16 and store it to zTensor directly. + Value dlf16 = rewriter.create(loc, f32); + create.affine.store(dlf16, zMemRef, indices); + // Remove the old AffineStore. + rewriter.eraseOp(storeOp); + } + + rewriter.eraseOp(stickOp); + return success(); + } +}; +/*! + * Function pass that optimizes ZLowIR. + */ +class ZLowInsertDLF16ConversionPass + : public PassWrapper> { +public: + StringRef getArgument() const override { + return "zlow-insert-dlf16-conversion"; + } + + StringRef getDescription() const override { + return "Replacing zlow.unstick and zlow.stick by inserting dlf16 " + "conversion directly into affine-for loops"; + } + + void runOnOperation() override { + Operation *function = getOperation(); + + ConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) + return signalPassFailure(); + } +}; + +std::unique_ptr createZLowInsertDLF16ConversionPass() { + return std::make_unique(); +} + +} // namespace zlow +} // namespace onnx_mlir diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 66760c9476..459b09c8e4 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -940,6 +940,11 @@ Value LLVMBuilder::_alloca( return b().create(loc(), resultType, size, alignment); } +Value LLVMBuilder::andi(Value lhs, Value rhs) const { + assert(lhs.getType() == rhs.getType() && "expected same type"); + return b().create(loc(), lhs, rhs); +} + Value LLVMBuilder::bitcast(Type type, Value val) const { return b().create(loc(), type, val); } @@ -1036,6 +1041,13 @@ Value LLVMBuilder::constant(Type type, double val) const { return constant; } +Value LLVMBuilder::extractElement( + Type resultType, Value container, int64_t position) const { + Value posVal = constant(b().getI64Type(), position); + return b().create( + loc(), resultType, container, posVal); +} + Value LLVMBuilder::extractValue( Type resultType, Value container, ArrayRef position) const { return b().create( @@ -1062,6 +1074,11 @@ Value LLVMBuilder::icmp(LLVM::ICmpPredicate cond, Value lhs, Value rhs) const { return b().create(loc(), cond, lhs, rhs); } +Value LLVMBuilder::insertElement(Value vec, Value val, int64_t position) const { + Value posVal = constant(b().getI64Type(), position); + return b().create(loc(), vec, val, posVal); +} + Value LLVMBuilder::insertValue(Type resultType, Value container, Value val, llvm::ArrayRef position) const { return b().create( @@ -1072,6 +1089,10 @@ Value LLVMBuilder::inttoptr(Type type, Value val) const { return b().create(loc(), type, val); } +Value LLVMBuilder::lshr(Value lhs, Value rhs) const { + return b().create(loc(), lhs, rhs); +} + Value LLVMBuilder::load(Value addr) const { return b().create(loc(), addr); } @@ -1089,6 +1110,11 @@ Value LLVMBuilder::nullI8Ptr() const { return b().create(loc(), I8PtrTy); } +Value LLVMBuilder::ori(Value lhs, Value rhs) const { + assert(lhs.getType() == rhs.getType() && "expected same type"); + return b().create(loc(), lhs, rhs); +} + Value LLVMBuilder::ptrtoint(Type type, Value val) const { return b().create(loc(), type, val); } @@ -1097,14 +1123,31 @@ void LLVMBuilder::_return(Value val) const { b().create(loc(), ArrayRef({val})); } +Value LLVMBuilder::select(Value cmp, Value lhs, Value rhs) const { + assert(lhs.getType() == rhs.getType() && "expected same type"); + return b().create(loc(), cmp, lhs, rhs); +} + Value LLVMBuilder::sext(Type type, Value val) const { return b().create(loc(), type, val); } +Value LLVMBuilder::shl(Value lhs, Value rhs) const { + return b().create(loc(), lhs, rhs); +} + void LLVMBuilder::store(Value val, Value addr) const { b().create(loc(), val, addr); } +Value LLVMBuilder::trunc(Type type, Value val) const { + return b().create(loc(), type, val); +} + +Value LLVMBuilder::zext(Type type, Value val) const { + return b().create(loc(), type, val); +} + FlatSymbolRefAttr LLVMBuilder::getOrInsertSymbolRef(ModuleOp module, StringRef funcName, Type resultType, ArrayRef operandTypes, bool isVarArg) const { diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index 7e1fe16408..c729e772a1 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -367,6 +367,9 @@ struct LLVMBuilder final : DialectBuilder { mlir::Value _alloca( mlir::Type resultType, mlir::Value size, int64_t alignment) const; + // AndOp + mlir::Value andi(mlir::Value lhs, mlir::Value rhs) const; + // BitcastOp mlir::Value bitcast(mlir::Type type, mlir::Value val) const; mlir::Value bitcastI8Ptr(mlir::Value val) const; @@ -392,6 +395,10 @@ struct LLVMBuilder final : DialectBuilder { mlir::Value constant(mlir::Type type, int64_t val) const; mlir::Value constant(mlir::Type type, double val) const; + // ExtractElementOp + mlir::Value extractElement( + mlir::Type resultType, mlir::Value container, int64_t position) const; + // ExtractValueOp mlir::Value extractValue(mlir::Type resultType, mlir::Value container, llvm::ArrayRef position) const; @@ -412,6 +419,10 @@ struct LLVMBuilder final : DialectBuilder { mlir::Value icmp( mlir::LLVM::ICmpPredicate cond, mlir::Value lhs, mlir::Value rhs) const; + // InsertElementOp + mlir::Value insertElement( + mlir::Value vec, mlir::Value val, int64_t position) const; + // InsertValueOp mlir::Value insertValue(mlir::Type resultType, mlir::Value container, mlir::Value val, llvm::ArrayRef position) const; @@ -419,6 +430,9 @@ struct LLVMBuilder final : DialectBuilder { // Inttoptr mlir::Value inttoptr(mlir::Type type, mlir::Value val) const; + // LShrOp + mlir::Value lshr(mlir::Value lhs, mlir::Value rhs) const; + // LoadOp mlir::Value load(mlir::Value addr) const; @@ -429,18 +443,33 @@ struct LLVMBuilder final : DialectBuilder { mlir::Value null(mlir::Type type) const; mlir::Value nullI8Ptr() const; + // OrOp + mlir::Value ori(mlir::Value lhs, mlir::Value rhs) const; + // Ptrtoint mlir::Value ptrtoint(mlir::Type type, mlir::Value val) const; // ReturnOp void _return(mlir::Value val) const; + // SelectOp + mlir::Value select(mlir::Value cmp, mlir::Value lhs, mlir::Value rhs) const; + // SExtOp mlir::Value sext(mlir::Type type, mlir::Value val) const; + // ShlOp + mlir::Value shl(mlir::Value lhs, mlir::Value rhs) const; + // StoreOp void store(mlir::Value val, mlir::Value addr) const; + // TruncOp + mlir::Value trunc(mlir::Type type, mlir::Value val) const; + + // ZExtOp + mlir::Value zext(mlir::Type type, mlir::Value val) const; + //===--------------------------------------------------------------------===// // Helper functions //===--------------------------------------------------------------------===// diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir index 1197b92c74..360036bfed 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir @@ -4,9 +4,9 @@ func.func @test_lower_both_zlow_and_krnl() -> () { %0 = memref.alloc() : memref<10x10xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> + %1 = memref.alloc() : memref<1x1x32x64xf16> %2 = "krnl.global"() {name = "constant_0", shape = [1, 2], value = dense<[[0., 1.0]]> : tensor<1x2xf32>} : () -> memref<1x2xf32> - "zlow.stick"(%0, %1) : (memref<10x10xf32>, memref<1x1x32x64xf32>) -> () + "zlow.stick"(%0, %1) : (memref<10x10xf32>, memref<1x1x32x64xf16>) -> () return // CHECK-DAG: llvm.mlir.global internal constant @{{.*}}(dense<{{\[}}[0.000000e+00, 1.000000e+00]{{\]}}> : tensor<1x2xf32>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<1 x array<2 x f32>> @@ -16,16 +16,16 @@ func.func @test_lower_both_zlow_and_krnl() -> () { func.func @test_stick() -> () { %0 = memref.alloc() : memref<10x10xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - "zlow.stick"(%0, %1) : (memref<10x10xf32>, memref<1x1x32x64xf32>) -> () + %1 = memref.alloc() : memref<1x1x32x64xf16> + "zlow.stick"(%0, %1) : (memref<10x10xf32>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_stick // CHECK: [[UNSTICKIFIED_MEMREF:%.+]] = llvm.insertvalue {{.*}}, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[STICKIFIED_MEMREF:%.+]] = llvm.insertvalue {{.*}}, {{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: [[STICKIFIED_MEMREF:%.+]] = llvm.insertvalue {{.*}}, {{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> - // CHECK: [[ALIGNED_BUFFER:%.+]] = llvm.extractvalue [[STICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> - // CHECK: [[ALIGNED_BUFFER_I8PTR:%.+]] = llvm.bitcast [[ALIGNED_BUFFER]] : !llvm.ptr to !llvm.ptr + // CHECK: [[ALIGNED_BUFFER:%.+]] = llvm.extractvalue [[STICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: [[ALIGNED_BUFFER_I8PTR:%.+]] = llvm.bitcast [[ALIGNED_BUFFER]] : !llvm.ptr to !llvm.ptr // CHECK: [[PRE_TRANSFORMED_DESC:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(i32, i32, i32, i32, i32, i32, i32)> : (i64) -> !llvm.ptr> // CHECK: [[DATA_LAYOUT:%.+]] = llvm.mlir.constant(1 : i64) : i64 @@ -74,17 +74,17 @@ func.func @test_stick() -> () { // ----- func.func @test_unstick() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> %1 = memref.alloc() : memref<10x10xf32> - "zlow.unstick"(%0, %1) : (memref<1x1x32x64xf32>, memref<10x10xf32>) -> () + "zlow.unstick"(%0, %1) : (memref<1x1x32x64xf16>, memref<10x10xf32>) -> () return // CHECK-LABEL: test_unstick - // CHECK: [[STICKIFIED_MEMREF:%.+]] = llvm.insertvalue {{.*}}, {{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: [[STICKIFIED_MEMREF:%.+]] = llvm.insertvalue {{.*}}, {{.*}}[4, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK: [[UNSTICKIFIED_MEMREF:%.+]] = llvm.insertvalue {{.*}}, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: [[ALIGNED_BUFFER:%.+]] = llvm.extractvalue [[STICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> - // CHECK: [[ALIGNED_BUFFER_I8PTR:%.+]] = llvm.bitcast [[ALIGNED_BUFFER]] : !llvm.ptr to !llvm.ptr + // CHECK: [[ALIGNED_BUFFER:%.+]] = llvm.extractvalue [[STICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: [[ALIGNED_BUFFER_I8PTR:%.+]] = llvm.bitcast [[ALIGNED_BUFFER]] : !llvm.ptr to !llvm.ptr // CHECK: [[PRE_TRANSFORMED_DESC:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(i32, i32, i32, i32, i32, i32, i32)> : (i64) -> !llvm.ptr> // CHECK: [[DATA_LAYOUT:%.+]] = llvm.mlir.constant(1 : i64) : i64 @@ -134,10 +134,10 @@ func.func @test_unstick() -> () { // Check whether the lowering of zlow.relu calls the correct zDNN API or not. func.func @test_call_zdnn_relu() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.relu"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.relu"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_relu @@ -148,10 +148,10 @@ func.func @test_call_zdnn_relu() -> () { // Check whether the lowering of zlow.tanh calls the correct zDNN API or not. func.func @test_call_zdnn_tanh() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.tanh"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.tanh"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_tanh @@ -162,10 +162,10 @@ func.func @test_call_zdnn_tanh() -> () { // Check whether the lowering of zlow.sigmoid calls the correct zDNN API or not. func.func @test_call_zdnn_sigmoid() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.sigmoid"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.sigmoid"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_sigmoid @@ -176,11 +176,11 @@ func.func @test_call_zdnn_sigmoid() -> () { // Check whether the lowering of zlow.add calls the correct zDNN API or not. func.func @test_call_zdnn_add() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - %2 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %2 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.add"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.add"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_add @@ -191,11 +191,11 @@ func.func @test_call_zdnn_add() -> () { // Check whether the lowering of zlow.sub calls the correct zDNN API or not. func.func @test_call_zdnn_sub() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - %2 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %2 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.sub"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.sub"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_sub @@ -206,11 +206,11 @@ func.func @test_call_zdnn_sub() -> () { // Check whether the lowering of zlow.mul calls the correct zDNN API or not. func.func @test_call_zdnn_mul() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - %2 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %2 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.mul"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.mul"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_mul @@ -221,11 +221,11 @@ func.func @test_call_zdnn_mul() -> () { // Check whether the lowering of zlow.div calls the correct zDNN API or not. func.func @test_call_zdnn_div() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - %2 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %2 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.div"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.div"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_div @@ -236,11 +236,11 @@ func.func @test_call_zdnn_div() -> () { // Check whether the lowering of zlow.softmax calls the correct zDNN API or not. func.func @test_call_zdnn_softmax() -> () { - %0 = memref.alloc() : memref<1x1x1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x1x1x32x64xf16> %work_area = memref.alloc() {alignment = 4096 : i64} : memref<8192xi8> %shape = memref.alloc() : memref<3xi64> - "zlow.softmax"(%0, %work_area, %shape, %1) {act_func = "ACT_NONE"} : (memref<1x1x1x1x32x64xf32>, memref<8192xi8>, memref<3xi64>, memref<1x1x1x1x32x64xf32>) -> () + "zlow.softmax"(%0, %work_area, %shape, %1) {act_func = "ACT_NONE"} : (memref<1x1x1x1x32x64xf16>, memref<8192xi8>, memref<3xi64>, memref<1x1x1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_softmax @@ -256,8 +256,8 @@ func.func @test_stick_for_lstm() -> () { %i = memref.alloc() : memref<1x10x10xf32> %c = memref.alloc() : memref<1x10x10xf32> %o = memref.alloc() : memref<1x10x10xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - "zlow.stickForLSTM"(%f, %i, %c, %o, %1) : (memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x1x32x64xf32>) -> () + %1 = memref.alloc() : memref<1x1x32x64xf16> + "zlow.stickForLSTM"(%f, %i, %c, %o, %1) : (memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_stick_for_lstm @@ -272,8 +272,8 @@ func.func @test_stick_for_gru() -> () { %g = memref.alloc() : memref<1x10x10xf32> %r = memref.alloc() : memref<1x10x10xf32> %h = memref.alloc() : memref<1x10x10xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - "zlow.stickForGRU"(%g, %r, %h, %1) : (memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x1x32x64xf32>) -> () + %1 = memref.alloc() : memref<1x1x32x64xf16> + "zlow.stickForGRU"(%g, %r, %h, %1) : (memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x10x10xf32>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_stick_for_gru @@ -284,11 +284,11 @@ func.func @test_stick_for_gru() -> () { // Check whether the lowering of zlow.min calls the correct zDNN API or not. func.func @test_call_zdnn_min() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - %2 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %2 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.min"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.min"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_min @@ -299,11 +299,11 @@ func.func @test_call_zdnn_min() -> () { // Check whether the lowering of zlow.max calls the correct zDNN API or not. func.func @test_call_zdnn_max() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> - %2 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %2 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.max"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.max"(%0, %1, %shape, %2) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_max @@ -314,10 +314,10 @@ func.func @test_call_zdnn_max() -> () { // Check whether the lowering of zlow.exp calls the correct zDNN API or not. func.func @test_call_zdnn_exp() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.exp"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.exp"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_exp @@ -328,10 +328,10 @@ func.func @test_call_zdnn_exp() -> () { // Check whether the lowering of zlow.log calls the correct zDNN API or not. func.func @test_call_zdnn_log() -> () { - %0 = memref.alloc() : memref<1x1x32x64xf32> - %1 = memref.alloc() : memref<1x1x32x64xf32> + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> %shape = memref.alloc() : memref<2xi64> - "zlow.log"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf32>, memref<2xi64>, memref<1x1x32x64xf32>) -> () + "zlow.log"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () return // CHECK-LABEL: test_call_zdnn_log @@ -341,10 +341,10 @@ func.func @test_call_zdnn_log() -> () { // ----- // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. -func.func @test_matmul_no_bcast_unstacked(%x: memref<2048xf32>,%y: memref<2048xf32>,%bias: memref<2048xf32>, %shape: memref<3xi64>) -> memref<2048xf32> { - %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = 0 : si64} : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<3xi64>, memref<2048xf32>) -> () - return %res : memref<2048xf32> +func.func @test_matmul_no_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { + %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_no_bcast_unstacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_op(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 } @@ -352,10 +352,10 @@ func.func @test_matmul_no_bcast_unstacked(%x: memref<2048xf32>,%y: memref<2048xf // ----- // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. -func.func @test_matmul_no_bcast_stacked(%x: memref<2048xf32>,%y: memref<2048xf32>,%bias: memref<2048xf32>, %shape: memref<3xi64>) -> memref<2048xf32> { - %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = -1 : si64} : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<3xi64>, memref<2048xf32>) -> () - return %res : memref<2048xf32> +func.func @test_matmul_no_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { + %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_no_bcast_stacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_op(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 } @@ -363,10 +363,10 @@ func.func @test_matmul_no_bcast_stacked(%x: memref<2048xf32>,%y: memref<2048xf32 // ----- // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. -func.func @test_matmul_bcast_stacked(%x: memref<2048xf32>,%y: memref<2048xf32>,%bias: memref<2048xf32>, %shape: memref<3xi64>) -> memref<2048xf32> { - %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = -1 : si64} : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<3xi64>, memref<2048xf32>) -> () - return %res : memref<2048xf32> +func.func @test_matmul_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { + %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_bcast_stacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_bcast_op(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 } @@ -374,10 +374,10 @@ func.func @test_matmul_bcast_stacked(%x: memref<2048xf32>,%y: memref<2048xf32>,% // ----- // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. -func.func @test_matmul_bcast_unstacked(%x: memref<2048xf32>,%y: memref<2048xf32>,%bias: memref<2048xf32>, %shape: memref<3xi64>) -> memref<2048xf32> { - %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = 0 : si64} : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<3xi64>, memref<2048xf32>) -> () - return %res : memref<2048xf32> +func.func @test_matmul_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { + %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_bcast_unstacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_bcast_op(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 } @@ -386,12 +386,12 @@ func.func @test_matmul_bcast_unstacked(%x: memref<2048xf32>,%y: memref<2048xf32> // Check whether conv2d calls the correct zDNN API or not. func.func @test_call_zdnn_cond2d() -> () { - %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %kernel = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %bias = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> + %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %kernel = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %bias = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> %shape = memref.alloc() : memref<7xi64> - "zlow.conv2d"(%input, %kernel, %bias, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING", act_func = "ACT_NONE" } : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<7xi64>, memref<2048xf32>)-> () + "zlow.conv2d"(%input, %kernel, %bias, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING", act_func = "ACT_NONE" } : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<7xi64>, memref<2048xf16>)-> () return // CHECK-LABEL: test_call_zdnn_cond2d @@ -402,12 +402,12 @@ func.func @test_call_zdnn_cond2d() -> () { // Check whether conv2d calls the correct zDNN API or not. func.func @test_call_zdnn_cond2d_valid_padding() -> () { - %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %kernel = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %bias = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> + %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %kernel = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %bias = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> %shape = memref.alloc() : memref<7xi64> - "zlow.conv2d"(%input, %kernel, %bias, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "VALID_PADDING", act_func = "ACT_NONE" } : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<7xi64>, memref<2048xf32>)-> () + "zlow.conv2d"(%input, %kernel, %bias, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "VALID_PADDING", act_func = "ACT_NONE" } : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<7xi64>, memref<2048xf16>)-> () return // CHECK-LABEL: test_call_zdnn_cond2d_valid_padding @@ -418,12 +418,12 @@ func.func @test_call_zdnn_cond2d_valid_padding() -> () { // Check whether conv2d calls the correct zDNN API or not. func.func @test_call_zdnn_cond2d_relu_act() -> () { - %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %kernel = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %bias = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> + %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %kernel = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %bias = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> %shape = memref.alloc() : memref<7xi64> - "zlow.conv2d"(%input, %kernel, %bias, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING", act_func = "ACT_RELU" } : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<7xi64>, memref<2048xf32>)-> () + "zlow.conv2d"(%input, %kernel, %bias, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING", act_func = "ACT_RELU" } : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<7xi64>, memref<2048xf16>)-> () return // CHECK-LABEL: test_call_zdnn_cond2d_relu_act @@ -434,10 +434,10 @@ func.func @test_call_zdnn_cond2d_relu_act() -> () { // Check whether avgpool2d calls the correct zDNN API or not. func.func @test_call_zdnn_avgpool2d() -> () { - %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> + %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> %shape = memref.alloc() : memref<6xi64> - "zlow.avgpool2d"(%input, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING" } : (memref<2048xf32>, memref<6xi64>, memref<2048xf32>)-> () + "zlow.avgpool2d"(%input, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING" } : (memref<2048xf16>, memref<6xi64>, memref<2048xf16>)-> () return // CHECK-LABEL: test_call_zdnn_avgpool2d @@ -448,10 +448,10 @@ func.func @test_call_zdnn_avgpool2d() -> () { // Check whether maxpool2d calls the correct zDNN API or not. func.func @test_call_zdnn_maxpool2d() -> () { - %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> + %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> %shape = memref.alloc() : memref<6xi64> - "zlow.maxpool2d"(%input, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING" } : (memref<2048xf32>, memref<6xi64>, memref<2048xf32>)-> () + "zlow.maxpool2d"(%input, %shape, %output) {kernel_shape = [5, 5], strides = [2, 2], padding_type = "SAME_PADDING" } : (memref<2048xf16>, memref<6xi64>, memref<2048xf16>)-> () return // CHECK-LABEL: test_call_zdnn_maxpool2d @@ -463,10 +463,10 @@ func.func @test_call_zdnn_maxpool2d() -> () { // Check whether meanreduce2d calls the correct zDNN API or not. func.func @test_call_zdnn_meanreduce2d() -> () { - %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> + %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> %shape = memref.alloc() : memref<4xi64> - "zlow.meanreduce2d"(%input, %shape, %output) : (memref<2048xf32>, memref<4xi64>, memref<2048xf32>)-> () + "zlow.meanreduce2d"(%input, %shape, %output) : (memref<2048xf16>, memref<4xi64>, memref<2048xf16>)-> () return // CHECK-LABEL: test_call_zdnn_meanreduce2d @@ -477,12 +477,12 @@ func.func @test_call_zdnn_meanreduce2d() -> () { // Check whether batchnorm calls the correct zDNN API or not. func.func @test_call_zdnn_batchnorm() -> () { - %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %a = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - %b = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> + %input = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %a = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %b = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> %shape = memref.alloc() : memref<4xi64> - %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf32> - "zlow.batchnorm"(%input, %a, %b, %shape, %output) : (memref<2048xf32>, memref<2048xf32>, memref<2048xf32>, memref<4xi64>, memref<2048xf32>)-> () + %output = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + "zlow.batchnorm"(%input, %a, %b, %shape, %output) : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<4xi64>, memref<2048xf16>)-> () return // CHECK-LABEL: test_call_zdnn_batchnorm