Skip to content

Commit b88d83a

Browse files
authored
[GLUTEN-7548][VL] Follow up hash join optimization PR 8931 to resolve comments (#11728)
1 parent a3c973f commit b88d83a

9 files changed

Lines changed: 111 additions & 81 deletions

File tree

backends-velox/src/main/java/org/apache/gluten/vectorized/HashJoinBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public long rtHandle() {
4242
public static native long nativeBuild(
4343
String buildHashTableId,
4444
long[] batchHandlers,
45-
String joinKeys,
45+
String[] joinKeys,
4646
int joinType,
4747
boolean hasMixedFiltCondition,
4848
boolean isExistenceJoin,

backends-velox/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,4 @@ object GlutenRpcMessages {
3434

3535
case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String])
3636
extends GlutenRpcMessage
37-
38-
// for mergetree cache
39-
case class GlutenMergeTreeCacheLoad(
40-
mergeTreeTable: String,
41-
columns: util.Set[String],
42-
onlyMetaCache: Boolean)
43-
extends GlutenRpcMessage
44-
45-
case class GlutenCacheLoadStatus(jobId: String)
46-
47-
case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
48-
extends GlutenRpcMessage
49-
50-
case class GlutenFilesCacheLoad(files: Array[Byte]) extends GlutenRpcMessage
51-
52-
case class GlutenFilesCacheLoadStatus(jobId: String)
5337
}

backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,18 @@ case class ColumnarBuildSideRelation(
197197
)
198198
}
199199

200-
val joinKey = keys.asScala
201-
.map {
202-
key =>
203-
val attr = ConverterUtils.getAttrFromExpr(key)
204-
ConverterUtils.genColumnNameWithExprId(attr)
205-
}
206-
.mkString(",")
200+
val joinKeys = keys.asScala.map {
201+
key =>
202+
val attr = ConverterUtils.getAttrFromExpr(key)
203+
ConverterUtils.genColumnNameWithExprId(attr)
204+
}.toArray
207205

208206
// Build the hash table
209207
hashTableData = HashJoinBuilder
210208
.nativeBuild(
211209
broadcastContext.buildHashTableId,
212210
batchArray.toArray,
213-
joinKey,
211+
joinKeys,
214212
broadcastContext.substraitJoinType.ordinal(),
215213
broadcastContext.hasMixedFiltCondition,
216214
broadcastContext.isExistenceJoin,

backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,20 +167,18 @@ class UnsafeColumnarBuildSideRelation(
167167
)
168168
}
169169

170-
val joinKey = keys.asScala
171-
.map {
172-
key =>
173-
val attr = ConverterUtils.getAttrFromExpr(key)
174-
ConverterUtils.genColumnNameWithExprId(attr)
175-
}
176-
.mkString(",")
170+
val joinKeys = keys.asScala.map {
171+
key =>
172+
val attr = ConverterUtils.getAttrFromExpr(key)
173+
ConverterUtils.genColumnNameWithExprId(attr)
174+
}.toArray
177175

178176
// Build the hash table
179177
hashTableData = HashJoinBuilder
180178
.nativeBuild(
181179
broadcastContext.buildHashTableId,
182180
batchArray.toArray,
183-
joinKey,
181+
joinKeys,
184182
broadcastContext.substraitJoinType.ordinal(),
185183
broadcastContext.hasMixedFiltCondition,
186184
broadcastContext.isExistenceJoin,

cpp/velox/compute/VeloxBackend.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ void VeloxBackend::tearDown() {
362362
filesystem->close();
363363
}
364364
#endif
365-
gluten::hashTableObjStore.reset();
366365

367366
// Destruct IOThreadPoolExecutor will join all threads.
368367
// On threads exit, thread local variables can be constructed with referencing global variables.

cpp/velox/jni/JniHashTable.cc

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,34 @@
2929

3030
namespace gluten {
3131

32-
static jclass jniVeloxBroadcastBuildSideCache = nullptr;
33-
static jmethodID jniGet = nullptr;
32+
void JniHashTableContext::initialize(JNIEnv* env, JavaVM* javaVm) {
33+
vm_ = javaVm;
34+
const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
35+
jniVeloxBroadcastBuildSideCache_ = createGlobalClassReferenceOrError(env, classSig);
36+
jniGet_ = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache_, "get", "(Ljava/lang/String;)J");
37+
}
3438

35-
jlong callJavaGet(const std::string& id) {
39+
void JniHashTableContext::finalize(JNIEnv* env) {
40+
if (jniVeloxBroadcastBuildSideCache_ != nullptr) {
41+
env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache_);
42+
jniVeloxBroadcastBuildSideCache_ = nullptr;
43+
}
44+
}
45+
46+
jlong JniHashTableContext::callJavaGet(const std::string& id) const {
3647
JNIEnv* env;
37-
if (vm->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
48+
if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) {
3849
throw gluten::GlutenException("JNIEnv was not attached to current thread");
3950
}
4051

4152
const jstring s = env->NewStringUTF(id.c_str());
42-
43-
auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache, jniGet, s);
53+
auto result = env->CallStaticLongMethod(jniVeloxBroadcastBuildSideCache_, jniGet_, s);
4454
return result;
4555
}
4656

4757
// Return the velox's hash table.
4858
std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
49-
const std::string& joinKeys,
59+
const std::vector<std::string>& joinKeys,
5060
std::vector<std::string> names,
5161
std::vector<facebook::velox::TypePtr> veloxTypeList,
5262
int joinType,
@@ -98,12 +108,9 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
98108
VELOX_NYI("Unsupported Join type: {}", std::to_string(sJoin));
99109
}
100110

101-
std::vector<std::string> joinKeyNames;
102-
folly::split(',', joinKeys, joinKeyNames);
103-
104111
std::vector<std::shared_ptr<const facebook::velox::core::FieldAccessTypedExpr>> joinKeyTypes;
105-
joinKeyTypes.reserve(joinKeyNames.size());
106-
for (const auto& name : joinKeyNames) {
112+
joinKeyTypes.reserve(joinKeys.size());
113+
for (const auto& name : joinKeys) {
107114
joinKeyTypes.emplace_back(
108115
std::make_shared<facebook::velox::core::FieldAccessTypedExpr>(rowType->findChild(name), name));
109116
}
@@ -125,21 +132,8 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
125132
return hashTableBuilder;
126133
}
127134

128-
long getJoin(std::string hashTableId) {
129-
return callJavaGet(hashTableId);
130-
}
131-
132-
void initVeloxJniHashTable(JNIEnv* env) {
133-
if (env->GetJavaVM(&vm) != JNI_OK) {
134-
throw gluten::GlutenException("Unable to get JavaVM instance");
135-
}
136-
const char* classSig = "Lorg/apache/gluten/execution/VeloxBroadcastBuildSideCache;";
137-
jniVeloxBroadcastBuildSideCache = createGlobalClassReferenceOrError(env, classSig);
138-
jniGet = getStaticMethodId(env, jniVeloxBroadcastBuildSideCache, "get", "(Ljava/lang/String;)J");
139-
}
140-
141-
void finalizeVeloxJniHashTable(JNIEnv* env) {
142-
env->DeleteGlobalRef(jniVeloxBroadcastBuildSideCache);
135+
long getJoin(const std::string& hashTableId) {
136+
return JniHashTableContext::getInstance().callJavaGet(hashTableId);
143137
}
144138

145139
} // namespace gluten

cpp/velox/jni/JniHashTable.h

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,53 @@
2626

2727
namespace gluten {
2828

29-
inline static JavaVM* vm = nullptr;
29+
// Wrapper class to encapsulate JNI-related static objects for hash table operations.
30+
// This avoids exposing global variables in the gluten namespace.
31+
class JniHashTableContext {
32+
public:
33+
static JniHashTableContext& getInstance() {
34+
static JniHashTableContext instance;
35+
return instance;
36+
}
3037

31-
inline static std::unique_ptr<ObjectStore> hashTableObjStore = ObjectStore::create();
38+
// Delete copy and move constructors/operators
39+
JniHashTableContext(const JniHashTableContext&) = delete;
40+
JniHashTableContext& operator=(const JniHashTableContext&) = delete;
41+
JniHashTableContext(JniHashTableContext&&) = delete;
42+
JniHashTableContext& operator=(JniHashTableContext&&) = delete;
43+
44+
void initialize(JNIEnv* env, JavaVM* javaVm);
45+
void finalize(JNIEnv* env);
46+
47+
JavaVM* getJavaVM() const {
48+
return vm_;
49+
}
50+
51+
ObjectStore* getHashTableObjStore() const {
52+
return hashTableObjStore_.get();
53+
}
54+
55+
jlong callJavaGet(const std::string& id) const;
56+
57+
private:
58+
JniHashTableContext() : hashTableObjStore_(ObjectStore::create()) {}
59+
60+
~JniHashTableContext() {
61+
// Note: The destructor is called at program exit (after main() returns).
62+
// By this time, JNI_OnUnload should have already been called, which invokes
63+
// finalize() to clean up JNI global references while the JVM is still valid.
64+
// The singleton itself (including hashTableObjStore_) will be destroyed here.
65+
}
66+
67+
JavaVM* vm_{nullptr};
68+
std::unique_ptr<ObjectStore> hashTableObjStore_;
69+
jclass jniVeloxBroadcastBuildSideCache_{nullptr};
70+
jmethodID jniGet_{nullptr};
71+
};
3272

3373
// Return the hash table builder address.
3474
std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
35-
const std::string& joinKeys,
75+
const std::vector<std::string>& joinKeys,
3676
std::vector<std::string> names,
3777
std::vector<facebook::velox::TypePtr> veloxTypeList,
3878
int joinType,
@@ -43,12 +83,21 @@ std::shared_ptr<HashTableBuilder> nativeHashTableBuild(
4383
std::vector<std::shared_ptr<ColumnarBatch>>& batches,
4484
std::shared_ptr<facebook::velox::memory::MemoryPool> memoryPool);
4585

46-
long getJoin(std::string hashTableId);
86+
long getJoin(const std::string& hashTableId);
4787

48-
void initVeloxJniHashTable(JNIEnv* env);
88+
// Initialize the JNI hash table context
89+
inline void initVeloxJniHashTable(JNIEnv* env, JavaVM* javaVm) {
90+
JniHashTableContext::getInstance().initialize(env, javaVm);
91+
}
4992

50-
void finalizeVeloxJniHashTable(JNIEnv* env);
93+
// Finalize the JNI hash table context
94+
inline void finalizeVeloxJniHashTable(JNIEnv* env) {
95+
JniHashTableContext::getInstance().finalize(env);
96+
}
5197

52-
jlong callJavaGet(const std::string& id);
98+
// Get hash table object store
99+
inline ObjectStore* getHashTableObjStore() {
100+
return JniHashTableContext::getInstance().getHashTableObjStore();
101+
}
53102

54103
} // namespace gluten

cpp/velox/jni/VeloxJniWrapper.cc

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
8080
getJniErrorState()->ensureInitialized(env);
8181
initVeloxJniFileSystem(env);
8282
initVeloxJniUDF(env);
83-
initVeloxJniHashTable(env);
83+
initVeloxJniHashTable(env, vm);
8484

8585
infoCls = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/validate/NativePlanValidationInfo;");
8686
infoClsInitMethod = getMethodIdOrError(env, infoCls, "<init>", "(ILjava/lang/String;)V");
@@ -94,8 +94,6 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
9494

9595
DLOG(INFO) << "Loaded Velox backend.";
9696

97-
gluten::vm = vm;
98-
9997
return jniVersion;
10098
}
10199

@@ -108,6 +106,7 @@ void JNI_OnUnload(JavaVM* vm, void*) {
108106

109107
finalizeVeloxJniUDF(env);
110108
finalizeVeloxJniFileSystem(env);
109+
finalizeVeloxJniHashTable(env);
111110
getJniErrorState()->close();
112111
getJniCommonState()->close();
113112
google::ShutdownGoogleLogging();
@@ -939,7 +938,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
939938
jclass,
940939
jstring tableId,
941940
jlongArray batchHandles,
942-
jstring joinKey,
941+
jobjectArray joinKeys,
943942
jint joinType,
944943
jboolean hasMixedJoinCondition,
945944
jboolean isExistenceJoin,
@@ -949,7 +948,16 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
949948
jint broadcastHashTableBuildThreads) {
950949
JNI_METHOD_START
951950
const auto hashTableId = jStringToCString(env, tableId);
952-
const auto hashJoinKey = jStringToCString(env, joinKey);
951+
952+
// Convert Java String array to C++ vector<string>
953+
std::vector<std::string> hashJoinKeys;
954+
jsize joinKeysCount = env->GetArrayLength(joinKeys);
955+
hashJoinKeys.reserve(joinKeysCount);
956+
for (jsize i = 0; i < joinKeysCount; ++i) {
957+
jstring jkey = (jstring)env->GetObjectArrayElement(joinKeys, i);
958+
hashJoinKeys.emplace_back(jStringToCString(env, jkey));
959+
}
960+
953961
const auto inputType = gluten::getByteArrayElementsSafe(env, namedStruct);
954962
std::string structString{
955963
reinterpret_cast<const char*>(inputType.elems()), static_cast<std::string::size_type>(inputType.length())};
@@ -988,7 +996,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
988996

989997
if (numThreads <= 1) {
990998
auto builder = nativeHashTableBuild(
991-
hashJoinKey,
999+
hashJoinKeys,
9921000
names,
9931001
veloxTypeList,
9941002
joinType,
@@ -1008,7 +1016,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
10081016
nullptr);
10091017
builder->setHashTable(std::move(mainTable));
10101018

1011-
return gluten::hashTableObjStore->save(builder);
1019+
return gluten::getHashTableObjStore()->save(builder);
10121020
}
10131021

10141022
std::vector<std::thread> threads;
@@ -1027,7 +1035,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
10271035
}
10281036

10291037
auto builder = nativeHashTableBuild(
1030-
hashJoinKey,
1038+
hashJoinKeys,
10311039
names,
10321040
veloxTypeList,
10331041
joinType,
@@ -1073,7 +1081,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_native
10731081
}
10741082

10751083
hashTableBuilders[0]->setHashTable(std::move(mainTable));
1076-
return gluten::hashTableObjStore->save(hashTableBuilders[0]);
1084+
return gluten::getHashTableObjStore()->save(hashTableBuilders[0]);
10771085
JNI_METHOD_END(kInvalidObjectHandle)
10781086
}
10791087

@@ -1083,7 +1091,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_HashJoinBuilder_cloneH
10831091
jlong tableHandler) {
10841092
JNI_METHOD_START
10851093
auto hashTableHandler = ObjectStore::retrieve<gluten::HashTableBuilder>(tableHandler);
1086-
return gluten::hashTableObjStore->save(hashTableHandler);
1094+
return gluten::getHashTableObjStore()->save(hashTableHandler);
10871095
JNI_METHOD_END(kInvalidObjectHandle)
10881096
}
10891097

gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/FallbackRules.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,22 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan]
4444
plan match {
4545
case plan: CodegenSupport if plan.supportCodegen =>
4646
if (
47-
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
47+
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize
4848
) {
4949
return true
5050
}
5151
plan.children.exists(existsMultiCodegens(_, count + 1))
5252
case plan: ShuffledHashJoinExec =>
5353
if (
54-
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
54+
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize
5555
) {
5656
return true
5757
}
5858

5959
plan.children.exists(existsMultiCodegens(_, count + 1))
6060
case plan: SortMergeJoinExec if GlutenConfig.get.forceShuffledHashJoin =>
6161
if (
62-
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum == outputSize
62+
(count + 1) >= optimizeLevel && plan.output.map(_.dataType.defaultSize).sum >= outputSize
6363
) {
6464
return true
6565
}

0 commit comments

Comments
 (0)