Skip to content

Commit

Permalink
Fix regression in batched SSE2 patch.
Browse files Browse the repository at this point in the history
Signed-off-by: Tuomas Tonteri <[email protected]>
  • Loading branch information
johnfea committed Aug 16, 2024
1 parent e0197db commit ae88298
Showing 1 changed file with 74 additions and 31 deletions.
105 changes: 74 additions & 31 deletions src/liboslexec/llvm_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ LLVM_Util::LLVM_Util(const PerThreadInfo& per_thread_info, int debuglevel,
// TODO: why are there casts to the base class llvm::Type *?
m_vector_width = OIIO::floor2(OIIO::clamp(m_vector_width, 4, 16));
m_llvm_type_wide_float = llvm_vector_type(m_llvm_type_float,
m_vector_width);
m_vector_width);
m_llvm_type_wide_double = llvm_vector_type(m_llvm_type_double,
m_vector_width);
m_llvm_type_wide_int = llvm_vector_type(m_llvm_type_int, m_vector_width);
Expand Down Expand Up @@ -790,8 +790,8 @@ LLVM_Util::debug_push_inlined_function(OIIO::ustring function_name,
method_scope_line, // Scope Line,
fnFlags,
llvm::DISubprogram::toSPFlags(true /*isLocalToUnit*/,
true /*isDefinition*/,
true /*false*/ /*isOptimized*/));
true /*isDefinition*/,
true /*false*/ /*isOptimized*/));

mLexicalBlocks.push_back(function);
}
Expand Down Expand Up @@ -3698,12 +3698,21 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
// Convert <4 x i1> -> <4 x i32>
llvm::Value* w4_int_mask = builder().CreateSExt(mask,
type_wide_int());

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value. However the only 256bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
w4_float_type);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_mask };
llvm::Value* args[1] = { w4_float_mask };
llvm::Value* int8_mask;
int8_mask = builder().CreateCall(func, toArrayRef(args));
return int8_mask;
Expand All @@ -3727,18 +3736,28 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
auto w4_int_masks = op_quarter_16x(wide_int_mask);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
// to build a 32 bit mask value. However the only 128bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
std::array<llvm::Value*, 4> w4_float_masks = {
{ builder().CreateBitCast(w4_int_masks[0], w4_float_type),
builder().CreateBitCast(w4_int_masks[1], w4_float_type),
builder().CreateBitCast(w4_int_masks[2], w4_float_type),
builder().CreateBitCast(w4_int_masks[3], w4_float_type) }
};

llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_masks[0] };
llvm::Value* args[1] = { w4_float_masks[0] };
std::array<llvm::Value*, 4> int4_masks;
int4_masks[0] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[1];
args[0] = w4_float_masks[1];
int4_masks[1] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[2];
args[0] = w4_float_masks[2];
int4_masks[2] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[3];
args[0] = w4_float_masks[3];
int4_masks[3] = builder().CreateCall(func, toArrayRef(args));

llvm::Value* bits12_15 = op_shl(int4_masks[3], constant(12));
Expand All @@ -3759,14 +3778,22 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
auto w4_int_masks = op_split_8x(wide_int_mask);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
// to build a 32 bit mask value. However the only 128bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
std::array<llvm::Value*, 2> w4_float_masks = {
{ builder().CreateBitCast(w4_int_masks[0], w4_float_type),
builder().CreateBitCast(w4_int_masks[1], w4_float_type) }
};

llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_masks[0] };
llvm::Value* args[1] = { w4_float_masks[0] };
std::array<llvm::Value*, 2> int4_masks;
int4_masks[0] = builder().CreateCall(func, toArrayRef(args));
args[0] = w4_int_masks[1];
args[0] = w4_float_masks[1];
int4_masks[1] = builder().CreateCall(func, toArrayRef(args));

llvm::Value* bits4_7 = op_shl(int4_masks[1], constant(4));
Expand All @@ -3782,12 +3809,20 @@ LLVM_Util::mask_as_int(llvm::Value* mask)
llvm::Value* w4_int_mask = builder().CreateSExt(mask,
type_wide_int());

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value. However the only 256bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
w4_float_type);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_mask };
llvm::Value* args[1] = { w4_float_mask };
llvm::Value* int4_mask = builder().CreateCall(func,
toArrayRef(args));

Expand Down Expand Up @@ -3833,12 +3868,20 @@ LLVM_Util::mask4_as_int8(llvm::Value* mask)
// Convert <4 x i1> -> <4 x i32>
llvm::Value* w4_int_mask = builder().CreateSExt(mask, type_wide_int());

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value. However the only 256bit
// version works on floats, so we will cast from int32 to
// float beforehand
llvm::Type* w4_float_type = llvm_vector_type(m_llvm_type_float, 4);
llvm::Value* w4_float_mask = builder().CreateBitCast(w4_int_mask,
w4_float_type);

// Now we will use the horizontal sign extraction intrinsic
// to build a 32 bit mask value.
llvm::Function* func = llvm::Intrinsic::getDeclaration(
module(), llvm::Intrinsic::x86_sse2_pmovmskb_128);
module(), llvm::Intrinsic::x86_sse_movmsk_ps);

llvm::Value* args[1] = { w4_int_mask };
llvm::Value* args[1] = { w4_float_mask };
llvm::Value* int32 = builder().CreateCall(func, toArrayRef(args));
llvm::Value* i8 = builder().CreateIntCast(int32, type_int8(), true);

Expand Down Expand Up @@ -4685,7 +4728,7 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,

llvm::Value* unmasked_value = wide_constant(0);
llvm::Value* args[] = { unmasked_value, void_ptr(src_ptr),
wide_index, int_mask, constant(4) };
wide_index, int_mask, constant(4) };
return builder().CreateCall(func_avx512_gather_pi,
toArrayRef(args));
} else if (m_supports_avx2) {
Expand All @@ -4705,8 +4748,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
auto w8_int_masks = op_split_16x(wide_int_mask);
auto w8_int_indices = op_split_16x(wide_index);
llvm::Value* args[] = { avx2_unmasked_value, void_ptr(src_ptr),
w8_int_indices[0], w8_int_masks[0],
constant8((uint8_t)4) };
w8_int_indices[0], w8_int_masks[0],
constant8((uint8_t)4) };
llvm::Value* gather1 = builder().CreateCall(func_avx2_gather_pi,
toArrayRef(args));
args[2] = w8_int_indices[1];
Expand Down Expand Up @@ -4794,8 +4837,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
toArrayRef(args));
args[2] = w8_int_indices[1];
args[3] = builder().CreateBitCast(w8_int_masks[1],
llvm_vector_type(type_float(),
8));
llvm_vector_type(type_float(),
8));
llvm::Value* gather2 = builder().CreateCall(func_avx2_gather_ps,
toArrayRef(args));
return op_combine_8x_vectors(gather1, gather2);
Expand Down Expand Up @@ -4990,8 +5033,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
toArrayRef(args));
args[2] = w8_int_indices[1];
args[3] = builder().CreateBitCast(w8_int_masks[1],
llvm_vector_type(type_float(),
8));
llvm_vector_type(type_float(),
8));
llvm::Value* gather2 = builder().CreateCall(func_avx2_gather_ps,
toArrayRef(args));
return op_combine_8x_vectors(gather1, gather2);
Expand Down Expand Up @@ -5092,8 +5135,8 @@ LLVM_Util::op_gather(llvm::Type* src_type, llvm::Value* src_ptr,
auto w8_int_indices = op_split_16x(
op_linearize_16x_indices(wide_index));
llvm::Value* args[] = { avx2_unmasked_value, void_ptr(src_ptr),
w8_int_indices[0], w8_int_masks[0],
constant8((uint8_t)4) };
w8_int_indices[0], w8_int_masks[0],
constant8((uint8_t)4) };
llvm::Value* gather1 = builder().CreateCall(func_avx2_gather_pi,
toArrayRef(args));
args[2] = w8_int_indices[1];
Expand Down Expand Up @@ -5863,9 +5906,9 @@ LLVM_Util::apply_return_to(llvm::Value* existing_mask)
OSL_ASSERT(masked_function_context().return_count > 0);

llvm::Value* loc_of_return_mask = masked_function_context().location_of_mask;
llvm::Value* rs_mask = op_load_mask(loc_of_return_mask);
llvm::Value* result = builder().CreateSelect(rs_mask, existing_mask,
rs_mask);
llvm::Value* rs_mask = op_load_mask(loc_of_return_mask);
llvm::Value* result = builder().CreateSelect(rs_mask, existing_mask,
rs_mask);
return result;
}

Expand Down

0 comments on commit ae88298

Please sign in to comment.