From 7cbddeb2346efc95e59f0b142e0a628154370421 Mon Sep 17 00:00:00 2001 From: BradleyWood Date: Wed, 30 Aug 2023 08:42:40 -0600 Subject: [PATCH] x86: Support masked bitwise operations for i8, i16 AVX-512 does not support masked bitwise and/or/xor operations for i8 and i16 data types. This commit supports native masking by using the i32 bitwise opcode + i8/i16 masked move instruction. Signed-off-by: BradleyWood --- compiler/x/codegen/OMRTreeEvaluator.cpp | 39 ++++++++++++++++++------- compiler/x/codegen/OMRTreeEvaluator.hpp | 3 +- compiler/x/codegen/X86OpcodeTable.hpp | 6 ++-- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/compiler/x/codegen/OMRTreeEvaluator.cpp b/compiler/x/codegen/OMRTreeEvaluator.cpp index 2567e6b185b..f711020fe80 100644 --- a/compiler/x/codegen/OMRTreeEvaluator.cpp +++ b/compiler/x/codegen/OMRTreeEvaluator.cpp @@ -4047,18 +4047,12 @@ TR::InstOpCode OMR::X86::TreeEvaluator::getNativeSIMDOpcode(TR::ILOpCodes opcode break; case TR::vand: binaryOp = BinaryArithmeticAnd; - // Masking opcodes require lanewise support for each element type, however, int8/int16 - // bitwise instructions with masking are not supported without AVX-512. In non-masking - // operations, the element type does not matter. - if (!isMaskOp) elementType = TR::Int32; break; case TR::vor: binaryOp = BinaryArithmeticOr; - if (!isMaskOp) elementType = TR::Int32; break; case TR::vxor: binaryOp = BinaryArithmeticXor; - if (!isMaskOp) elementType = TR::Int32; break; case TR::vmmin: case TR::vmin: @@ -4458,6 +4452,24 @@ TR::Register* OMR::X86::TreeEvaluator::vectorBinaryArithmeticEvaluator(TR::Node* TR::Register *tmpNaNReg = NULL; bool useRegMemForm = cg->comp()->target().cpu.supportsAVX() && !mask; + bool maskTypeMismatch = false; + + if (et == TR::Int8 || et == TR::Int16) + { + switch (node->getOpCode().getVectorOperation()) + { + case TR::vand: + case TR::vor: + case TR::vxor: + // There are no native opcodes meant specifically for these element types + // Therefore, if masking is required, we cannot use a single instruction + // to perform these masked bitwise operations because of the element type mismatch. + maskTypeMismatch = true; + break; + default: + break; + } + } if (useRegMemForm) { @@ -4522,7 +4534,7 @@ TR::Register* OMR::X86::TreeEvaluator::vectorBinaryArithmeticEvaluator(TR::Node* TR::Register *rSrcReg = tmpNaNReg ? vectorFPNaNHelper(node, tmpNaNReg, lhsReg, rhsReg, NULL, cg) : rhsReg; if (maskReg) { - binaryVectorMaskHelper(nativeOpcode, simdEncoding, node, resultReg, lhsReg, rSrcReg, maskReg, cg); + binaryVectorMaskHelper(nativeOpcode, simdEncoding, node, resultReg, lhsReg, rSrcReg, maskReg, cg, maskTypeMismatch); } else { @@ -4533,7 +4545,7 @@ TR::Register* OMR::X86::TreeEvaluator::vectorBinaryArithmeticEvaluator(TR::Node* else if (maskReg) { TR::Register *rSrcReg = tmpNaNReg ? vectorFPNaNHelper(node, tmpNaNReg, lhsReg, rhsReg, NULL, cg) : rhsReg; - binaryVectorMaskHelper(nativeOpcode, simdEncoding, node, resultReg, lhsReg, rSrcReg, maskReg, cg); + binaryVectorMaskHelper(nativeOpcode, simdEncoding, node, resultReg, lhsReg, rSrcReg, maskReg, cg, maskTypeMismatch); } else { @@ -5355,7 +5367,8 @@ OMR::X86::TreeEvaluator::binaryVectorMaskHelper(TR::InstOpCode opcode, TR::Register *lhsReg, TR::Register *rhsReg, TR::Register *maskReg, - TR::CodeGenerator *cg) + TR::CodeGenerator *cg, + bool maskTypeMismatch) { TR_ASSERT_FATAL(encoding != OMR::X86::Bad, "No suitable encoding method for opcode"); bool vectorMask = maskReg->getKind() == TR_VRF; @@ -5375,13 +5388,19 @@ OMR::X86::TreeEvaluator::binaryVectorMaskHelper(TR::InstOpCode opcode, cg->stopUsingRegister(tmpReg); return resultReg; } - else if (vectorMask) + else if (vectorMask && maskTypeMismatch) { generateRegRegRegInstruction(opcode.getMnemonic(), node, tmpReg, lhsReg, rhsReg, cg, encoding); vectorMergeMaskHelper(node, resultReg, tmpReg, maskReg, cg); cg->stopUsingRegister(tmpReg); return resultReg; } + else if (vectorMask) + { + generateRegMaskRegRegInstruction(opcode.getMnemonic(), node, tmpReg, maskReg, lhsReg, rhsReg, cg, encoding); + cg->stopUsingRegister(tmpReg); + return resultReg; + } else { TR::InstOpCode movOpcode = TR::InstOpCode::MOVDQURegReg; diff --git a/compiler/x/codegen/OMRTreeEvaluator.hpp b/compiler/x/codegen/OMRTreeEvaluator.hpp index d5d1f6f0ecb..206ffc915e6 100644 --- a/compiler/x/codegen/OMRTreeEvaluator.hpp +++ b/compiler/x/codegen/OMRTreeEvaluator.hpp @@ -409,7 +409,8 @@ class OMR_EXTENSIBLE TreeEvaluator: public OMR::TreeEvaluator TR::Register *lhsReg, TR::Register *rhsReg, TR::Register *maskReg, - TR::CodeGenerator *cg); + TR::CodeGenerator *cg, + bool maskTypeMismatch = false); static TR::Register *ternaryVectorMaskHelper(TR::InstOpCode opcode, OMR::X86::Encoding encoding, TR::Node *node, diff --git a/compiler/x/codegen/X86OpcodeTable.hpp b/compiler/x/codegen/X86OpcodeTable.hpp index 6839ff33c90..af763181d86 100644 --- a/compiler/x/codegen/X86OpcodeTable.hpp +++ b/compiler/x/codegen/X86OpcodeTable.hpp @@ -89,9 +89,9 @@ static const TR::InstOpCode::Mnemonic VectorBinaryArithmeticOpCodesForReg[NumBin { TR::InstOpCode::PSUBBRegReg, TR::InstOpCode::PSUBWRegReg, TR::InstOpCode::PSUBDRegReg, TR::InstOpCode::PSUBQRegReg, TR::InstOpCode::SUBPSRegReg, TR::InstOpCode::SUBPDRegReg }, // BinaryArithmeticSub { TR::InstOpCode::bad, TR::InstOpCode::PMULLWRegReg, TR::InstOpCode::PMULLDRegReg, TR::InstOpCode::bad, TR::InstOpCode::MULPSRegReg, TR::InstOpCode::MULPDRegReg }, // BinaryArithmeticMul { TR::InstOpCode::bad, TR::InstOpCode::bad, TR::InstOpCode::bad, TR::InstOpCode::bad, TR::InstOpCode::DIVPSRegReg, TR::InstOpCode::DIVPDRegReg }, // BinaryArithmeticDiv - { TR::InstOpCode::bad, TR::InstOpCode::bad, TR::InstOpCode::PANDRegReg, TR::InstOpCode::PANDRegReg, TR::InstOpCode::bad, TR::InstOpCode::bad }, // BinaryArithmeticAnd - { TR::InstOpCode::bad, TR::InstOpCode::bad, TR::InstOpCode::PORRegReg, TR::InstOpCode::PORRegReg, TR::InstOpCode::bad, TR::InstOpCode::bad }, // BinaryArithmeticOr, - { TR::InstOpCode::bad, TR::InstOpCode::bad, TR::InstOpCode::PXORRegReg, TR::InstOpCode::PXORRegReg, TR::InstOpCode::bad, TR::InstOpCode::bad }, // BinaryArithmeticXor + { TR::InstOpCode::PANDRegReg, TR::InstOpCode::PANDRegReg, TR::InstOpCode::PANDRegReg, TR::InstOpCode::PANDRegReg, TR::InstOpCode::bad, TR::InstOpCode::bad }, // BinaryArithmeticAnd + { TR::InstOpCode::PORRegReg, TR::InstOpCode::PORRegReg, TR::InstOpCode::PORRegReg, TR::InstOpCode::PORRegReg, TR::InstOpCode::bad, TR::InstOpCode::bad }, // BinaryArithmeticOr, + { TR::InstOpCode::PXORRegReg, TR::InstOpCode::PXORRegReg, TR::InstOpCode::PXORRegReg, TR::InstOpCode::PXORRegReg, TR::InstOpCode::bad, TR::InstOpCode::bad }, // BinaryArithmeticXor { TR::InstOpCode::PMINSBRegReg, TR::InstOpCode::PMINSWRegReg, TR::InstOpCode::PMINSDRegReg, TR::InstOpCode::PMINSQRegReg, TR::InstOpCode::MINPSRegReg, TR::InstOpCode::MINPDRegReg }, // BinaryArithmeticMin { TR::InstOpCode::PMAXSBRegReg, TR::InstOpCode::PMAXSWRegReg, TR::InstOpCode::PMAXSDRegReg, TR::InstOpCode::PMAXSQRegReg, TR::InstOpCode::MAXPSRegReg, TR::InstOpCode::MAXPDRegReg }, // BinaryArithmeticMax { TR::InstOpCode::bad, TR::InstOpCode::VPSLLVWRegRegReg, TR::InstOpCode::VPSLLVDRegRegReg, TR::InstOpCode::VPSLLVQRegRegReg, TR::InstOpCode::bad, TR::InstOpCode::bad }, // BinaryLogicalShiftLeft