From 50a75f47be60f89210c1dc2d8974ab63a8298ad1 Mon Sep 17 00:00:00 2001 From: "zhongheng.gy" Date: Thu, 21 May 2026 15:59:15 +0800 Subject: [PATCH 1/3] [SPARK-57487][SQL] Support distributed map join for medium-sized build tables via SQL hint --- .../client/ManagedRpcResponseCallback.java | 33 + .../spark/network/client/TransportClient.java | 15 + .../client/TransportResponseHandler.java | 14 +- .../spark/network/server/RpcHandler.java | 10 + .../server/TransportRequestHandler.java | 67 +- .../network/shard/ShardLookupListener.java | 34 + .../spark/network/shard/ShardStoreClient.java | 58 ++ .../shard/protocol/BatchLookupReq.java | 98 +++ .../shard/protocol/BatchLookupResp.java | 90 +++ .../shard/protocol/ShardLookupMessage.java | 73 +++ .../org/apache/spark/internal/LogKeys.java | 2 + .../org/apache/spark/ContextCleaner.scala | 24 +- .../scala/org/apache/spark/SparkContext.scala | 1 + .../scala/org/apache/spark/SparkEnv.scala | 48 ++ .../org/apache/spark/executor/Executor.scala | 1 + .../spark/internal/config/package.scala | 9 + .../spark/network/ShardLookupService.scala | 67 ++ .../network/ShardLookupServiceFactory.scala | 72 +++ .../netty/NettyShardLookupService.scala | 131 ++++ .../network/netty/NettyShardRpcServer.scala | 80 +++ .../apache/spark/scheduler/DAGScheduler.scala | 3 + .../CoarseGrainedSchedulerBackend.scala | 4 + .../scheduler/dynalloc/ExecutorMonitor.scala | 54 +- .../scala/org/apache/spark/shard/Shard.scala | 22 + .../org/apache/spark/shard/ShardManager.scala | 589 ++++++++++++++++++ .../spark/shard/ShardManagerEndpoint.scala | 67 ++ .../apache/spark/shard/ShardManagerId.scala | 110 ++++ .../spark/shard/ShardManagerMaster.scala | 109 ++++ .../shard/ShardManagerMasterEndpoint.scala | 256 ++++++++ .../spark/shard/ShardManagerMessages.scala | 47 ++ .../org/apache/spark/shard/package-info.java | 18 + .../org/apache/spark/storage/BlockId.scala | 9 + .../apache/spark/storage/BlockManager.scala | 8 + .../spark/storage/BlockManagerMaster.scala | 10 + .../storage/BlockManagerMasterEndpoint.scala | 15 + .../spark/storage/BlockManagerMessages.scala | 2 + .../storage/BlockManagerStorageEndpoint.scala | 8 +- .../apache/spark/ContextCleanerSuite.scala | 14 +- .../sql/catalyst/analysis/ResolveHints.scala | 61 +- .../spark/sql/catalyst/optimizer/joins.scala | 52 ++ .../sql/catalyst/plans/logical/hints.scala | 20 + .../plans/physical/partitioning.scala | 49 +- .../apache/spark/sql/internal/SQLConf.scala | 35 ++ .../sql/execution/BufferedShardRowMap.java | 485 ++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 56 +- .../adaptive/AdaptiveSparkPlanExec.scala | 10 +- .../adaptive/LogicalQueryStageStrategy.scala | 15 +- .../execution/adaptive/QueryStageExec.scala | 44 ++ .../exchange/EnsureRequirements.scala | 2 + .../exchange/ShardExchangeExec.scala | 259 ++++++++ .../execution/exchange/ShardedRowRDD.scala | 36 ++ .../joins/DistributedMapJoinExec.scala | 567 +++++++++++++++++ .../joins/HashedRelationAdapter.scala | 122 ++++ .../execution/joins/UnsafeRowBufCodec.scala | 72 +++ .../apache/spark/sql/CachedTableSuite.scala | 1 + .../joins/DistributedMapJoinSuite.scala | 400 ++++++++++++ 56 files changed, 4508 insertions(+), 50 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/ManagedRpcResponseCallback.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/shard/ShardLookupListener.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupReq.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupResp.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/shard/protocol/ShardLookupMessage.java create mode 100644 core/src/main/scala/org/apache/spark/network/ShardLookupService.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyShardRpcServer.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/Shard.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/ShardManager.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/ShardManagerEndpoint.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/ShardManagerId.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala create mode 100644 core/src/main/scala/org/apache/spark/shard/package-info.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardedRowRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelationAdapter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeRowBufCodec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinSuite.scala diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ManagedRpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ManagedRpcResponseCallback.java new file mode 100644 index 0000000000000..932061d72e47e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/ManagedRpcResponseCallback.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +public interface ManagedRpcResponseCallback extends BaseResponseCallback { + + /** + * Successful response body. + * Ownership of {@code response} is transferred to the callback. + * The callback implementation MUST ensure {@code response} is released exactly once: + * either hand it off to the transport (e.g. wrap it in {@code RpcResponse} and write it to + * the channel, letting Netty release it after write completes), or call + * {@code response.release()} itself if it is not handed off. + * It may call {@code retain()} if it needs to pass the buffer to another thread. + */ + void onSuccess(ManagedBuffer response); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index f02f2c63ecd4c..4661212c69f8b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -200,6 +200,21 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { return requestId; } + public long sendManagedRpc(ManagedBuffer message, ManagedRpcResponseCallback callback) { + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } + + long requestId = requestId(); + handler.addRpcRequest(requestId, callback); + + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush(new RpcRequest(requestId, message)) + .addListener(listener); + + return requestId; + } + /** * Sends a MergedBlockMetaRequest message to the server. The response of this message is * either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}. diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index d27fa08d829bb..14ee022303669 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -189,7 +189,7 @@ public void handle(ResponseMessage message) throws Exception { "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); } } else if (message instanceof RpcResponse resp) { - RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId); + BaseResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", MDC.of(LogKeys.REQUEST_ID, resp.requestId), @@ -198,10 +198,14 @@ public void handle(ResponseMessage message) throws Exception { resp.body().release(); } else { outstandingRpcs.remove(resp.requestId); - try { - listener.onSuccess(resp.body().nioByteBuffer()); - } finally { - resp.body().release(); + if (listener instanceof ManagedRpcResponseCallback) { + ((ManagedRpcResponseCallback) listener).onSuccess(resp.body()); + } else { + try { + ((RpcResponseCallback) listener).onSuccess(resp.body().nioByteBuffer()); + } finally { + resp.body().release(); + } } } } else if (message instanceof RpcFailure resp) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java index a7c38917d17f6..da5b1f41dad54 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -21,6 +21,8 @@ import org.apache.spark.internal.SparkLogger; import org.apache.spark.internal.SparkLoggerFactory; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ManagedRpcResponseCallback; import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallbackWithID; @@ -136,6 +138,14 @@ public void onFailure(Throwable e) { } + public interface ManagedRpcHandler { + + void receive( + TransportClient client, + ManagedBuffer message, + ManagedRpcResponseCallback callback); + } + /** * Handler for {@link MergedBlockMetaRequest}. * diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 464d4d9eb378f..965f66960124b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -163,25 +163,58 @@ private void processStreamRequest(final StreamRequest req) { } private void processRpcRequest(final RpcRequest req) { - try { - rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + if (rpcHandler instanceof RpcHandler.ManagedRpcHandler) { + boolean handedOff = false; + try { + ((RpcHandler.ManagedRpcHandler) rpcHandler).receive(reverseClient, req.body(), + new ManagedRpcResponseCallback() { + @Override + public void onSuccess(ManagedBuffer response) { + boolean sent = false; + try { + respond(new RpcResponse(req.requestId, response)); + sent = true; + } finally { + if (!sent) { + response.release(); + } + } + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e))); + } + }); + handedOff = true; + } catch (Exception e) { + respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e))); + } finally { + if (!handedOff) { + req.body().release(); } + } + } else { + try { + rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + respond(new RpcResponse(req.requestId, new NioManagedBuffer(response))); + } - @Override - public void onFailure(Throwable e) { - respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e))); - } - }); - } catch (Exception e) { - logger.error("Error while invoking RpcHandler#receive() on RPC id {} from {}", e, - MDC.of(LogKeys.REQUEST_ID, req.requestId), - MDC.of(LogKeys.HOST_PORT, getRemoteAddress(channel))); - respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e))); - } finally { - req.body().release(); + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id {} from {}", e, + MDC.of(LogKeys.REQUEST_ID, req.requestId), + MDC.of(LogKeys.HOST_PORT, getRemoteAddress(channel))); + respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e))); + } finally { + req.body().release(); + } } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/shard/ShardLookupListener.java b/common/network-common/src/main/java/org/apache/spark/network/shard/ShardLookupListener.java new file mode 100644 index 0000000000000..2652a14729f85 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/shard/ShardLookupListener.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shard; + +import java.util.EventListener; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Listener for asynchronous shard batch lookup results in distributed map join. + */ +public interface ShardLookupListener extends EventListener { + + /** Called when a batch lookup RPC completes successfully. */ + void onBatchFetchSuccess(ManagedBuffer response); + + /** Called when a batch lookup RPC fails. */ + void onBatchFetchFailure(Throwable exception); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java b/common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java new file mode 100644 index 0000000000000..9cd3b2be82e4f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shard; + +import java.io.Closeable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.util.TransportConf; + +/** + * Client interface for shard-based RPC lookups in distributed map join. + * Each executor has one instance, used by the probe side to send batched + * key lookups to build-side executors. + */ +public abstract class ShardStoreClient implements Closeable { + protected final Logger logger = LoggerFactory.getLogger(this.getClass()); + protected volatile TransportClientFactory clientFactory; + protected String appId; + protected TransportConf transportConf; + + protected void checkInit() { + assert appId != null : "Called before init()"; + } + + /** + * Send a batched key lookup request to the specified host. + * + * @param host the target executor's hostname + * @param port the target executor's shard service port + * @param reqMsg the serialized batch of probe keys + * @param listener callback for success or failure + */ + public abstract void fetchBatch( + String host, + int port, + ManagedBuffer reqMsg, + ShardLookupListener listener); + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupReq.java b/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupReq.java new file mode 100644 index 0000000000000..e3a23d49994a2 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupReq.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shard.protocol; + +import java.util.Arrays; +import java.util.Objects; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Request message for a batched key lookup in distributed map join. + * Contains the shard set ID, target shard, number of key fields, + * and the serialized key data. + */ +public class BatchLookupReq extends ShardLookupMessage { + + public final long setId; + public final int shardId; + public final int numFields; + public final byte[] keysData; + + public BatchLookupReq(long setId, int shardId, int numFields, byte[] keysData) { + this.setId = setId; + this.shardId = shardId; + this.numFields = numFields; + this.keysData = keysData; + } + + @Override + protected Type type() { + return Type.BATCH_LOOKUP_REQ; + } + + @Override + public int hashCode() { + return (Objects.hash(setId, shardId) * 31 + Objects.hash(numFields)) * 31 + + Arrays.hashCode(keysData); + } + + @Override + public String toString() { + return String.format("BatchLookupReq[setId=%d,shardId=%d,numFields=%d,keysSize=%d]", + setId, shardId, numFields, keysData.length); + } + + @Override + public boolean equals(Object other) { + if (other instanceof BatchLookupReq) { + BatchLookupReq o = (BatchLookupReq) other; + return setId == o.setId + && shardId == o.shardId + && numFields == o.numFields + && Arrays.equals(keysData, o.keysData); + } + return false; + } + + @Override + public int encodedLength() { + return Long.BYTES + + Integer.BYTES + + Integer.BYTES + + Encoders.ByteArrays.encodedLength(keysData); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(setId); + buf.writeInt(shardId); + buf.writeInt(numFields); + Encoders.ByteArrays.encode(buf, keysData); + } + + public static BatchLookupReq decode(ByteBuf buf) { + long setId = buf.readLong(); + int shardId = buf.readInt(); + int numFields = buf.readInt(); + byte[] keysData = Encoders.ByteArrays.decode(buf); + return new BatchLookupReq(setId, shardId, numFields, keysData); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupResp.java b/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupResp.java new file mode 100644 index 0000000000000..b43eea7bd562a --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/BatchLookupResp.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shard.protocol; + +import java.util.Arrays; +import java.util.Objects; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Response message containing matching build-side rows for a batched + * key lookup in distributed map join. + */ +public class BatchLookupResp extends ShardLookupMessage { + + public final long setId; + public final int shardId; + public final byte[] rowsData; + + public BatchLookupResp(long setId, int shardId, byte[] rowsData) { + this.setId = setId; + this.shardId = shardId; + this.rowsData = rowsData; + } + + @Override + protected ShardLookupMessage.Type type() { + return Type.BATCH_LOOKUP_RESP; + } + + @Override + public int hashCode() { + return Objects.hash(setId, shardId) * 31 + Arrays.hashCode(rowsData); + } + + @Override + public String toString() { + return String.format("BatchLookupResp[setId=%d,shardId=%d,rowsSize=%d]", + setId, shardId, rowsData.length); + } + + @Override + public boolean equals(Object other) { + if (other instanceof BatchLookupResp) { + BatchLookupResp o = (BatchLookupResp) other; + return setId == o.setId + && shardId == o.shardId + && Arrays.equals(rowsData, o.rowsData); + } + return false; + } + + @Override + public int encodedLength() { + return Long.BYTES + + Integer.BYTES + + Encoders.ByteArrays.encodedLength(rowsData); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(setId); + buf.writeInt(shardId); + Encoders.ByteArrays.encode(buf, rowsData); + } + + public static BatchLookupResp decode(ByteBuf buf) { + long setId = buf.readLong(); + int shardId = buf.readInt(); + byte[] rowsData = Encoders.ByteArrays.decode(buf); + return new BatchLookupResp(setId, shardId, rowsData); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/ShardLookupMessage.java b/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/ShardLookupMessage.java new file mode 100644 index 0000000000000..7dfdf2e5b3a2e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/shard/protocol/ShardLookupMessage.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shard.protocol; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Base class for shard lookup RPC messages used in distributed map join. + * Provides encoding/decoding and type-based dispatch similar to + * {@code BlockTransferMessage}. + */ +public abstract class ShardLookupMessage implements Encodable { + + protected abstract Type type(); + + public enum Type { + BATCH_LOOKUP_REQ(0), BATCH_LOOKUP_RESP(1); + + private final byte id; + + Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { + return id; + } + } + + public static class Decoder { + public static ShardLookupMessage fromByteBuffer(ByteBuffer msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: // BATCH_LOOKUP_REQ + return BatchLookupReq.decode(buf); + case 1: // BATCH_LOOKUP_RESP + return BatchLookupResp.decode(buf); + default: + throw new IllegalArgumentException("Unknown message type: " + type); + } + } + } + + public ByteBuffer toByteBuffer() { + ByteBuf buf = Unpooled.buffer(encodedLength() + 1); + buf.writeByte(type().id()); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.nioBuffer(); + } +} diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index 37064bf776312..62148566bd893 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -722,6 +722,8 @@ public enum LogKeys implements LogKey { SESSION_KEY, SET_CLIENT_INFO_REQUEST, SHARD_ID, + SHARD_MANAGER_ID, + SHARD_SET_ID, SHORTER_SERVICE_NAME, SHORT_USER_NAME, SHUFFLE_BLOCK_INFO, diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 54ea8c94daac1..185bddc3d6537 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -25,10 +25,11 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging -import org.apache.spark.internal.LogKeys.{ACCUMULATOR_ID, BROADCAST_ID, LISTENER, RDD_ID, SHUFFLE_ID} +import org.apache.spark.internal.LogKeys.{ACCUMULATOR_ID, BROADCAST_ID, LISTENER, RDD_ID, SHARD_SET_ID, SHUFFLE_ID} import org.apache.spark.internal.config._ import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.scheduler.SparkListener +import org.apache.spark.shard.ShardSetRef import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils} @@ -39,6 +40,7 @@ private sealed trait CleanupTask private case class CleanRDD(rddId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private case class CleanShardSet(setId: Long) extends CleanupTask private case class CleanAccum(accId: Long) extends CleanupTask private case class CleanCheckpoint(rddId: Int) extends CleanupTask private case class CleanSparkListener(listener: SparkListener) extends CleanupTask @@ -168,6 +170,10 @@ private[spark] class ContextCleaner( registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } + def registerShardSetForCleanup(setRef: ShardSetRef): Unit = { + registerForCleanup(setRef, CleanShardSet(setRef.setId)) + } + /** Register a RDDCheckpointData for cleanup when it is garbage collected. */ def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int): Unit = { registerForCleanup(rdd, CleanCheckpoint(parentId)) @@ -203,6 +209,8 @@ private[spark] class ContextCleaner( doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanShardSet(setId) => + doCleanupShardSet(setId, blocking = blockOnCleanupTasks) case CleanAccum(accId) => doCleanupAccum(accId, blocking = blockOnCleanupTasks) case CleanCheckpoint(rddId) => @@ -263,6 +271,18 @@ private[spark] class ContextCleaner( } } + def doCleanupShardSet(setId: Long, blocking: Boolean): Unit = { + try { + logInfo(log"Cleaning shard-set ${MDC(SHARD_SET_ID, setId)}") + shardManager.unpersist(setId, blocking) + listeners.asScala.foreach(_.shardSetCleaned(setId)) + logInfo(log"Cleaned shard-set ${MDC(SHARD_SET_ID, setId)}") + } catch { + case e: Exception => + logError(log"Error cleaning shard-set ${MDC(SHARD_SET_ID, setId)}", e) + } + } + /** Perform accumulator cleanup. */ def doCleanupAccum(accId: Long, blocking: Boolean): Unit = { try { @@ -305,6 +325,7 @@ private[spark] class ContextCleaner( } private def broadcastManager = sc.env.broadcastManager + private def shardManager = sc.env.shardManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } @@ -319,6 +340,7 @@ private[spark] trait CleanerListener { def rddCleaned(rddId: Int): Unit def shuffleCleaned(shuffleId: Int): Unit def broadcastCleaned(broadcastId: Long): Unit + def shardSetCleaned(setId: Long): Unit def accumCleaned(accId: Long): Unit def checkpointCleaned(rddId: Long): Unit } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e00e34b89b842..60066add19c9f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -650,6 +650,7 @@ class SparkContext(config: SparkConf) extends Logging { _env.blockManager.initialize(_applicationId) FallbackStorage.registerBlockManagerIfNeeded( _env.blockManager.master, _conf, _hadoopConfiguration) + _env.initializeShardManager() // The metrics system for Driver need to be set spark.app.id to app ID. // So it should start after we get app ID from the task scheduler and set spark.app.id. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 39f12e03a9336..084cf007ea8c4 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.LogKeys import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager} import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances} +import org.apache.spark.network.ShardLookupServiceFactory import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.shuffle.ExternalBlockStoreClient import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} @@ -45,6 +46,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} +import org.apache.spark.shard.{ShardManager, ShardManagerId, ShardManagerInfo, ShardManagerMaster, ShardManagerMasterEndpoint} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.streaming.{MultiShuffleManager, StreamingShuffleManager} import org.apache.spark.storage._ @@ -70,6 +72,7 @@ class SparkEnv ( val serializerManager: SerializerManager, val mapOutputTracker: MapOutputTracker, val broadcastManager: BroadcastManager, + shardManagerFactory: () => ShardManager, val blockManager: BlockManager, val securityManager: SecurityManager, val metricsSystem: MetricsSystem, @@ -111,6 +114,14 @@ class SparkEnv ( def memoryManager: MemoryManager = _memoryManager + @volatile private[spark] var shardManager: ShardManager = _ + + private[spark] def initializeShardManager(): Unit = { + if (conf.get(config.SHARD_ENABLED)) { + shardManager = shardManagerFactory() + } + } + @volatile private[spark] var isStopped = false /** @@ -189,6 +200,10 @@ class SparkEnv ( shuffleManager.stop() } broadcastManager.stop() + if (shardManager != null) { + shardManager.stop() + shardManager.master.stop() + } blockManager.stop() blockManager.master.stop() metricsSystem.stop() @@ -573,6 +588,38 @@ object SparkEnv extends Logging { securityManager, externalShuffleClient) + val shardManagerFactory: () => ShardManager = () => { + val managerInfo = new concurrent.TrieMap[ShardManagerId, ShardManagerInfo] + val managerMaster = new ShardManagerMaster( + registerOrLookupEndpoint( + ShardManagerMaster.DRIVER_ENDPOINT_NAME, + new ShardManagerMasterEndpoint( + rpcEnv, + isLocal, + conf, + managerInfo, + isDriver)), + conf, + isDriver) + val lookupService = + ShardLookupServiceFactory.create( + conf, + bindAddress, + advertiseAddress, + 0, + numUsableCores, + managerMaster.masterEndpoint) + val sm = new ShardManager( + executorId, + rpcEnv, + managerMaster, + conf, + lookupService, + isDriver) + sm.initialize(conf.get("spark.app.id", "")) + sm + } + val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -603,6 +650,7 @@ object SparkEnv extends Logging { serializerManager, mapOutputTracker, broadcastManager, + shardManagerFactory, blockManager, securityManager, metricsSystem, diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 805ec68a89476..6313bf7e8891b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -532,6 +532,7 @@ private[spark] class Executor( Utils.withContextClassLoader(defaultSessionState.replClassLoader) { env.initializeShuffleManager() } + env.initializeShardManager() } metricsPoller.start() diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 141e033ea992f..514f319744458 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2994,4 +2994,13 @@ package object config { .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .booleanConf .createWithDefault(false) + + private[spark] val SHARD_ENABLED = + ConfigBuilder("spark.shard.enabled") + .doc("When true, the shard service infrastructure (endpoints, lookup server) is " + + "started at application launch, enabling the distributed map join strategy " + + "via SQL hints.") + .version("5.0.0") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/network/ShardLookupService.scala b/core/src/main/scala/org/apache/spark/network/ShardLookupService.scala new file mode 100644 index 0000000000000..928b12d701826 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ShardLookupService.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import scala.concurrent.{Future, Promise} + +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shard.{ShardLookupListener, ShardStoreClient} +import org.apache.spark.shard.ShardManager + +private[spark] abstract class ShardLookupService extends ShardStoreClient { + + def init(shardManager: ShardManager): Unit + + def port: Int + + def hostName: String + + /** + * Install a shard replica on this executor. Called when a remote executor + * requests this executor to host a copy of a shard. + * + * The default reads the shard data from BlockManager (Java-serialized + * HashedRelation). Override in custom implementations (e.g., native + * backends) to load data in a different format. + * + * @return true if the replica was installed successfully + */ + def onInstallReplica(shardManager: ShardManager, setId: Long, shardId: Int): Unit = { + shardManager.loadShardData(setId, shardId) + } + + def fetchBatch(host: String, port: Int, reqMsg: ManagedBuffer): Future[ManagedBuffer] = { + val result = Promise[ManagedBuffer]() + fetchBatch( + host, + port, + reqMsg, + new ShardLookupListener { + + override def onBatchFetchSuccess(response: ManagedBuffer): Unit = { + result.success(response) + } + + override def onBatchFetchFailure(exception: Throwable): Unit = { + result.failure(exception) + } + }) + result.future + } + +} diff --git a/core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala b/core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala new file mode 100644 index 0000000000000..aadae80f413d2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.CLASS_NAME +import org.apache.spark.network.netty.NettyShardLookupService +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils + +/** + * Factory for creating [[ShardLookupService]] instances. + * + * The implementation class is controlled by `spark.shard.service`. + * Custom implementations (e.g. native backends) can be plugged in by setting + * this config to the FQCN of a [[ShardLookupService]] subclass with a + * constructor matching `(SparkConf, String, String, Int)`. + */ +private[spark] object ShardLookupServiceFactory extends Logging { + + private val SHARD_LOOKUP_SERVICE_CLASS_KEY = "spark.shard.service" + + private val DEFAULT_CLASS = + classOf[NettyShardLookupService].getName + + def create( + conf: SparkConf, + bindAddress: String, + advertiseAddress: String, + port: Int, + numCores: Int, + masterEndpoint: RpcEndpointRef): ShardLookupService = { + val className = conf.get(SHARD_LOOKUP_SERVICE_CLASS_KEY, DEFAULT_CLASS) + if (className == DEFAULT_CLASS) { + new NettyShardLookupService(conf, bindAddress, advertiseAddress, port, + numCores, masterEndpoint) + } else { + try { + logInfo(log"Creating custom ShardLookupService: ${MDC(CLASS_NAME, className)}") + Utils.classForName(className) + .getDeclaredConstructor( + classOf[SparkConf], classOf[String], classOf[String], + classOf[Int], classOf[Int], classOf[RpcEndpointRef]) + .newInstance(conf, bindAddress, advertiseAddress, + port.asInstanceOf[AnyRef], numCores.asInstanceOf[AnyRef], + masterEndpoint) + .asInstanceOf[ShardLookupService] + } catch { + case e: Exception => + logWarning(log"Failed to create ${MDC(CLASS_NAME, className)}, falling back to Netty", e) + new NettyShardLookupService(conf, bindAddress, advertiseAddress, port, + numCores, masterEndpoint) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala new file mode 100644 index 0000000000000..e724a5f2efed6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkConf +import org.apache.spark.network.{ShardLookupService, TransportContext} +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.ManagedRpcResponseCallback +import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} +import org.apache.spark.network.shard.ShardLookupListener +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shard.{ShardLookupAdapter, ShardManager} +import org.apache.spark.util.Utils + +private[spark] class NettyShardLookupService( + conf: SparkConf, + bindAddress: String, + val hostName: String, + _port: Int, + numCores: Int, + masterEndpoint: RpcEndpointRef = null) + extends ShardLookupService { + + private val serializer = new JavaSerializer(conf) + private[this] var transportContext: TransportContext = _ + private[this] var server: TransportServer = _ + private[this] var rpcHandler: NettyShardRpcServer = _ + + private val lookupAdapter: ShardLookupAdapter = { + Utils + .classForName("org.apache.spark.sql.execution.joins.HashedRelationAdapter") + .getDeclaredConstructor() + .newInstance() + .asInstanceOf[ShardLookupAdapter] + } + + override def init(shardManager: ShardManager): Unit = { + rpcHandler = new NettyShardRpcServer(conf.getAppId, serializer, shardManager, lookupAdapter) + val cloned = conf.clone + cloned.setIfMissing("spark.shard.io.mode", "NIO") + cloned.setIfMissing("spark.shard.io.clientThreads", "8") + cloned.setIfMissing("spark.shard.io.serverThreads", "8") + cloned.setIfMissing("spark.shard.io.numConnectionsPerPeer", "8") + cloned.setIfMissing("spark.shard.io.connectionCreationTimeout", "2s") + cloned.setIfMissing("spark.shard.io.retryWait", "300ms") + cloned.setIfMissing("spark.shard.io.maxRetries", "3") + cloned.setIfMissing("spark.network.waitForReachable", "false") + cloned.setIfMissing("spark.network.sharedByteBufAllocators.enabled", "true") + cloned.setIfMissing("spark.network.io.preferDirectBufs", "true") + transportConf = SparkTransportConf.fromSparkConf(cloned, "shard", numCores) + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory() + server = createNonAuthServer() + appId = conf.getAppId + logger.info(s"Server created on $hostName $bindAddress:${server.getPort}") + } + + override def port: Int = server.getPort + + private def createNonAuthServer(): TransportServer = { + def startService(port: Int): (TransportServer, Int) = { + val server = + transportContext.createServer( + bindAddress, + port, + List.empty[TransportServerBootstrap].asJava) + (server, server.getPort) + } + + Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1 + } + + override def fetchBatch( + host: String, + port: Int, + reqMsg: ManagedBuffer, + listener: ShardLookupListener): Unit = try { + val client = clientFactory.createClient(host, port, true) + client.sendManagedRpc( + reqMsg, + new ManagedRpcResponseCallback() { + + override def onSuccess(response: ManagedBuffer): Unit = { + listener.onBatchFetchSuccess(response) + } + + override def onFailure(e: Throwable): Unit = { + listener.onBatchFetchFailure(e) + } + }) + } catch { + case e: Exception => + listener.onBatchFetchFailure(e) + } + + override def close(): Unit = { + if (server != null) { + server.close() + } + + if (rpcHandler != null) { + rpcHandler.close() + } + + if (clientFactory != null) { + clientFactory.close() + } + + if (transportContext != null) { + transportContext.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyShardRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyShardRpcServer.scala new file mode 100644 index 0000000000000..eac21b337c827 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyShardRpcServer.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future => SFuture} + +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{ManagedRpcResponseCallback, RpcResponseCallback, TransportClient} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} +import org.apache.spark.serializer.Serializer +import org.apache.spark.shard.{ShardLookupAdapter, ShardManager} +import org.apache.spark.util.ThreadUtils + +private[spark] class NettyShardRpcServer( + appId: String, + serializer: Serializer, + shardManager: ShardManager, + lookupAdapter: ShardLookupAdapter) + extends RpcHandler + with RpcHandler.ManagedRpcHandler + with Logging { + + private val streamManager = new OneForOneStreamManager() + + private implicit val workerEc: ExecutionContextExecutorService = + ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("shard-rpc-worker", 16)) + + override def receive( + client: TransportClient, + message: ByteBuffer, + callback: RpcResponseCallback): Unit = { + throw new UnsupportedOperationException() + } + + override def receive( + client: TransportClient, + message: ManagedBuffer, + callback: ManagedRpcResponseCallback): Unit = { + SFuture { + val respMsg = + try { + lookupAdapter.lookup(shardManager, message) + } catch { + case t: Throwable => + callback.onFailure(t) + null + } finally { + message.release() + } + if (respMsg != null) { + callback.onSuccess(respMsg) + } + } + } + + override def getStreamManager: StreamManager = streamManager + + def close(): Unit = { + workerEc.shutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22720b98aafde..4c8041358b71d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -3131,6 +3131,9 @@ private[spark] class DAGScheduler( host => blockManagerMaster.removeShufflePushMergerLocation(host)) } blockManagerMaster.removeExecutorAsync(execId) + if (env.shardManager != null) { + env.shardManager.master.removeExecutor(execId) + } clearCacheLocs() } if (fileLost) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index ad92e22424c77..44fb9d906a7d3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -489,6 +489,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // about the executor, but the scheduler will not. Therefore, we should remove the // executor from the block manager when we hit this case. scheduler.sc.env.blockManager.master.removeExecutorAsync(executorId) + val sm = scheduler.sc.env.shardManager + if (sm != null) { + sm.master.removeExecutor(executorId) + } // SPARK-35011: If we reach this code path, which means the executor has been // already removed from the scheduler backend but the block manager master may // still know it. In this case, removing the executor from block manager master diff --git a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala index a98672fd7db4a..333d67836292c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.internal.config._ import org.apache.spark.resource.ResourceProfile.UNKNOWN_RESOURCE_PROFILE_ID import org.apache.spark.scheduler._ -import org.apache.spark.storage.{RDDBlockId, ShuffleDataBlockId} +import org.apache.spark.storage.{BlockId, RDDBlockId, ShardBlockId, ShuffleDataBlockId} import org.apache.spark.util.Clock /** @@ -388,10 +388,12 @@ private[spark] class ExecutorMonitor( val exec = ensureExecutorIsTracked(event.blockUpdatedInfo.blockManagerId.executorId, UNKNOWN_RESOURCE_PROFILE_ID) + var shouldRet = true // Check if it is a shuffle file, or RDD to pick the correct codepath for update - if (!event.blockUpdatedInfo.blockId.isInstanceOf[RDDBlockId]) { - if (event.blockUpdatedInfo.blockId.isInstanceOf[ShuffleDataBlockId] && - shuffleTrackingEnabled) { + event.blockUpdatedInfo.blockId match { + case _: RDDBlockId => + shouldRet = false + case ShuffleDataBlockId(shuffleId, _, _) if shuffleTrackingEnabled => /** * The executor monitor keeps track of locations of cache and shuffle blocks and this can * be used to decide which executor(s) Spark should shutdown first. Since we move shuffle @@ -399,11 +401,19 @@ private[spark] class ExecutorMonitor( * data blocks as index and other blocks blocks do not necessarily mean the entire block * has been committed. */ - event.blockUpdatedInfo.blockId match { - case ShuffleDataBlockId(shuffleId, _, _) => exec.addShuffle(shuffleId) - case _ => // For now we only update on data blocks + exec.addShuffle(shuffleId) + case sh: ShardBlockId if sh.shardId != -1 => + val storageLevel = event.blockUpdatedInfo.storageLevel + if (storageLevel.isValid) { + exec.addNonRddCached(sh) + } else { + exec.removeNonRddCached(sh) } - } + exec.updateTimeout() + case _ => // For now we only update on data blocks + } + + if (shouldRet) { return } @@ -462,6 +472,18 @@ private[spark] class ExecutorMonitor( override def broadcastCleaned(broadcastId: Long): Unit = { } + override def shardSetCleaned(setId: Long): Unit = { + executors.asScala.foreach { case (_, exec) => + val toRemove = exec.nonRddCachedBlocks.collect { + case s: ShardBlockId if s.setId == setId => s + } + if (toRemove.nonEmpty) { + toRemove.foreach(exec.removeNonRddCached) + } + } + nextTimeout.set(Long.MinValue) + } + override def accumCleaned(accId: Long): Unit = { } override def checkpointCleaned(rddId: Long): Unit = { } @@ -556,6 +578,19 @@ private[spark] class ExecutorMonitor( // This should only be used in the event thread. private val shuffleIds = if (shuffleTrackingEnabled) new mutable.HashSet[Int]() else null + val nonRddCachedBlocks = new mutable.HashSet[BlockId]() + + def addNonRddCached(id: BlockId): Unit = { + if (nonRddCachedBlocks.add(id) && isIdle) { + updateTimeout() + } + } + def removeNonRddCached(id: BlockId): Unit = { + if (nonRddCachedBlocks.remove(id) && nonRddCachedBlocks.isEmpty && cachedBlocks.isEmpty) { + updateTimeout() + } + } + def isIdle: Boolean = idleStart >= 0 && !hasActiveShuffle def updateRunningTasks(delta: Int): Unit = { @@ -567,7 +602,8 @@ private[spark] class ExecutorMonitor( def updateTimeout(): Unit = { val oldDeadline = timeoutAt val newDeadline = if (idleStart >= 0) { - val _cacheTimeout = if (cachedBlocks.nonEmpty) storageTimeoutNs else 0 + val _cacheTimeout = + if (cachedBlocks.nonEmpty || nonRddCachedBlocks.nonEmpty) storageTimeoutNs else 0 val _shuffleTimeout = if (shuffleIds != null && shuffleIds.nonEmpty) { shuffleTimeoutNs } else { diff --git a/core/src/main/scala/org/apache/spark/shard/Shard.scala b/core/src/main/scala/org/apache/spark/shard/Shard.scala new file mode 100644 index 0000000000000..26145aaaee907 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/Shard.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard + +case class ShardSetRef(setId: Long, shardIds: Array[Int]) + +case class ShardKey(setId: Long, shardId: Int) diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManager.scala b/core/src/main/scala/org/apache/spark/shard/ShardManager.scala new file mode 100644 index 0000000000000..3102816e8e087 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/ShardManager.scala @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard + +import java.io.{InputStream => JInputStream, OutputStream => JOutputStream, SequenceInputStream} +import java.nio.ByteBuffer +import java.util.{Collections, Map => JMap} +import java.util.concurrent.ConcurrentHashMap +import java.util.zip.Adler32 + +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future} +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag +import scala.util.Random +import scala.util.control.NonFatal + +import org.apache.commons.collections4.map.AbstractReferenceMap._ +import org.apache.commons.collections4.map.ReferenceMap + +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{HOST_PORT, SHARD_ID, SHARD_MANAGER_ID, SHARD_SET_ID} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.network.ShardLookupService +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.{BlockData, BlockId, ByteBufferBlockData, ShardBlockId, StorageLevel} +import org.apache.spark.util.{IdGenerator, KeyLock, ThreadUtils, Utils} +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +/** + * Per-executor manager for distributed map join shard data. + * + * Coordinates shard installation, replication, and RPC-based lookups with the + * driver-side [[ShardManagerMaster]]. Each executor has one instance, created + * lazily on first use by a distributed map join operator. + */ +private[spark] class ShardManager( + val executorId: String, + rpcEnv: RpcEnv, + val master: ShardManagerMaster, + conf: SparkConf, + val shardLookupService: ShardLookupService, + isDriver: Boolean) + extends Logging { + + private var shardManagerId: ShardManagerId = _ + + private[shard] val cachedValues: JMap[Any, Any] = + Collections.synchronizedMap( + new ReferenceMap(ReferenceStrength.HARD, ReferenceStrength.WEAK) + .asInstanceOf[JMap[Any, Any]]) + + private val managerEndpoint = rpcEnv.setupEndpoint( + "ShardManagerEndpoint" + ShardManager.ID_GENERATOR.next, + new ShardManagerEndpoint(rpcEnv, this)) + + private val cleanupCallbacks = + new java.util.concurrent.CopyOnWriteArrayList[ShardCleanupCallback]() + + def registerCleanupCallback(cb: ShardCleanupCallback): Unit = { + cleanupCallbacks.add(cb) + } + + private[spark] def invokeCleanupCallbacks(setId: Long): Unit = { + filterBytesCache.remove(setId) + val it = cleanupCallbacks.iterator() + while (it.hasNext) { + try it.next().onShardSetRemoved(setId) + catch { + case NonFatal(e) => + logWarning(log"Shard cleanup callback failed for setId=${MDC(SHARD_SET_ID, setId)}", e) + } + } + } + + private val filterBytesCache = + new ConcurrentHashMap[Long, (Array[Byte], Array[Byte])]() + + private val shardLock = new KeyLock[ShardBlockId] + + private[spark] val lookupEc: ExecutionContextExecutorService = + ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("lookup-join-future", 16)) + + private val managerEc: ExecutionContextExecutorService = + ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("shard-manager-future", 8)) + + private var blockSize: Int = _ + private[shard] var compressionCodec: Option[CompressionCodec] = _ + private def setConf(conf: SparkConf): Unit = { + compressionCodec = Some(CompressionCodec.createCodec(conf)) + blockSize = conf.getInt("spark.shard.blockSize", 4 << 20) + } + + def initialize(appId: String): Unit = { + setConf(conf) + shardLookupService.init(this) + val id = + ShardManagerId(executorId, shardLookupService.hostName, shardLookupService.port, None) + val idFromMaster = master.registerShardManager(id, managerEndpoint) + shardManagerId = if (idFromMaster != null) idFromMaster else id + + logInfo(log"Initialized ShardManager: ${MDC(SHARD_MANAGER_ID, shardManagerId)}") + } + + def newShardSet(numShards: Int, replicaCount: Int): Long = { + master.newShardSet(numShards, replicaCount) + } + + def installShard[T: ClassTag]( + value: T, + setId: Long, + id: Int, + filterBytes: Option[(Array[Byte], Array[Byte])] = None)( + bfOutput: JOutputStream => Unit): Unit = { + writeShardBlock(value, setId, id, bfOutput, filterBytes) + filterBytes.foreach(fb => filterBytesCache.putIfAbsent(setId, fb)) + installReplica(setId, id) + } + + def installReplicaSet(setId: Long, id: Int): Unit = { + master.installReplicaSet(setId, id) + } + + def installReplica(setId: Long, id: Int): Unit = { + logInfo(log"installReplica: setId=${MDC(SHARD_SET_ID, setId)}, shardId=${MDC(SHARD_ID, id)}") + shardLookupService.onInstallReplica(this, setId, id) + reportShardInstalled(setId, id) + } + + def getFilterBytes(setId: Long, shardId: Int): Option[(Array[Byte], Array[Byte])] = { + Option(filterBytesCache.get(setId)).orElse { + val metaId = ShardBlockId(setId, shardId, "meta") + val bm = SparkEnv.get.blockManager + bm.getLocalBytes(metaId).flatMap { block => + try { + val meta = SparkEnv.get.serializer.newInstance() + .deserialize[ShardBlockMeta](block.toByteBuffer()) + meta.filterExprBytes.map { exprBytes => + val pair = (exprBytes, + meta.filterSchemaBytes.getOrElse(Array.empty[Byte])) + filterBytesCache.putIfAbsent(setId, pair) + pair + } + } finally { + releaseBlockManagerLock(metaId) + } + } + } + } + + def loadShardData(setId: Long, id: Int): Unit = { + readShardBlock[AnyRef](setId, id) + } + + def reportShardInstalled(setId: Long, id: Int): Unit = { + master.reportShard(shardManagerId, setId, id) + logInfo(log"Shard (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, id)}) reported to master") + } + + def mergeBloomFilter(setId: Long, shardIds: Array[Int], acc: BloomAccumulator): Unit = + Utils.tryOrIOException { + val setBloomId = ShardBlockId(setId, -1, "bloom") + shardLock.withLock(setBloomId) { + val bm = SparkEnv.get.blockManager + def openBloomInput(shardId: Int): JInputStream = { + val bloomId = ShardBlockId(setId, shardId, "bloom") + val block = + bm.getLocalBytes(bloomId) + .map(x => { releaseBlockManagerLock(bloomId); x }) + .getOrElse { + bm.getRemoteBytes(bloomId) match { + case Some(b) => + new ByteBufferBlockData(b, true) + case None => + throw new SparkException(s"Failed to get bloom $bloomId") + } + } + + try { + compressionCodec.map(_.compressedInputStream(block.toInputStream())) + .getOrElse(block.toInputStream()) + } finally { + block.dispose() + } + } + + // merge all shard blooms + shardIds.foreach(id => acc.add(openBloomInput(id))) + + // write merged bloom + val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate) + val out = compressionCodec.map(_.compressedOutputStream(cbbos)).getOrElse(cbbos) + try acc.finish(out) + finally out.close() + if (!bm.putBytes( + setBloomId, + cbbos.toChunkedByteBuffer, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true)) { + throw new SparkException(s"Failed to store bloom $setBloomId in local BlockManager") + } + } + } + + def fetchBloomFilter[T: ClassTag](setId: Long)(bfInput: JInputStream => T): T = + Utils.tryOrIOException { + val bloomId = ShardBlockId(setId, -1, "bloom") + shardLock.withLock(bloomId) { + Option(cachedValues.get(bloomId)).map(_.asInstanceOf[T]).getOrElse { + val bm = SparkEnv.get.blockManager + val block = + bm.getLocalBytes(bloomId) + .map(x => { releaseBlockManagerLock(bloomId); x }) + .getOrElse { + bm.getRemoteBytes(bloomId) match { + case Some(b) => + if (!bm + .putBytes( + bloomId, + b, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true)) { + throw new SparkException( + s"Failed to store bloom $bloomId in local BlockManager") + } + new ByteBufferBlockData(b, true) + case None => + throw new SparkException(s"Failed to get bloom $bloomId") + } + } + + try { + val input = compressionCodec.map(_.compressedInputStream(block.toInputStream())) + .getOrElse(block.toInputStream()) + val obj = + try { bfInput(input) } + finally { input.close() } + + if (obj != null) { + cachedValues.put(bloomId, obj) + } + + obj + } finally { + block.dispose() + } + } + } + } + + def fetchRemoteBatch( + setId: Long, + shardId: Int, + toReqMsg: () => ManagedBuffer): Future[ManagedBuffer] = { + implicit val ec: ExecutionContextExecutorService = lookupEc + val message: Future[ManagedBuffer] = Future { + toReqMsg() + } + + def locations(refresh: Boolean): List[ShardManagerId] = + sortLocations(master.getLocations(setId, shardId, refresh)).toList + + def tryWithReqMsg( + reqMsg: ManagedBuffer, + locs: List[ShardManagerId], + tried: Set[ShardManagerId] = Set.empty): Future[ManagedBuffer] = + locs match { + case Nil => + Future.failed( + new SparkException( + s"Remote batch not found of ($setId, $shardId) " + + s"at locations ${locations(refresh = false)}")) + case loc :: tail => + // Before each attempt, retain once and transfer ownership of that reference + // to Netty outbound (MessageWithHeader#deallocate will call release). + reqMsg.retain() + shardLookupService + .fetchBatch(loc.host, loc.port, reqMsg) + .recoverWith { case NonFatal(e) => + logWarning(log"Failed to fetch remote batch" + + log" (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, shardId)})" + + log" from ${MDC(HOST_PORT, loc)}", e) + val triedLocs = tried + loc + val recoverLocs = + if (tail.isEmpty) locations(refresh = true).filterNot(triedLocs) + else tail + tryWithReqMsg(reqMsg, recoverLocs, triedLocs) + } + } + + message.flatMap { reqMsg => + tryWithReqMsg(reqMsg, locations(refresh = false)) + // Release the base reference after the whole retry chain completes. + .andThen { case _ => reqMsg.release() } + } + } + + def getLocalValue[T: ClassTag](setId: Long, id: Int): T = { + Option(cachedValues.get(ShardBlockId(setId, id))) + .map(_.asInstanceOf[T]) + .getOrElse(readShardBlock[T](setId, id)) + } + + def unpersist(setId: Long, blocking: Boolean): Unit = { + logDebug(log"Unpersisting shard-set ${MDC(SHARD_SET_ID, setId)}") + SparkEnv.get.blockManager.master.removeShardSet(setId, blocking) + } + + def stop(): Unit = { + lookupEc.shutdown() + managerEc.shutdown() + shardLookupService.close() + rpcEnv.stop(managerEndpoint) + logInfo(log"ShardManager stopped") + } + + private def sortLocations(locations: Seq[ShardManagerId]): Seq[ShardManagerId] = { + Random.shuffle(locations).sortBy(loc => if (loc.host == shardManagerId.host) 0 else 1) + } + + private def writeShardBlock[T: ClassTag]( + value: T, + setId: Long, + id: Int, + bfOutput: JOutputStream => Unit, + filterBytes: Option[(Array[Byte], Array[Byte])] = None): Unit = + Utils.tryOrIOException { + val shardId = ShardBlockId(setId, id) + shardLock.withLock(shardId) { + val bm = SparkEnv.get.blockManager + val blocks = blockifyObject( + value, blockSize, SparkEnv.get.serializer, compressionCodec) + val checksums = new Array[Int](blocks.length) + // write shard + blocks.zipWithIndex.foreach { case (block, i) => + checksums(i) = calcChecksum(block) + val pieceId = ShardBlockId(setId, id, "piece" + i) + val bytes = new ChunkedByteBuffer(block.duplicate()) + if (!bm.putBytes(pieceId, bytes, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store shard piece $pieceId in local BlockManager") + } + } + + // write meta + val meta = ShardBlockMeta( + blocks.length, + Some(checksums), + filterBytes.map(_._1), + filterBytes.map(_._2)) + val metaBytes = new ChunkedByteBuffer( + SparkEnv.get.serializer.newInstance().serialize(meta)) + val metaId = ShardBlockId(setId, id, "meta") + if (!bm.putBytes( + metaId, + metaBytes, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true)) { + throw new SparkException(s"Failed to store shard meta $metaId in local BlockManager") + } + + // write bloom + val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate) + val out = compressionCodec.map(_.compressedOutputStream(cbbos)).getOrElse(cbbos) + try bfOutput(out) + finally out.close() + val bloomId = ShardBlockId(setId, id, "bloom") + if (!bm.putBytes( + bloomId, + cbbos.toChunkedByteBuffer, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true)) { + throw new SparkException(s"Failed to store shard bloom $bloomId in local BlockManager") + } + } + } + + private def readShardBlock[T: ClassTag](setId: Long, id: Int): T = + Utils.tryOrIOException { + val shardId = ShardBlockId(setId, id) + shardLock.withLock(shardId) { + Option(cachedValues.get(shardId)).map(_.asInstanceOf[T]).getOrElse { + val bm = SparkEnv.get.blockManager + bm.getLocalValues(shardId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseBlockManagerLock(shardId) + + if (x != null) { + cachedValues.put(shardId, x) + } + + x + } else { + throw new SparkException(s"Failed to get locally stored shard: $shardId") + } + + case None => + val blocks = readBlocks(setId, id) + try { + val obj = unBlockifyObject[T]( + blocks.map(_.toInputStream()), + SparkEnv.get.serializer, + compressionCodec) + + if (obj != null) { + cachedValues.put(shardId, obj) + } + + obj + } finally { + blocks.foreach(_.dispose()) + } + } + } + } + } + + private def readBlocks(setId: Long, id: Int): Array[BlockData] = { + val bm = SparkEnv.get.blockManager + // read meta + val metaId = ShardBlockId(setId, id, "meta") + val metaBuf = bm.getLocalBytes(metaId) match { + case Some(block) => + val x = block.toByteBuffer() + releaseBlockManagerLock(metaId) + x + case None => + bm.getRemoteBytes(metaId) match { + case Some(b) => + if (!bm.putBytes(metaId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException(s"Failed to store meta $metaId in local BlockManager") + } + b.toByteBuffer + case None => + throw new SparkException(s"Failed to get meta $metaId") + } + } + val meta = SparkEnv.get.serializer.newInstance().deserialize[ShardBlockMeta](metaBuf) + + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these chunks from this executor as well. + val blocks = new Array[BlockData](meta.numBlocks) + for (pid <- Random.shuffle(Seq.range(0, meta.numBlocks))) { + val pieceId = ShardBlockId(setId, id, "piece" + pid) + logDebug(log"Reading piece ${MDC(SHARD_ID, pieceId)}") + // First try getLocalBytes because there is a chance that previous attempts to fetch the + // shard blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + bm.getLocalBytes(pieceId) match { + case Some(block) => + blocks(pid) = block + releaseBlockManagerLock(pieceId) + case None => + bm.getRemoteBytes(pieceId) match { + case Some(b) => + meta.checksums.foreach { c => + val sum = calcChecksum(b.chunks(0)) + if (sum != c(pid)) { + throw new SparkException(s"corrupt remote block $pieceId: $sum != ${c(pid)}") + } + } + + if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException(s"Failed to store piece $pieceId in local BlockManager") + } + blocks(pid) = new ByteBufferBlockData(b, true) + case None => + throw new SparkException(s"Failed to get piece $pieceId") + } + } + } + blocks + } + + private def calcChecksum(block: ByteBuffer): Int = { + val adler = new Adler32() + if (block.hasArray) { + adler.update( + block.array, + block.arrayOffset + block.position(), + block.limit() + - block.position()) + } else { + val bytes = new Array[Byte](block.remaining()) + block.duplicate.get(bytes) + adler.update(bytes) + } + adler.getValue.toInt + } + + /** + * If running in a task, register the given block's locks for release upon task completion. + * Otherwise, if not running in a task then immediately release the lock. + */ + private def releaseBlockManagerLock(blockId: BlockId): Unit = { + val bm = SparkEnv.get.blockManager + Option(TaskContext.get()) match { + case Some(taskContext) => + taskContext.addTaskCompletionListener[Unit](_ => bm.releaseLock(blockId)) + case None => + // This should only happen on the driver, where shard variables may be accessed + // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow + // shard variables to be garbage collected we need to free the reference here + // which is slightly unsafe but is technically okay because shard variables aren't + // stored off-heap. + bm.releaseLock(blockId) + } + } + + private def blockifyObject[T: ClassTag]( + obj: T, + blockSize: Int, + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { + val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate) + val out = compressionCodec.map(_.compressedOutputStream(cbbos)).getOrElse(cbbos) + val ser = serializer.newInstance() + val serOut = ser.serializeStream(out) + Utils.tryWithSafeFinally { + serOut.writeObject[T](obj) + } { + serOut.close() + } + cbbos.toChunkedByteBuffer.getChunks() + } + + private def unBlockifyObject[T: ClassTag]( + blocks: Array[JInputStream], + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): T = { + require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) + val in: JInputStream = compressionCodec.map(_.compressedInputStream(is)).getOrElse(is) + val ser = serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } + obj + } +} + +private[spark] case class ShardBlockMeta( + numBlocks: Int, + checksums: Option[Array[Int]], + filterExprBytes: Option[Array[Byte]] = None, + filterSchemaBytes: Option[Array[Byte]] = None) + +private[spark] trait BloomAccumulator extends Serializable { + def add(in: java.io.InputStream): Unit + def isEmpty: Boolean + def finish(out: java.io.OutputStream): Unit +} + +private[spark] trait ShardCleanupCallback { + def onShardSetRemoved(setId: Long): Unit +} + +private[spark] trait ShardLookupAdapter extends Serializable { + def lookup(manager: ShardManager, reqMsg: ManagedBuffer): ManagedBuffer +} + +private[spark] object ShardManager extends Logging { + // Distinguished from local mode + private val ID_GENERATOR = new IdGenerator +} diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerEndpoint.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerEndpoint.scala new file mode 100644 index 0000000000000..0900d383a35c4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerEndpoint.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard + +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future} + +import org.apache.spark.internal.{Logging, MessageWithContext} +import org.apache.spark.internal.LogKeys.{SHARD_ID, SHARD_SET_ID} +import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv} +import org.apache.spark.shard.ShardManagerMessages.InstallReplica +import org.apache.spark.util.ThreadUtils + +private[spark] class ShardManagerEndpoint(val rpcEnv: RpcEnv, shardManager: ShardManager) + extends IsolatedThreadSafeRpcEndpoint + with Logging { + + private implicit val asyncEc: ExecutionContextExecutorService = + ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("shard-manager-async-thread-pool", 8)) + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case InstallReplica(setId, shardId) => + doAsync[Boolean]( + log"install replica (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, shardId)})", + context) { + shardManager.installReplica(setId, shardId) + true + } + } + + private def doAsync[T]( + actionMessage: MessageWithContext, + context: RpcCallContext)(body: => T): Unit = { + val future = Future { + logDebug(actionMessage.message) + body + } + future.onComplete { + case scala.util.Success(response) => + logDebug(s"Done ${actionMessage.message}, response is $response") + context.reply(response) + case scala.util.Failure(t) => + logError(log"Error in " + actionMessage, t) + context.sendFailure(t) + } + } + + override def onStop(): Unit = { + asyncEc.shutdown() + } + +} diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerId.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerId.scala new file mode 100644 index 0000000000000..afd03b6b0e8b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerId.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import com.google.common.cache.{CacheBuilder, CacheLoader} + +import org.apache.spark.util.Utils + +private[spark] class ShardManagerId( + private var executorId_ : String, + private var host_ : String, + private var port_ : Int, + private var topologyInfo_ : Option[String]) + extends Externalizable { + + private def this() = this(null, null, 0, None) // For deserialization only + + def executorId: String = executorId_ + + if (null != host_) { + Utils.checkHost(host_) + require(port_ > 0, s"port must be positive: $port_") + } + + def hostPort: String = { + Utils.checkHost(host) + require(port > 0, s"port must be positive: $port") + host + ":" + port + } + + def host: String = host_ + + def port: Int = port_ + + def topologyInfo: Option[String] = topologyInfo_ + + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + out.writeUTF(executorId_) + out.writeUTF(host_) + out.writeInt(port_) + out.writeBoolean(topologyInfo_.isDefined) + topologyInfo.foreach(out.writeUTF) + } + + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + executorId_ = in.readUTF() + host_ = in.readUTF() + port_ = in.readInt() + val isTopologyInfoAvailable = in.readBoolean() + topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None + } + + override def toString: String = s"ShardManagerId($executorId, $host, $port, $topologyInfo)" + + override def hashCode: Int = + ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode + + override def equals(that: Any): Boolean = that match { + case id: ShardManagerId => + executorId == id.executorId && + port == id.port && + host == id.host && + topologyInfo == id.topologyInfo + case _ => + false + } +} + +private[spark] object ShardManagerId { + def apply( + execId: String, + host: String, + port: Int, + topologyInfo: Option[String] = None): ShardManagerId = + getCachedShardManagerId(new ShardManagerId(execId, host, port, topologyInfo)) + + def apply(in: ObjectInput): ShardManagerId = { + val obj = new ShardManagerId() + obj.readExternal(in) + getCachedShardManagerId(obj) + } + + private val shardManagerIdCache = CacheBuilder + .newBuilder() + .maximumSize(1000) + .build(new CacheLoader[ShardManagerId, ShardManagerId]() { + override def load(id: ShardManagerId): ShardManagerId = id + }) + + private def getCachedShardManagerId(id: ShardManagerId): ShardManagerId = { + shardManagerIdCache.get(id) + } +} diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala new file mode 100644 index 0000000000000..616673b2f7f83 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard + +import com.google.common.cache.{CacheBuilder, CacheLoader} + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{EXECUTOR_ID, SHARD_ID, SHARD_MANAGER_ID, SHARD_SET_ID} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.shard.ShardManagerMessages._ + +/** + * Client-side proxy for the driver's [[ShardManagerMasterEndpoint]]. + * + * Provides RPC methods for shard set creation, location queries, replica + * installation, and executor removal used by both driver and executor code. + */ +private[spark] class ShardManagerMaster( + val masterEndpoint: RpcEndpointRef, + conf: SparkConf, + isDriver: Boolean) + extends Logging { + + private val LOCATION_CACHE_MAX_SIZE = 2048 + + private val shardLocationsCache = CacheBuilder + .newBuilder() + .maximumSize(LOCATION_CACHE_MAX_SIZE) + .build(new CacheLoader[ShardKey, Seq[ShardManagerId]] { + override def load(key: ShardKey): Seq[ShardManagerId] = { + getLocations(key) + } + }) + + def registerShardManager( + id: ShardManagerId, + managerEndpoint: RpcEndpointRef): ShardManagerId = { + logInfo(log"Registering ShardManager ${MDC(SHARD_MANAGER_ID, id)}") + val updatedId = + masterEndpoint.askSync[ShardManagerId](RegisterShardManager(id, managerEndpoint)) + logInfo(log"Registered ShardManager ${MDC(SHARD_MANAGER_ID, updatedId)}") + updatedId + } + + def newShardSet(numShards: Int, replicaCount: Int): Long = { + masterEndpoint.askSync[Long](NewShardSet(numShards, replicaCount)) + } + + def reportShard(shardManagerId: ShardManagerId, setId: Long, id: Int): Unit = { + masterEndpoint.askSync[Boolean](UpdateShardInfo(shardManagerId, setId, id)) + } + + def installReplicaSet(setId: Long, shardId: Int): Unit = { + // non-blocking + masterEndpoint.ask[Boolean](InstallReplicaSet(setId, shardId)) + logInfo(log"Install replica set of shard" + + log" (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, shardId)}) requested") + } + + def getLocations(setId: Long, shardId: Int, refresh: Boolean = false): Seq[ShardManagerId] = { + val key = ShardKey(setId, shardId) + if (refresh) { + shardLocationsCache.invalidate(key) + } + shardLocationsCache.get(key) + } + + def removeExecutor(execId: String): Unit = { + masterEndpoint.ask[Boolean](RemoveExecutor(execId)) + logInfo(log"Removal(shard) of executor ${MDC(EXECUTOR_ID, execId)} requested") + } + + private def getLocations(key: ShardKey): Seq[ShardManagerId] = { + masterEndpoint.askSync[Seq[ShardManagerId]](GetLocations(key.setId, key.shardId)) + } + + def stop(): Unit = { + if (masterEndpoint != null && isDriver) { + tell(StopShardManagerMaster) + logInfo(log"ShardManagerMaster stopped") + } + } + + private def tell(message: Any): Unit = { + if (!masterEndpoint.askSync[Boolean](message)) { + throw new SparkException("Unexpected ShardManagerMasterEndpoint result error") + } + } +} + +private[spark] object ShardManagerMaster { + val DRIVER_ENDPOINT_NAME = "ShardManagerMaster" +} diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala new file mode 100644 index 0000000000000..f4bef2c192b2d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard + +import java.util.{HashMap => JHashMap} +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService} +import scala.jdk.CollectionConverters._ +import scala.util.Random + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{EXECUTOR_ID, SHARD_ID, SHARD_MANAGER_ID, SHARD_SET_ID} +import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} +import org.apache.spark.shard.ShardManagerMessages._ +import org.apache.spark.util.ThreadUtils + +/** + * Driver-side RPC endpoint that tracks shard locations across executors + * and coordinates replica placement for distributed map join. + */ +private[spark] class ShardManagerMasterEndpoint( + val rpcEnv: RpcEnv, + val isLocal: Boolean, + conf: SparkConf, + shardManagerInfo: mutable.Map[ShardManagerId, ShardManagerInfo], + isDriver: Boolean) + extends IsolatedThreadSafeRpcEndpoint + with Logging { + + private val EXEC_LOAD_WEIGHT = 10 + private val HOST_LOAD_WEIGHT = 5 + private val SAME_HOST_PENALTY = 100000L + private val JITTER_BOUND = 256 + + private val nextShardSetId = new AtomicLong(0) + + private val shardSetInfo = new mutable.HashMap[Long, ShardSetInfo] + private val shardSetLocations = + new JHashMap[Long, JHashMap[Int, mutable.HashSet[ShardManagerId]]] + + private implicit val askEc: ExecutionContextExecutorService = + ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("shard-manager-ask-thread-pool", 8)) + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterShardManager(id, endpoint) => + context.reply(register(id, endpoint)) + + case NewShardSet(numShards, replicaCount) => + context.reply(newShardSet(numShards, replicaCount)) + + case UpdateShardInfo(shardManagerId, setId, shardId) => + updateShardInfo(shardManagerId, setId, shardId) + context.reply(true) + + case InstallReplicaSet(setId, shardId) => + installReplicaToWorkers(setId, shardId) + context.reply(true) + + case GetLocations(setId, shardId) => + context.reply(getLocations(setId, shardId)) + + case RemoveExecutor(execId) => + removeExecutor(execId) + context.reply(true) + + case StopShardManagerMaster => + context.reply(true) + stop() + } + + private def register( + idWithoutTopologyInfo: ShardManagerId, + managerEndpoint: RpcEndpointRef): ShardManagerId = { + val id = ShardManagerId( + idWithoutTopologyInfo.executorId, + idWithoutTopologyInfo.host, + idWithoutTopologyInfo.port, + None) + shardManagerInfo(id) = new ShardManagerInfo(id, managerEndpoint) + id + } + + private def removeShardManager(shardManagerId: ShardManagerId): Unit = { + val info = shardManagerInfo(shardManagerId) + shardManagerInfo.remove(shardManagerId) + + val iterator = info.shards.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + val setId = entry.getKey + val shardLocations = shardSetLocations.get(setId) + if (shardLocations != null) { + val valueIterator = entry.getValue.iterator + while (valueIterator.hasNext) { + val shardId = valueIterator.next() + val locations = shardLocations.get(shardId) + if (locations != null) { + locations -= shardManagerId + if (locations.isEmpty) { + shardLocations.remove(shardId) + logWarning(log"No more shard replicas available for" + + log" (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, shardId)})") + } else { + // proactively replicate here. + installReplicaToWorkers(setId, shardId) + } + } + } + if (shardLocations.isEmpty) { + shardSetLocations.remove(setId) + } + } + } + + logInfo(log"Removing shard manager ${MDC(SHARD_MANAGER_ID, shardManagerId)}") + } + + private def newShardSet(numShards: Int, replicaCount: Int): Long = { + val setId = nextShardSetId.getAndIncrement() + shardSetInfo(setId) = new ShardSetInfo(numShards, replicaCount) + setId + } + + private def updateShardInfo(shardManagerId: ShardManagerId, setId: Long, shardId: Int): Unit = { + if (!shardManagerInfo.contains(shardManagerId)) return + shardManagerInfo(shardManagerId).updateShardInfo(setId, shardId) + + val shardLocations = shardSetLocations.computeIfAbsent(setId, + _ => new JHashMap[Int, mutable.HashSet[ShardManagerId]]) + val locations = shardLocations.computeIfAbsent(shardId, + _ => new mutable.HashSet[ShardManagerId]) + locations.add(shardManagerId) + } + + /** + * Place shard replicas on executors. Prefers executors with lower shard load + * and hosts that don't already hold this shard. Score is + * execLoad*10 + hostLoad*5 + sameHostPenalty(100000) + jitter. + */ + private def installReplicaToWorkers(setId: Long, shardId: Int): Unit = { + shardSetInfo + .get(setId) + .foreach { setInfo => + val targetReplica = setInfo.replicaCount + val currentHolders = getLocations(setId, shardId) + + val have = currentHolders.size + val need = math.max(0, targetReplica - have) + + val perExecSetLoad = mutable.HashMap.empty[String, Int].withDefaultValue(0) + val perHostSetLoad = mutable.HashMap.empty[String, Int].withDefaultValue(0) + + def inc(m: scala.collection.mutable.Map[String, Int], k: String, d: Int = 1): Unit = { + m.update(k, m.getOrElse(k, 0) + d) + } + def bumpLoads(smi: ShardManagerId, execDelta: Int = 1, hostDelta: Int = 1): Unit = { + inc(perExecSetLoad, smi.executorId, execDelta) + inc(perHostSetLoad, smi.host, hostDelta) + } + + Option(shardSetLocations.get(setId)).foreach(locMap => + locMap.values().asScala.foreach(holderSet => holderSet.foreach(smi => bumpLoads(smi)))) + + val usedHosts = mutable.HashSet.empty[String] + currentHolders.foreach(smi => usedHosts += smi.host) + + val chosen = mutable.ArrayBuffer.empty[ShardManagerId] + val chosenExecIds = mutable.HashSet.empty[String] + val rng = new Random(setId ^ (shardId.toLong << 32)) + + def score(smi: ShardManagerId): Long = { + val execLoad = perExecSetLoad(smi.executorId) + val hostLoad = perHostSetLoad(smi.host) + val sameHost = if (usedHosts.contains(smi.host)) SAME_HOST_PENALTY else 0L + execLoad * EXEC_LOAD_WEIGHT + hostLoad * HOST_LOAD_WEIGHT + + sameHost + (rng.nextInt(JITTER_BOUND) & 0xff).toLong + } + + val currentExecIds = currentHolders.map(_.executorId).toSet + val candidates = + shardManagerInfo.keys.filterNot(smi => currentExecIds.contains(smi.executorId)) + + var remaining = need + while (remaining > 0) { + val legals = candidates.filterNot(smi => chosenExecIds.contains(smi.executorId)) + if (legals.isEmpty) { + remaining = 0 + } else { + val best = legals.minBy(score) + chosen += best + chosenExecIds += best.executorId + usedHosts += best.host + bumpLoads(best) + remaining -= 1 + } + } + + chosen.foreach { smi => + shardManagerInfo.get(smi).foreach { sm => + sm.managerEndpoint.ask[Boolean](InstallReplica(setId, shardId)) + } + } + } + } + + private def removeExecutor(execId: String): Unit = { + logInfo(log"Trying to remove executor ${MDC(EXECUTOR_ID, execId)} from ShardManagerMaster.") + val ids = shardManagerInfo.keys.filter(_.executorId.equals(execId)).toSeq + ids.foreach(removeShardManager) + } + + private def getLocations(setId: Long, shardId: Int): Seq[ShardManagerId] = { + if (!shardSetLocations + .containsKey(setId) || !shardSetLocations.get(setId).containsKey(shardId)) { + Seq.empty + } else shardSetLocations.get(setId).get(shardId).toSeq + } + + override def onStop(): Unit = { + askEc.shutdown() + } +} + +private[spark] class ShardSetInfo(val numShards: Int, val replicaCount: Int) extends Logging {} + +private[spark] class ShardManagerInfo( + val shardManagerId: ShardManagerId, + val managerEndpoint: RpcEndpointRef) + extends Logging { + private val _shards = new JHashMap[Long, mutable.HashSet[Int]]() + + def updateShardInfo(setId: Long, id: Int): Unit = { + _shards.computeIfAbsent(setId, _ => new mutable.HashSet[Int]()).add(id) + } + + def shards: JHashMap[Long, mutable.HashSet[Int]] = _shards +} diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala new file mode 100644 index 0000000000000..9f8e03ddd4a83 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard + +import org.apache.spark.rpc.RpcEndpointRef + +private[spark] object ShardManagerMessages { + + // Messages from executor-side endpoints to the master. + sealed trait ToShardManagerMasterEndpoint + + case class RegisterShardManager(id: ShardManagerId, managerEndpoint: RpcEndpointRef) + extends ToShardManagerMasterEndpoint + + case class NewShardSet(numShards: Int, replicaCount: Int) extends ToShardManagerMasterEndpoint + + case class UpdateShardInfo(shardManagerId: ShardManagerId, setId: Long, id: Int) + extends ToShardManagerMasterEndpoint + + case class InstallReplicaSet(setId: Long, shardId: Int) extends ToShardManagerMasterEndpoint + + case class GetLocations(setId: Long, shardId: Int) extends ToShardManagerMasterEndpoint + + case class RemoveExecutor(execId: String) extends ToShardManagerMasterEndpoint + + case object StopShardManagerMaster extends ToShardManagerMasterEndpoint + + // Messages from the master to executor-side endpoints. + sealed trait ToShardManagerEndpoint + + case class InstallReplica(setId: Long, shardId: Int) extends ToShardManagerEndpoint +} diff --git a/core/src/main/scala/org/apache/spark/shard/package-info.java b/core/src/main/scala/org/apache/spark/shard/package-info.java new file mode 100644 index 0000000000000..d6ed012e335ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shard/package-info.java @@ -0,0 +1,18 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shard; diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 3e46a53ee082c..c6c385497802c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -161,6 +161,12 @@ case class BroadcastBlockId(broadcastId: Long, field: String = "") extends Block override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } +@DeveloperApi +case class ShardBlockId(setId: Long, shardId: Int, field: String = "") extends BlockId { + override def name: String = "shard_" + setId + "_" + shardId + + (if (field == "") "" else "_" + field) +} + @DeveloperApi case class TaskResultBlockId(taskId: Long) extends BlockId { override def name: String = "taskresult_" + taskId @@ -278,6 +284,7 @@ object BlockId { "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+)_([0-9]+).meta".r val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r + val SHARD = "shard_([0-9]+)_(-?[0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val PYTHON_STREAM = "python-stream-([0-9]+)-([0-9]+)".r @@ -316,6 +323,8 @@ object BlockId { chunkId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) + case SHARD(setId, id, field) => + ShardBlockId(setId.toLong, id.toInt, field.stripPrefix("_")) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 5fbc8dca74f68..ef90e547a879b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -2127,6 +2127,14 @@ private[spark] class BlockManager( blocksToRemove.size } + def removeShardSet(setId: Long, tellMaster: Boolean): Int = { + val blocksToRemove = blockInfoManager.entries.map(_._1).collect { + case sid @ ShardBlockId(`setId`, _, _) => sid + } + blocksToRemove.foreach { shardId => removeBlock(shardId, tellMaster) } + blocksToRemove.size + } + /** * Remove cache blocks that might be related to cached local relations. * diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 98bd52fc0886c..2d3b0aa0a7a55 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -229,6 +229,16 @@ class BlockManagerMaster( } } + def removeShardSet(setId: Long, blocking: Boolean): Unit = { + val future = driverEndpoint.askSync[Future[Seq[Int]]](RemoveShardSet(setId)) + future.failed.foreach(e => + logWarning(s"Failed to remove shard-set $setId - ${e.getMessage}", e) + )(ThreadUtils.sameThread) + if (blocking) { + waitBlockRemovalTimeout.awaitResult(future) + } + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 9d6539e09f452..8c644adbad261 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -203,6 +203,9 @@ class BlockManagerMasterEndpoint( case RemoveBroadcast(broadcastId, removeFromDriver) => context.reply(removeBroadcast(broadcastId, removeFromDriver)) + case RemoveShardSet(setId) => + context.reply(removeShardSet(setId)) + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) context.reply(true) @@ -497,6 +500,18 @@ class BlockManagerMasterEndpoint( Future.sequence(futures) } + private def removeShardSet(setId: Long): Future[Seq[Int]] = { + val removeMsg = RemoveShardSet(setId) + val futures = blockManagerInfo.values.map { bm => + bm.storageEndpoint.ask[Int](removeMsg).recover { + // use 0 as default value means no blocks were removed + handleBlockRemovalFailure("shard-set", setId.toString, bm.blockManagerId, 0) + } + }.toSeq + + Future.sequence(futures) + } + private def removeBlockManager(blockManagerId: BlockManagerId): Unit = { val info = blockManagerInfo(blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 7fb145556a118..bbaf6635ecf08 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -51,6 +51,8 @@ private[spark] object BlockManagerMessages { // Mark a rdd block as visible. case class MarkRDDBlockAsVisible(blockId: RDDBlockId) extends ToBlockManagerMasterStorageEndpoint + case class RemoveShardSet(setId: Long) extends ToBlockManagerMasterStorageEndpoint + /** * Driver to Executor message to trigger a thread dump. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala index 54329c5b1e514..138f1c578d551 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Futu import org.apache.spark.{MapOutputTracker, SparkEnv} import org.apache.spark.internal.{Logging, MessageWithContext} -import org.apache.spark.internal.LogKeys.{BLOCK_ID, BROADCAST_ID, RDD_ID, SHUFFLE_ID} +import org.apache.spark.internal.LogKeys.{BLOCK_ID, BROADCAST_ID, RDD_ID, SHARD_ID, SHUFFLE_ID} import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -76,6 +76,12 @@ class BlockManagerStorageEndpoint( doAsync[Int](log"removing broadcast ${MDC(BROADCAST_ID, broadcastId)}", context) { blockManager.removeBroadcast(broadcastId, tellMaster = true) } + case RemoveShardSet(setId: Long) => + doAsync[Int](log"removing shard-set ${MDC(SHARD_ID, setId)}", context) { + val sm = SparkEnv.get.shardManager + if (sm != null) sm.invokeCleanupCallbacks(setId) + blockManager.removeShardSet(setId, tellMaster = true) + } case GetBlockStatus(blockId, _) => context.reply(blockManager.getStatus(blockId)) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 813de4132ab2d..a0cfa7659327c 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -348,12 +348,14 @@ class CleanerTester( rddIds: Seq[Int] = Seq.empty, shuffleIds: Seq[Int] = Seq.empty, broadcastIds: Seq[Long] = Seq.empty, + shardSetIds: Seq[Long] = Seq.empty, checkpointIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] ++= shuffleIds val toBeCleanedBroadcastIds = new HashSet[Long] ++= broadcastIds + val toBeCleanedShardSetIds = new HashSet[Long] ++= shardSetIds val toBeCheckpointIds = new HashSet[Long] ++= checkpointIds val isDistributed = !sc.isLocal @@ -373,6 +375,11 @@ class CleanerTester( logInfo("Broadcast " + broadcastId + " cleaned") } + def shardSetCleaned(setId: Long): Unit = { + toBeCleanedShardSetIds.synchronized { toBeCleanedShardSetIds -= setId } + logInfo("ShardSet " + setId + " cleaned") + } + def accumCleaned(accId: Long): Unit = { logInfo("Cleaned accId " + accId + " cleaned") } @@ -496,10 +503,14 @@ class CleanerTester( val s3 = toBeCleanedBroadcastIds.synchronized { toBeCleanedBroadcastIds.toSeq.sorted.mkString("[", ", ", "]") } + val s4 = toBeCleanedShardSetIds.synchronized { + toBeCleanedShardSetIds.toSeq.sorted.mkString("[", ", ", "]") + } s""" |\tRDDs = $s1 |\tShuffles = $s2 |\tBroadcasts = $s3 + |\tShardSets = $s4 """.stripMargin } @@ -507,7 +518,8 @@ class CleanerTester( toBeCleanedRDDIds.synchronized { toBeCleanedRDDIds.isEmpty } && toBeCleanedShuffleIds.synchronized { toBeCleanedShuffleIds.isEmpty } && toBeCleanedBroadcastIds.synchronized { toBeCleanedBroadcastIds.isEmpty } && - toBeCheckpointIds.synchronized { toBeCheckpointIds.isEmpty } + toBeCheckpointIds.synchronized { toBeCheckpointIds.isEmpty } && + toBeCleanedShardSetIds.synchronized { toBeCleanedShardSetIds.isEmpty } private def getRDDBlocks(rddId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index f6d20dd71b1f4..50ff9d70f9631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Ascending, ByteLiteral, Expression, IntegerLiteral, ShortLiteral, SortOrder, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.{Ascending, ByteLiteral, EqualTo, Expression, IntegerLiteral, Literal, ShortLiteral, SortOrder, StringLiteral} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -65,6 +65,45 @@ object ResolveHints { _.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT)))) } + private def createHintInfo(ident: Seq[String], hint: UnresolvedHint): HintInfo = { + + if (DistributedMapJoinStrategy(None, None).hintAliases + .map(_.toUpperCase(Locale.ROOT)) + .contains(hint.name.toUpperCase(Locale.ROOT))) { + + def keyName(e: Expression): Option[String] = e match { + case UnresolvedAttribute(parts) => Some(parts.last.toLowerCase(Locale.ROOT)) + case _ => None + } + def intValue(e: Expression): Option[Int] = e match { + case IntegerLiteral(i) => Some(i) + case Literal(i: Int, _) => Some(i) + case _ => None + } + + val keyOpts = hint.parameters + .collectFirst { + case uf: UnresolvedFunction if matchedIdentifier(uf.nameParts, ident) => + uf.arguments + .collect { case EqualTo(lhs, rhs) => + for { + k <- keyName(lhs) + v <- intValue(rhs) + } yield k -> v + } + .flatten + .toMap + } + .getOrElse(Map.empty) + HintInfo(strategy = + Some(DistributedMapJoinStrategy( + keyOpts.get(DistributedMapJoinStrategy.KEY_SHARD_COUNT), + keyOpts.get(DistributedMapJoinStrategy.KEY_REPLICA_COUNT)))) + } else { + createHintInfo(hint.name) + } + } + // This method checks if given multi-part identifiers are matched with each other. // The [[ResolveJoinStrategyHints]] rule is applied before the resolution batch // in the analyzer and we cannot semantically compare them at this stage. @@ -95,7 +134,7 @@ object ResolveHints { plan: LogicalPlan, relationsInHint: Set[Seq[String]], relationsInHintWithMatch: mutable.HashSet[Seq[String]], - hintName: String): LogicalPlan = { + hint: UnresolvedHint): LogicalPlan = { // Whether to continue recursing down the tree var recurse = true @@ -106,19 +145,20 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case ResolvedHint(u @ UnresolvedRelation(ident, _, _), hint) + case ResolvedHint(u @ UnresolvedRelation(ident, _, _), info) if matchedIdentifierInHint(ident) => - ResolvedHint(u, createHintInfo(hintName).merge(hint, hintErrorHandler)) + ResolvedHint(u, createHintInfo(ident, hint).merge(info, hintErrorHandler)) - case ResolvedHint(r: SubqueryAlias, hint) + case ResolvedHint(r: SubqueryAlias, info) if matchedIdentifierInHint(extractIdentifier(r)) => - ResolvedHint(r, createHintInfo(hintName).merge(hint, hintErrorHandler)) + ResolvedHint(r, createHintInfo(extractIdentifier(r), hint) + .merge(info, hintErrorHandler)) case UnresolvedRelation(ident, _, _) if matchedIdentifierInHint(ident) => - ResolvedHint(plan, createHintInfo(hintName)) + ResolvedHint(plan, createHintInfo(ident, hint)) case r: SubqueryAlias if matchedIdentifierInHint(extractIdentifier(r)) => - ResolvedHint(plan, createHintInfo(hintName)) + ResolvedHint(plan, createHintInfo(extractIdentifier(r), hint)) case _: ResolvedHint | _: View | _: UnresolvedWith | _: SubqueryAlias => // Don't traverse down these nodes. @@ -137,7 +177,7 @@ object ResolveHints { if ((plan fastEquals newNode) && recurse) { newNode.mapChildren { child => - applyJoinStrategyHint(child, relationsInHint, relationsInHintWithMatch, hintName) + applyJoinStrategyHint(child, relationsInHint, relationsInHintWithMatch, hint) } } else { newNode @@ -155,12 +195,13 @@ object ResolveHints { val relationNamesInHint = h.parameters.map { case StringLiteral(tableName) => UnresolvedAttribute.parseAttributeName(tableName) case tableId: UnresolvedAttribute => tableId.nameParts + case uf: UnresolvedFunction => uf.nameParts case unsupported => throw QueryCompilationErrors.joinStrategyHintParameterNotSupportedError(unsupported) }.toSet val relationsInHintWithMatch = new mutable.HashSet[Seq[String]] val applied = applyJoinStrategyHint( - h.child, relationNamesInHint, relationsInHintWithMatch, h.name) + h.child, relationNamesInHint, relationsInHintWithMatch, h) // Filters unmatched relation identifiers in the hint val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 13e3cb76805d8..9c3fe932812a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -315,6 +315,33 @@ trait JoinSelectionHelper extends Logging { ) } + def getDistributedMapJoinBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + + if (!hintOnly) { + return None + } + + val buildLeft = hintToDistributedMapJoinLeft(hint) + val buildRight = hintToDistributedMapJoinRight(hint) + + if (!buildLeft && !buildRight) { + return None + } + + getBuildSide( + canBuildDistributedMapJoinLeft(joinType) && buildLeft, + canBuildDistributedMapJoinRight(joinType) && buildRight, + left, + right + ) + } + def getShuffleHashJoinBuildSide( join: Join, hintOnly: Boolean, @@ -388,6 +415,20 @@ trait JoinSelectionHelper extends Logging { } } + def canBuildDistributedMapJoinLeft(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | RightOuter => true + case _ => false + } + } + + def canBuildDistributedMapJoinRight(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } + } + def canBuildShuffledHashJoinLeft(joinType: JoinType): Boolean = { joinType match { case _: InnerLike | LeftOuter | FullOuter | RightOuter => true @@ -465,6 +506,14 @@ trait JoinSelectionHelper extends Logging { } } + def hintToDistributedMapJoinLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.exists(isDistributedMapJoin)) + } + + def hintToDistributedMapJoinRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.exists(isDistributedMapJoin)) + } + def hintToShuffleHashJoinLeft(hint: JoinHint): Boolean = { hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) } @@ -559,4 +608,7 @@ trait JoinSelectionHelper extends Logging { Utils.isTesting && conf.getConfString("spark.sql.join.forceApplyShuffledHashJoin", "false") == "true" } + + private def isDistributedMapJoin(s: JoinStrategyHint): Boolean = + s.isInstanceOf[DistributedMapJoinStrategy] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index 8a6182b87b77c..9a585eab65b96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -125,6 +125,7 @@ object JoinStrategyHint { val strategies: Set[JoinStrategyHint] = Set( BROADCAST, + DistributedMapJoinStrategy(None, None), SHUFFLE_MERGE, SHUFFLE_HASH, SHUFFLE_REPLICATE_NL) @@ -142,6 +143,25 @@ case object BROADCAST extends JoinStrategyHint { "MAPJOIN") } +case class DistributedMapJoinStrategy( + shards: Option[Int], + replicas: Option[Int]) extends JoinStrategyHint { + + override val displayName: String = { + val params = List("shard_count" -> shards, "replica_count" -> replicas) + .collect { case (k, Some(v)) => s"$k=$v" } + .mkString(", ") + s"distmapjoin${if (params.isEmpty) "" else s"($params)"}" + } + + override val hintAliases: Set[String] = Set("DISTMAPJOIN") +} + +object DistributedMapJoinStrategy { + val KEY_SHARD_COUNT: String = "shard_count" + val KEY_REPLICA_COUNT: String = "replica_count" +} + /** * The hint for shuffle sort merge join. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d2bb12d2053aa..230f3398edd5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} /** * Specifies how tuples that share common expressions will be distributed when a query is executed @@ -208,6 +208,34 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { } } +/** + * A distribution where data is distributed by hashing the given expressions into a fixed number + * of shards. Used by DistributedMapJoin to ensure build side data is properly sharded for remote + * lookup. + * + * @param buildKeys + * The expressions to hash on + * @param numShards + * Number of shards (partitions) + */ +case class ShardDistribution( + buildKeys: Seq[Expression], + numShards: Int, + replicaCount: Int, + filterExpr: Option[Expression] = None, + filterSchema: Option[StructType] = None) + extends Distribution { + + override def requiredNumPartitions: Option[Int] = Some(numShards) + + /** + * Creates a default partitioning for this distribution, which can satisfy this distribution + * while matching the given number of partitions. + */ + override def createPartitioning(numPartitions: Int): Partitioning = + HashPartitioning(buildKeys, numShards) +} + /** * Describes how an operator's output is split across partitions. It has 2 major properties: * 1. number of partitions. @@ -929,6 +957,25 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { } } +case class ShardPartitioning( + buildKeys: Seq[Expression], + numShards: Int) extends Partitioning { + + override val numPartitions: Int = numShards + + override protected def satisfies0(required: Distribution): Boolean = required match { + case ShardDistribution(keys, shards, _, _, _) => + shards == numShards && sameKeys(keys) + case _ => false + } + + private def sameKeys(req: Seq[Expression]): Boolean = { + val a = buildKeys.map(_.canonicalized) + val b = req.map(_.canonicalized) + a.length == b.length && a.zip(b).forall { case (x, y) => x.semanticEquals(y) } + } +} + /** * This is used in the scenario where an operator has multiple children (e.g., join) and one or more * of which have their own requirement regarding whether its data can be considered as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f76dec323ef3..bff770556b3a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -7488,6 +7488,35 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DISTRIBUTED_MAP_JOIN_MAX_IN_FLIGHT_NUM = + buildConf("spark.sql.execution.distributedMapJoin.maxInFlightNum") + .doc("Maximum number of concurrent RPC lookup batches per task on the probe side.") + .version("5.0.0") + .intConf + .createWithDefault(8) + + val DISTRIBUTED_MAP_JOIN_MAX_BATCH_SIZE = + buildConf("spark.sql.execution.distributedMapJoin.maxBatchSize") + .doc("Maximum number of probe-side keys per RPC lookup batch.") + .version("5.0.0") + .intConf + .createWithDefault(1024) + + val DISTRIBUTED_MAP_JOIN_BLOOM_FILTER_CAPACITY = + buildConf("spark.sql.execution.distributedMapJoin.bloomFilterCapacity") + .doc("Expected number of distinct keys per shard for the build-side bloom filter. " + + "The total bloom filter capacity is this value multiplied by the number of shards.") + .version("5.0.0") + .longConf + .createWithDefault(5L << 20) + + val DISTRIBUTED_MAP_JOIN_EXCHANGE_TIMEOUT = + buildConf("spark.sql.execution.distributedMapJoin.exchangeTimeout") + .doc("Timeout in seconds for building shard data in distributed map join.") + .version("5.0.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString(s"${30 * 60}") + /** * Holds information about keys that have been deprecated. * @@ -8063,6 +8092,12 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def broadcastHashJoinOutputPartitioningExpandLimit: Int = getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) + def distributedMapJoinMaxInFlightNum: Int = getConf(DISTRIBUTED_MAP_JOIN_MAX_IN_FLIGHT_NUM) + + def distributedMapJoinMaxBatchSize: Int = getConf(DISTRIBUTED_MAP_JOIN_MAX_BATCH_SIZE) + + def distributedMapJoinExchangeTimeout: Long = getConf(DISTRIBUTED_MAP_JOIN_EXCHANGE_TIMEOUT) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java new file mode 100644 index 0000000000000..d5f2f251585cd --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java @@ -0,0 +1,485 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.Map; +import java.util.NoSuchElementException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.map.HashMapGrowthStrategy; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * Probe-side key batching buffer for distributed map join. + * + * Accumulates streamed-side keys into per-shard batches backed by off-heap + * Tungsten pages, then serializes them into Netty buffers for RPC lookup + * against build-side executors. + */ +public final class BufferedShardRowMap extends MemoryConsumer { + + private static final Logger logger = LoggerFactory.getLogger(BufferedShardRowMap.class); + + private final TaskMemoryManager taskMemoryManager; + + private final long setId; + private final int numShards; + private final int batchCapacity; + private final int mask; + private final int uaoSize; + + private final int numFields; + private final UnsafeRow valueUr; + + private final KeyValueBatch[] tailingBatches; + private final LinkedList pendingBatches; + + private final LinkedList pageList; + + private final PooledByteBufAllocator alloc; + + private KeyValuePage currentPage = null; + + public BufferedShardRowMap(TaskMemoryManager taskMemoryManager, long setId, int numShards, + int numFields, UnsafeRow valueUr, int batchSize) { + super(taskMemoryManager, taskMemoryManager.getTungstenMemoryMode()); + this.taskMemoryManager = taskMemoryManager; + this.setId = setId; + this.numShards = numShards; + this.tailingBatches = new KeyValueBatch[numShards]; + this.pendingBatches = new LinkedList<>(); + this.pageList = new LinkedList<>(); + this.numFields = numFields; + this.valueUr = valueUr; + this.alloc = NettyUtils.getSharedPooledByteBufAllocator(true, true); + this.batchCapacity = Math.max((int) ByteArrayMethods.nextPowerOf2(batchSize), 32); + this.mask = HashMapGrowthStrategy.DOUBLING.nextCapacity(batchCapacity) - 1; + this.uaoSize = UnsafeAlignedOffset.getUaoSize(); + } + + public void putRow(int shard, + Object kbase, long koff, int klen, int khash, + Object vbase, long voff, int vlen) { + final int idx = shard % numShards; + if (tailingBatches[idx] == null) { + tailingBatches[idx] = new KeyValueBatch(shard); + } + KeyValueBatch batch = tailingBatches[idx]; + if (batch.capReached()) { + pendingBatches.add(batch); + tailingBatches[idx] = new KeyValueBatch(shard); + batch = tailingBatches[idx]; + } + batch.append(kbase, koff, klen, khash, vbase, voff, vlen); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this) { + synchronized (this) { + final Iterator iter = pageList.iterator(); + long released = 0L; + while (iter.hasNext()) { + final KeyValuePage page = iter.next(); + if (page.nonRef()) { + iter.remove(); + if (page == currentPage) { + currentPage = null; + } + released += page.getSize(); + page.free(); + if (released >= size) { + break; + } + } + } + return released; + } + } + return 0L; + } + + public synchronized void free() { + Iterator valueIter = pageList.iterator(); + while (valueIter.hasNext()) { + KeyValuePage valuePage = valueIter.next(); + valueIter.remove(); + valuePage.free(); + } + } + + private synchronized KeyValuePage acquirePage(long required) { + for (KeyValuePage page : pageList) { + if (page.nonRef()) { + page.resetCursor(); + } + + if (required <= page.getWritableSize()) { + return page; + } + } + KeyValuePage page = new KeyValuePage(allocatePage(uaoSize + required)); + page.resetCursor(); + pageList.add(page); + return page; + } + + public boolean hasPending() { + return !pendingBatches.isEmpty(); + } + + public Iterator pendingIterator() { + return pendingBatches.iterator(); + } + + public Iterator tailingIterator() { + return new Iterator() { + private int nextIdx = advance(0); + private int lastIdx = -1; + private boolean canRemove = false; + + @Override + public boolean hasNext() { + return nextIdx != -1; + } + + @Override + public KeyValueBatch next() { + if (nextIdx == -1) throw new NoSuchElementException(); + lastIdx = nextIdx; + KeyValueBatch v = tailingBatches[lastIdx]; + nextIdx = advance(lastIdx + 1); + canRemove = true; + return v; + } + + @Override + public void remove() { + if (!canRemove) { + throw new IllegalStateException("next() not called or remove() already called"); + } + tailingBatches[lastIdx] = null; + canRemove = false; + } + + private int advance(int from) { + for (int i = from; i < tailingBatches.length; i++) { + if (tailingBatches[i] != null) return i; + } + return -1; + } + }; + } + + @FunctionalInterface + private interface ObjectLongIntConsumer { + void accept(Object base, long off, int len); + } + + private ObjectLongIntConsumer makeAdvanceWrite(final ByteBuf buf) { + if (buf.hasArray()) { + return (base, off, len) -> { + final int idx = buf.writerIndex(); + Platform.copyMemory(base, off, buf.array(), + Platform.BYTE_ARRAY_OFFSET + buf.arrayOffset() + idx, len); + buf.writerIndex(idx + len); + }; + } else if (buf.hasMemoryAddress()) { + return (base, off, len) -> { + final int idx = buf.writerIndex(); + Platform.copyMemory(base, off, null, buf.memoryAddress() + idx, len); + buf.writerIndex(idx + len); + }; + } else { + return (base, off, len) -> { + final byte[] arr = new byte[len]; + Platform.copyMemory(base, off, arr, Platform.BYTE_ARRAY_OFFSET, len); + buf.writeBytes(arr); + }; + } + } + + public final class KeyValueBatch { + private final int shard; + private final Map keyMap; + private final LinkedList innerPageList; + + private int numKeys; + private int numValues; + private int sizeKeys; + private KeyValuePage lastWritePage; + + private KeyValueBatch(int shard) { + this.shard = shard; + this.keyMap = new HashMap<>(); + this.innerPageList = new LinkedList<>(); + } + + public long getSetId() { + return setId; + } + + public int getShard() { + return shard; + } + + public ManagedBuffer wrapKeysBuffer() { + // header: setId(8) + shardId(4) + numKeyFields(4) + final int headerSize = Long.BYTES + Integer.BYTES * 2; + final int capacity = headerSize + Integer.BYTES * numKeys + sizeKeys; + final ByteBuf buf = alloc.buffer(capacity); + buf.writeLong(setId); + buf.writeInt(shard); + buf.writeInt(numFields); + + final ObjectLongIntConsumer advanceWrite = makeAdvanceWrite(buf); + + for (long address : keyMap.values()) { + final Object base = taskMemoryManager.getPage(address); + final long off = taskMemoryManager.getOffsetInPage(address); + final int klen = UnsafeAlignedOffset.getSize(base, off); + buf.writeInt(klen); + advanceWrite.accept(base, off + (uaoSize * 2L), klen); + } + + return new NettyManagedBuffer(buf); + } + + public Iterator> multiValuesIterator() { + return new Iterator>() { + private final Iterator iter = keyMap.values().iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Iterator next() { + long addr = iter.next(); + Object base = taskMemoryManager.getPage(addr); + long off = taskMemoryManager.getOffsetInPage(addr); + int len = UnsafeAlignedOffset.getSize(base, off); + off += ((uaoSize * 2L) + len); + final long headAddr = Platform.getLong(base, off); + + return new Iterator() { + + private long address = headAddr; + + @Override + public boolean hasNext() { + return address > 0L; + } + + @Override + public UnsafeRow next() { + // (value length) (value) (next) + Object base = taskMemoryManager.getPage(address); + long off = taskMemoryManager.getOffsetInPage(address); + int vlen = UnsafeAlignedOffset.getSize(base, off); + off += uaoSize; + valueUr.pointTo(base, off, vlen); + off += vlen; + address = Platform.getLong(base, off); + return valueUr; + } + }; + } + }; + } + + private boolean capReached() { + return numValues >= batchCapacity; + } + + private void append(Object kbase, long koff, int klen, int khash, + Object vbase, long voff, int vlen) { + assert numValues < batchCapacity : "capacity reached " + batchCapacity; + // lookup key + int pos = khash & mask; + boolean hasKey = false; + Object keyBaseObject = null; + long keyNextOffset = 0; + int step = 1; + while (true) { + long keyAddress = keyMap.getOrDefault(pos, 0L); + if (keyAddress == 0L) { + // fresh key + break; + } else { + final Object base = taskMemoryManager.getPage(keyAddress); + long offset = taskMemoryManager.getOffsetInPage(keyAddress); + keyBaseObject = base; + int keyLength = UnsafeAlignedOffset.getSize(base, offset); + if (keyLength == klen) { + offset += uaoSize; + int keyHash = UnsafeAlignedOffset.getSize(base, offset); + if (keyHash == khash) { + offset += uaoSize; + hasKey = ByteArrayMethods.arrayEquals(base, offset, kbase, koff, klen); + if (hasKey) { + // exists key + offset += klen; + keyNextOffset = offset; + break; + } + } + } + } + pos = (pos + step) & mask; + step++; + } + + long nextAddress = hasKey ? Platform.getLong(keyBaseObject, keyNextOffset) : 0L; + // (value length) (value) (next) + final long valueRecordLength = uaoSize + vlen + NEXT_POINTER_BYTES; + if (currentPage == null || currentPage.getWritableSize() < valueRecordLength) { + currentPage = acquirePage(valueRecordLength); + } + if (lastWritePage != currentPage) { + ensurePageRef(); + } + final long valueAddress = currentPage.putValue(vbase, voff, vlen, nextAddress); + numValues++; + if (hasKey) { + Platform.putLong(keyBaseObject, keyNextOffset, valueAddress); + } else { + numKeys++; + sizeKeys += klen; + // (key length) (key hash) (key) (next) + final long keyRecordLength = (uaoSize * 2L) + klen + NEXT_POINTER_BYTES; + if (currentPage.getWritableSize() < keyRecordLength) { + currentPage = acquirePage(keyRecordLength); + } + if (lastWritePage != currentPage) { + ensurePageRef(); + } + long keyAddress = currentPage.putKey(kbase, koff, klen, khash, valueAddress); + keyMap.put(pos, keyAddress); + } + } + + public void release() throws IOException { + Iterator iter = innerPageList.iterator(); + while (iter.hasNext()) { + KeyValuePage page = iter.next(); + iter.remove(); + page.release(); + } + } + + private void ensurePageRef() { + lastWritePage = currentPage; + lastWritePage.retain(); + innerPageList.add(lastWritePage); + } + } + + private static final int NEXT_POINTER_BYTES = Long.BYTES; + + private final class KeyValuePage { + + private final MemoryBlock page; + private final Object baseObject; + private final long baseOffset; + private int refCount; + private long cursor; + + private KeyValuePage(MemoryBlock page) { + this.page = page; + this.baseObject = page.getBaseObject(); + this.baseOffset = page.getBaseOffset(); + UnsafeAlignedOffset.putSize(baseObject, baseOffset, 0); + this.refCount = 0; + } + + private long putKey(Object kbase, long koff, int klen, int khash, long next) { + long offset = baseOffset + cursor; + final long recordOffset = offset; + UnsafeAlignedOffset.putSize(baseObject, offset, klen); + offset += uaoSize; + UnsafeAlignedOffset.putSize(baseObject, offset, khash); + offset += uaoSize; + Platform.copyMemory(kbase, koff, baseObject, offset, klen); + offset += klen; + Platform.putLong(baseObject, offset, next); + cursor += ((uaoSize * 2L) + klen + NEXT_POINTER_BYTES); + return taskMemoryManager.encodePageNumberAndOffset(page, recordOffset); + } + + private long putValue(Object vbase, long voff, int vlen, long next) { + long offset = baseOffset + cursor; + final long recordOffset = offset; + UnsafeAlignedOffset.putSize(baseObject, offset, vlen); + offset += uaoSize; + Platform.copyMemory(vbase, voff, baseObject, offset, vlen); + offset += vlen; + Platform.putLong(baseObject, offset, next); + cursor += (uaoSize + vlen + NEXT_POINTER_BYTES); + return taskMemoryManager.encodePageNumberAndOffset(page, recordOffset); + } + + private void resetCursor() { + assert refCount == 0 : "hasRef"; + this.cursor = uaoSize; + } + + private long getSize() { + return page.size(); + } + + private long getWritableSize() { + return getSize() - cursor; + } + + private boolean nonRef() { + return refCount == 0; + } + + private void retain() { + refCount++; + } + + private void release() { + assert refCount > 0 : "refCount=" + refCount; + refCount--; + } + + private void free() { + freePage(page); + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6444b1ddbdc96..c21f7b80699f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -156,6 +156,15 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * small. However, broadcasting tables is a network-intensive operation and it could cause * OOM or perform badly in some cases, especially when the build/broadcast side is big. * + * - Distributed map join (DMJ): + * Only supported for equi-joins. + * Supported for all join types that allow right-side build + * (e.g., inner, left outer, left semi). + * Avoids shuffle by building a distributed hash table service for the build side. + * The probe side performs batched RPC lookups to fetch matching rows. + * Ideal for medium-sized dimension tables (e.g., 200MB ~ 2GB) that are too large for BHJ + * but much smaller than the probe side. Requires explicit `distmapjoin` hint. + * * - Shuffle hash join: * Only supported for equi-joins, while the join keys do not need to be sortable. * Supported for all join types. @@ -205,6 +214,27 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + private def checkDistributedMapJoinBuildSide( + onlyLookingAtHint: Boolean, + buildSide: Option[BuildSide], + joinType: JoinType, + hint: JoinHint): Unit = { + + def invalidBuildSide(hintInfo: HintInfo, side: String): Unit = { + hintErrorHandler.joinHintNotSupported(hintInfo, + s"build $side for ${joinType.sql.toLowerCase(Locale.ROOT)} join via distmapjoin") + } + + if (onlyLookingAtHint && buildSide.isEmpty) { + if (hintToDistributedMapJoinLeft(hint)) { + invalidBuildSide(hint.leftHint.get, "left") + } + if (hintToDistributedMapJoinRight(hint)) { + invalidBuildSide(hint.rightHint.get, "right") + } + } + } + private def checkHintNonEquiJoin(hint: JoinHint): Unit = { if (hintToShuffleHashJoin(hint) || hintToPreferShuffleHashJoin(hint) || hintToSortMergeJoin(hint)) { @@ -219,11 +249,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // If it is an equi-join, we first look at the join hints w.r.t. the following order: // 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides // have the broadcast hints, choose the smaller side (based on stats) to broadcast. - // 2. sort merge hint: pick sort merge join if join keys are sortable. - // 3. shuffle hash hint: We pick shuffle hash join if the join type is supported. If both + // 2. distmapjoin hint: pick distributed map join if the join type supports build right. + // 3. sort merge hint: pick sort merge join if join keys are sortable. + // 4. shuffle hash hint: We pick shuffle hash join if the join type is supported. If both // sides have the shuffle hash hints, choose the smaller side (based on stats) as the // build side. - // 4. shuffle replicate NL hint: pick cartesian product if join type is inner like. + // 5. shuffle replicate NL hint: pick cartesian product if join type is inner like. // // If there is no hint or the hints are not applicable, we follow these rules one by one: // 1. Pick broadcast hash join if one side is small enough to broadcast, and the join type @@ -258,6 +289,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + def createDistributedMapJoin(onlyLookingAtHint: Boolean) = { + val buildSide = getDistributedMapJoinBuildSide( + left, right, joinType, hint, onlyLookingAtHint, conf) + checkDistributedMapJoinBuildSide(onlyLookingAtHint, buildSide, joinType, hint) + buildSide.map { + buildSide => + Seq(joins.DistributedMapJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + nonEquiCond, + planLater(left), + planLater(right), + hint)) + } + } + def createShuffleHashJoin(onlyLookingAtHint: Boolean) = { if (hashJoinSupport) { val buildSide = getShuffleHashJoinBuildSide(j, onlyLookingAtHint, conf) @@ -322,6 +371,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { createJoinWithoutHint() } else { createBroadcastHashJoin(true) + .orElse(createDistributedMapJoin(true)) .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None } .orElse(createShuffleHashJoin(true)) .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 9a483076ff567..1101fdf8874bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -690,7 +690,13 @@ case class AdaptiveSparkPlanExec( optimized, postStageCreationRules(outputsColumnar = plan.supportsColumnar), "AQE Post Stage Creation") - if (e.isInstanceOf[ShuffleExchangeLike]) { + if (e.isInstanceOf[ShardExchangeLike]) { + if (!newPlan.isInstanceOf[ShardExchangeLike]) { + throw SparkException.internalError( + "Custom columnar rules cannot transform shard node to something else.") + } + ShardQueryStageExec(currentStageId, newPlan, e.canonicalized) + } else if (e.isInstanceOf[ShuffleExchangeLike]) { if (!newPlan.isInstanceOf[ShuffleExchangeLike]) { throw SparkException.internalError( "Custom columnar rules cannot transform shuffle node to something else.") @@ -830,6 +836,8 @@ case class AdaptiveSparkPlanExec( val finalPlan = inputPlan match { case b: BroadcastExchangeLike if (!newPlan.isInstanceOf[BroadcastExchangeLike]) => b.withNewChildren(Seq(newPlan)) + case sh: ShardExchangeLike + if (!newPlan.isInstanceOf[ShardExchangeLike]) => sh.withNewChildren(Seq(newPlan)) case _ => newPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala index 4d33ed81641d9..316d32fa80d06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical.{BroadcastPartitioning, IdentityBroadcastMode} import org.apache.spark.sql.classic.Strategy import org.apache.spark.sql.execution.{joins, SparkPlan} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashedRelationBroadcastMode} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, DistributedMapJoinExec, HashedRelationBroadcastMode} /** * Strategy for plans containing [[LogicalQueryStage]] nodes: @@ -58,6 +58,11 @@ object LogicalQueryStageStrategy extends Strategy { case _ => false } + private def isShardStage(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, _: ShardQueryStageExec) => true + case _ => false + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, otherCondition, _, left, right, hint) @@ -73,6 +78,14 @@ object LogicalQueryStageStrategy extends Strategy { leftKeys, rightKeys, joinType, buildSide, otherCondition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, otherCondition, _, + left, right, hint) + if isShardStage(left) || isShardStage(right) => + val buildSide = if (isShardStage(left)) BuildLeft else BuildRight + Seq(DistributedMapJoinExec( + leftKeys, rightKeys, joinType, buildSide, otherCondition, planLater(left), + planLater(right), hint)) + case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) if isBroadcastStageWithHashedBroadcastMode(j.right, isNullAware = true) => Seq(joins.BroadcastHashJoinExec(leftKeys, rightKeys, LeftAnti, BuildRight, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index be58bccd1489a..9962fd6e2b986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -312,6 +312,50 @@ case class TableCacheQueryStageExec( override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics } +case class ShardQueryStageExec( + override val id: Int, + override val plan: SparkPlan, + override val _canonicalized: SparkPlan) extends ExchangeQueryStageExec { + + @transient val shard: ShardExchangeLike = plan match { + case sh: ShardExchangeLike => sh + case ReusedExchangeExec(_, sh: ShardExchangeLike) => sh + case _ => + throw SparkException.internalError( + s"wrong plan for shard stage:\n ${plan.treeString}") + } + + override protected def doMaterialize(): Future[Any] = { + shard.submitShardJob + } + + override def newReuseInstance( + newStageId: Int, + newOutput: Seq[Attribute]): ExchangeQueryStageExec = { + val reuse = ShardQueryStageExec( + newStageId, + ReusedExchangeExec(newOutput, shard), + _canonicalized) + reuse._resultOption = this._resultOption + reuse._error = this._error + reuse + } + + override protected def doCancel(reason: String): Unit = { + if (!shard.relationFuture.isDone) { + sparkContext.cancelJobGroup(shard.runId.toString) + shard.relationFuture.cancel(true) + } + } + + override def getRuntimeStatistics: Statistics = shard.runtimeStatistics + + def shardSetRef: org.apache.spark.shard.ShardSetRef = { + assert(isMaterialized, s"${getClass.getSimpleName} should already be ready") + resultOption.get().get.asInstanceOf[org.apache.spark.shard.ShardSetRef] + } +} + case class ResultQueryStageExec( override val id: Int, override val plan: SparkPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index c632b3d841e61..d8ee16f2601f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -125,6 +125,8 @@ case class EnsureRequirements( distribution match { case BroadcastDistribution(mode) => BroadcastExchangeExec(mode, child) + case ShardDistribution(keys, shards, replicas, filter, schema) => + ShardExchangeExec(keys, shards, replicas, child, filter, schema) case _: StatefulOpClusteredDistribution => ShuffleExchangeExec( distribution.createPartitioning(numPartitions), child, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala new file mode 100644 index 0000000000000..a0fe0bca2f58e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import java.io.{InputStream => JInputStream, OutputStream => JOutputStream} +import java.util.UUID +import java.util.concurrent.{Future => JFuture, TimeUnit} + +import scala.concurrent.{ExecutionContext, Promise, TimeoutException} + +import org.apache.spark.{shard, SparkEnv, SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.shard.BloomAccumulator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, ShardPartitioning} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.joins.HashedRelation +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.sketch.BloomFilter + +trait ShardExchangeLike extends Exchange { + def runId: UUID = UUID.randomUUID + + def relationFuture: JFuture[shard.ShardSetRef] + + final def submitShardJob: scala.concurrent.Future[shard.ShardSetRef] = executeQuery { + completionFuture + } + + protected def completionFuture: scala.concurrent.Future[shard.ShardSetRef] + + def runtimeStatistics: Statistics +} + +/** + * Shuffles the build side into shards, builds a [[HashedRelation]] per shard, + * installs them on executors via [[ShardManager]], and constructs a bloom + * filter for probe-side key filtering. + */ +case class ShardExchangeExec( + buildBoundKeys: Seq[Expression], + numShards: Int, + replicaCount: Int, + child: SparkPlan, + filterExpr: Option[Expression] = None, + filterSchema: Option[StructType] = None) + extends ShardExchangeLike { + + override val runId: UUID = UUID.randomUUID + + override lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override def outputPartitioning: Partitioning = ShardPartitioning(buildBoundKeys, numShards) + + override protected def doCanonicalize(): SparkPlan = { + copy( + buildBoundKeys = buildBoundKeys.map(_.canonicalized), + child = child.canonicalized, + filterExpr = filterExpr.map(_.canonicalized)) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics("numOutputRows").value + Statistics(dataSize, Some(rowCount)) + } + + @transient + private lazy val promise = Promise[shard.ShardSetRef]() + + @transient + override lazy val completionFuture: scala.concurrent.Future[shard.ShardSetRef] = + promise.future + + @transient + private val timeout: Long = conf.distributedMapJoinExchangeTimeout // seconds + + @transient + override lazy val relationFuture: JFuture[shard.ShardSetRef] = + SQLExecution + .withThreadLocalCaptured[shard.ShardSetRef](session, ShardExchangeExec.executionContext) { + try { + sparkContext.setJobGroup( + runId.toString, + s"shard exchange (runId $runId)", + interruptOnCancel = true) + if (!ensureExecutors(math.min(timeout / 2, 900), TimeUnit.SECONDS)) { + logWarning( + s"ensureExecutors timed out or SC stopped;" + + s" active executors may be < target, continue build") + } + val setId = sparkContext.env.shardManager.newShardSet(numShards, replicaCount) + val shuffled = ShuffleExchangeExec( + HashPartitioning(buildBoundKeys, numShards), + child, + REPARTITION_BY_NUM).execute() + val sharded = new ShardedRowRDD(shuffled, getPreferredHosts(shuffled.partitions.length)) + val bfCapacityPerShard = conf.getConf(SQLConf.DISTRIBUTED_MAP_JOIN_BLOOM_FILTER_CAPACITY) + val filterBytes: Option[(Array[Byte], Array[Byte])] = filterExpr.map { expr => + val ser = SparkEnv.get.serializer.newInstance() + val exprBytes = ser.serialize(expr).array() + val schemaBytes = filterSchema + .map(s => ser.serialize(s).array()) + .getOrElse(Array.empty[Byte]) + (exprBytes, schemaBytes) + } + val shardIds = + sharded + .mapPartitionsWithIndexInternal { case (shardId, rowIter) => + val bloomCapacity = bfCapacityPerShard * numShards + val bf = BloomFilter.create(bloomCapacity, 0.03d) + val keyGenerator = UnsafeProjection.create(buildBoundKeys) + val iter = rowIter.map { row => + bf.put(keyGenerator(row).getBytes) + longMetric("numOutputRows").add(1) + row + } + val relation = HashedRelation( + iter, + buildBoundKeys, + taskMemoryManager = TaskContext.get().taskMemoryManager()) + + TaskContext.get().addTaskCompletionListener[Unit](_ => relation.close()) + + longMetric("dataSize").add(relation.estimatedSize) + + val sm = SparkEnv.get.shardManager + sm.installShard(relation, setId, shardId, filterBytes)(bfOutput => + bf.writeTo(bfOutput)) + sm.installReplicaSet(setId, shardId) + + Iterator.single(shardId) + } + .collect() + // per shard-set + val acc = new BloomAccumulator { + private var bloom: BloomFilter = _ + override def add(input: JInputStream): Unit = { + val b = BloomFilter.readFrom(input) + if (bloom eq null) bloom = b else bloom.mergeInPlace(b) + } + override def isEmpty: Boolean = bloom eq null + override def finish(output: JOutputStream): Unit = { + if (!isEmpty) { + bloom.writeTo(output) + } + } + } + SparkEnv.get.shardManager.mergeBloomFilter(setId, shardIds, acc) + val setRef = shard.ShardSetRef(setId, shardIds) + promise.trySuccess(setRef) + sparkContext.cleaner.foreach(_.registerShardSetForCleanup(setRef)) + setRef + } catch { + case e: Throwable => + promise.tryFailure(e) + throw e + } + } + + override protected def doPrepare(): Unit = { + // Initialize metrics + metrics + // Materialize the future. + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = + throw QueryExecutionErrors.executeCodePathUnsupportedError("ShardExchange") + + def buildShardSet(): shard.ShardSetRef = try { + relationFuture.get(timeout, TimeUnit.SECONDS) + } catch { + case ex: TimeoutException => + logError(s"Could not execute shard in $timeout seconds.", ex) + if (!relationFuture.isDone) { + sparkContext.cancelJobGroup(runId.toString) + relationFuture.cancel(true) + } + throw new SparkException(s"shard exchange timeout.", ex) + } + + override protected def withNewChildInternal(newChild: SparkPlan): ShardExchangeExec = + copy(child = newChild) + + private def ensureExecutors(length: Long, unit: TimeUnit): Boolean = { + val sparkConf = sparkContext.getConf + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + val minExecutors = + math.min( + numShards, + sparkConf.get(org.apache.spark.internal.config.DYN_ALLOCATION_MAX_EXECUTORS)) + try sparkContext.requestTotalExecutors(minExecutors, 0, Map.empty) + catch { case scala.util.control.NonFatal(_) => () } + + val deadlineNs = System.nanoTime() + unit.toNanos(length) + while (System.nanoTime() < deadlineNs && !sparkContext.isStopped) { + val active = + try sparkContext.getExecutorIds().size + catch { case scala.util.control.NonFatal(_) => 0 } + if (active == 0) { + logWarning( + "Active executors should not be 0, " + + "spark.dynamicAllocation.minExecutors should be greater than 0.") + return false + } + if (active >= minExecutors) return true + Thread.sleep(10000) + } + false + } else true + } + + private def getPreferredHosts(prevPartLen: Int): Array[Seq[String]] = { + def activeHosts: Array[String] = { + val dh = sparkContext.conf.get("spark.driver.host", "") + val infos = sparkContext.statusTracker.getExecutorInfos + infos.iterator + .map(_.host) + .filter(_ != dh) + .toArray + .distinct + } + val hosts = activeHosts + val shuffledHosts = scala.util.Random.shuffle(hosts.toSeq).toArray + Array.tabulate(prevPartLen) { i => + if (shuffledHosts.nonEmpty) Seq(shuffledHosts(i % shuffledHosts.length)) else Nil + } + } +} + +object ShardExchangeExec { + + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("shard-exchange", 8)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardedRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardedRowRDD.scala new file mode 100644 index 0000000000000..403dfb48fdf93 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardedRowRDD.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.{OneToOneDependency, Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow + +private[spark] class ShardedRowRDD(prev: RDD[InternalRow], preferredHosts: Array[Seq[String]]) + extends RDD[InternalRow](prev.sparkContext, Seq(new OneToOneDependency(prev))) { + + override protected def getPartitions: Array[Partition] = prev.partitions + + override protected def getPreferredLocations(split: Partition): Seq[String] = { + preferredHosts(split.index % preferredHosts.length) + } + + override def compute(split: Partition, ctx: TaskContext): Iterator[InternalRow] = { + prev.iterator(split, ctx) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala new file mode 100644 index 0000000000000..51cf3dc86fb9b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala @@ -0,0 +1,567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util + +import scala.annotation.tailrec +import scala.concurrent.ExecutionContextExecutorService + +import io.netty.buffer.Unpooled + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.rdd.RDD +import org.apache.spark.shard.ShardSetRef +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, BindReferences, Expression, GenericInternalRow, JoinedRow, Predicate, PredicateHelper, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical.{DistributedMapJoinStrategy, JoinHint} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.{BufferedShardRowMap, SparkPlan} +import org.apache.spark.sql.execution.adaptive.ShardQueryStageExec +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShardExchangeExec} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.util.sketch.BloomFilter + +/** + * Physical operator for Distributed MapJoin. + * + * This strategy avoids full shuffle by building a distributed hash table service for the build + * side (medium-sized table), and the probe side performs batched RPC lookups to complete the + * join. + * + * Currently only supports: + * - Equi-join + * - BuildRight (right table as build side) + * - Explicit SQL hint: /*+distmapjoin(t(shard_count=5,replica_count=2))*/ + */ +case class DistributedMapJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + hint: JoinHint, + isNullAwareAntiJoin: Boolean = false) + extends HashJoin with PredicateHelper { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + private val (numShards, replicaCount): (Int, Int) = { + val strategy = + (if (buildSide == BuildLeft) hint.leftHint else hint.rightHint).flatMap(_.strategy) + strategy match { + case Some(DistributedMapJoinStrategy(ns, rc)) => (ns.getOrElse(5), rc.getOrElse(1)) + case _ => (5, 1) + } + } + + @transient private lazy val (buildOnlyFilter, remainingCondition): + (Option[Expression], Option[Expression]) = { + condition match { + case Some(cond) => + val buildAttrs = AttributeSet(buildOutput) + val conjuncts = splitConjunctivePredicates(cond) + val (bo, rem) = conjuncts.partition(_.references.subsetOf(buildAttrs)) + (bo.reduceOption(And), rem.reduceOption(And)) + case None => (None, None) + } + } + + @transient override protected[this] lazy val boundCondition: InternalRow => Boolean = { + remainingCondition match { + case Some(cond) => + Predicate.create(cond, streamedPlan.output ++ buildPlan.output).eval _ + case None => + (_: InternalRow) => true + } + } + + override def supportCodegen: Boolean = false + + override def supportsColumnar: Boolean = false + + override def needCopyResult: Boolean = false + + override def requiredChildDistribution: Seq[Distribution] = { + val filter = buildOnlyFilter.map(BindReferences.bindReference(_, buildOutput)) + val filterSchema = buildOnlyFilter.map(_ => buildPlan.schema) + val sd = ShardDistribution(buildBoundKeys, numShards, replicaCount, filter, filterSchema) + buildSide match { + case BuildLeft => Seq(sd, UnspecifiedDistribution) + case BuildRight => Seq(UnspecifiedDistribution, sd) + } + } + + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + + override def inputRDDs(): Seq[RDD[InternalRow]] = + throw QueryExecutionErrors.executeCodePathUnsupportedError("DistributedMapJoin") + + override protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo = + throw QueryExecutionErrors.executeCodePathUnsupportedError("DistributedMapJoin") + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SparkPlan = copy(left = newLeft, right = newRight) + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val setRef = resolveShardSetRef(buildPlan) + streamedPlan.execute().mapPartitionsInternal { streamedIter => + join(streamedIter, setRef.setId, numOutputRows) + } + } + + @tailrec + private def resolveShardSetRef(plan: SparkPlan): ShardSetRef = plan match { + case s: ShardExchangeExec => s.buildShardSet() + case s: ShardQueryStageExec => s.shardSetRef + case r: ReusedExchangeExec => resolveShardSetRef(r.child) + case other => + throw new IllegalStateException(s"Unexpected build plan for DistributedMapJoin: $other") + } + + private def streamedBloomFilter(setId: Long): BloomFilter = { + SparkEnv.get.shardManager.fetchBloomFilter[BloomFilter](setId)(bfInput => + BloomFilter.readFrom(bfInput)) + } + + private def join( + streamedIter: Iterator[InternalRow], + setId: Long, + numOutputRows: SQLMetric): Iterator[InternalRow] = { + + val joinedRow = new JoinedRow + + val (scanBatch, onMissing): ( + BatchMatchReader => Iterator[InternalRow], + InternalRow => Option[InternalRow]) = joinType match { + case _: InnerLike => + (innerJoinScan(_, joinedRow), (_: InternalRow) => None) + case LeftOuter | RightOuter => + val nullRow = new GenericInternalRow(buildOutput.length) + (outerJoinScan(_, joinedRow, nullRow), + (sr: InternalRow) => Some(joinedRow.withLeft(sr).withRight(nullRow))) + case LeftSemi => + (semiJoinScan(_, joinedRow), (_: InternalRow) => None) + case LeftAnti => + (antiJoinScan(_, joinedRow), (sr: InternalRow) => Some(sr)) + case _: ExistenceJoin => + val existsRow = new GenericInternalRow(Array[Any](null)) + (existenceJoinScan(_, joinedRow, existsRow), + (sr: InternalRow) => { existsRow.setBoolean(0, false); Some(joinedRow(sr, existsRow)) }) + case x => + throw new IllegalArgumentException( + s"DistributedMapJoin should not take $x as the JoinType") + } + + val iter = new LookupJoinIterator(streamedIter, setId, scanBatch, onMissing) + val resultProj = createResultProjection() + iter.map { row => + numOutputRows.add(1) + resultProj(row) + } + } + + // --------------------------------------------------------------------------- + // Scan methods: one per join type, each self-contained + // --------------------------------------------------------------------------- + + private def innerJoinScan( + reader: BatchMatchReader, + joinedRow: JoinedRow): Iterator[InternalRow] = { + new Iterator[InternalRow] { + private var _has = false + override def hasNext: Boolean = { if (!_has) _has = advance(); _has } + override def next(): InternalRow = { _has = false; joinedRow } + + private def advance(): Boolean = { + while (true) { + val br = reader.nextBuildRow() + if (br != null) { + if (boundCondition(joinedRow.withLeft(reader.curStreamed).withRight(br))) return true + } else if (!reader.advanceStreamedRow()) { + return false + } + } + false // unreachable + } + } + } + + private def outerJoinScan( + reader: BatchMatchReader, + joinedRow: JoinedRow, + nullRow: InternalRow): Iterator[InternalRow] = { + new Iterator[InternalRow] { + private var found = false + private var _has = false + override def hasNext: Boolean = { if (!_has) _has = advance(); _has } + override def next(): InternalRow = { _has = false; joinedRow } + + private def advance(): Boolean = { + while (true) { + val br = reader.nextBuildRow() + if (br != null) { + if (boundCondition(joinedRow.withLeft(reader.curStreamed).withRight(br))) { + found = true + return true + } + } else { + if (!found && reader.curStreamed != null) { + joinedRow.withLeft(reader.curStreamed).withRight(nullRow) + found = true + return true + } + if (!reader.advanceStreamedRow()) return false + found = false + } + } + false // unreachable + } + } + } + + private def semiJoinScan( + reader: BatchMatchReader, + joinedRow: JoinedRow): Iterator[InternalRow] = { + new Iterator[InternalRow] { + private var _has = false + override def hasNext: Boolean = { if (!_has) _has = advance(); _has } + override def next(): InternalRow = { _has = false; reader.curStreamed } + + private def advance(): Boolean = { + while (reader.advanceStreamedRow()) { + var br = reader.nextBuildRow() + while (br != null) { + if (boundCondition(joinedRow.withLeft(reader.curStreamed).withRight(br))) { + reader.skipRemainingBuildRows() + return true + } + br = reader.nextBuildRow() + } + } + false + } + } + } + + private def antiJoinScan( + reader: BatchMatchReader, + joinedRow: JoinedRow): Iterator[InternalRow] = { + new Iterator[InternalRow] { + private var _has = false + override def hasNext: Boolean = { if (!_has) _has = advance(); _has } + override def next(): InternalRow = { _has = false; reader.curStreamed } + + private def advance(): Boolean = { + while (reader.advanceStreamedRow()) { + if (findMatchingBuildRow(reader, joinedRow)) { + reader.skipRemainingBuildRows() + } else { + return true + } + } + false + } + } + } + + private def existenceJoinScan( + reader: BatchMatchReader, + joinedRow: JoinedRow, + existsRow: GenericInternalRow): Iterator[InternalRow] = { + new Iterator[InternalRow] { + private var _has = false + override def hasNext: Boolean = { if (!_has) _has = advance(); _has } + override def next(): InternalRow = { _has = false; joinedRow } + + private def advance(): Boolean = { + if (!reader.advanceStreamedRow()) return false + val exists = findMatchingBuildRow(reader, joinedRow) + if (exists) reader.skipRemainingBuildRows() + existsRow.setBoolean(0, exists) + joinedRow.withLeft(reader.curStreamed).withRight(existsRow) + true + } + } + } + + @tailrec + private def findMatchingBuildRow( + reader: BatchMatchReader, + joinedRow: JoinedRow): Boolean = { + val br = reader.nextBuildRow() + if (br == null) false + else if (boundCondition(joinedRow.withLeft(reader.curStreamed).withRight(br))) true + else findMatchingBuildRow(reader, joinedRow) + } + + // --------------------------------------------------------------------------- + // BatchMatchReader: reads batch response buffer, zero-copy + // --------------------------------------------------------------------------- + + private type PBatch = BufferedShardRowMap#KeyValueBatch + private type BBuffer = ManagedBuffer + + private class BatchMatchReader(batch: PBatch, buffer: BBuffer) extends AutoCloseable { + private val keyIter = batch.multiValuesIterator() + private val buf = Unpooled.wrappedBuffer(buffer.nioByteBuffer()) + private val buildUr: UnsafeRow = new UnsafeRow(buildOutput.length) + private val advanceRead = UnsafeRowBufCodec.makeAdvanceRead(buildUr, buf) + private var streamedIter: java.util.Iterator[UnsafeRow] = _ + private var buildRowsStart: Int = -1 + private var atSentinel = true + var curStreamed: UnsafeRow = _ + + assert(batch.getSetId == buf.readLong(), "setId mismatch") + assert(batch.getShard == buf.readInt(), "shardId mismatch") + + def advanceStreamedRow(): Boolean = { + if (streamedIter != null && streamedIter.hasNext) { + curStreamed = streamedIter.next() + rewindBuildRows() + true + } else { + skipRemainingBuildRows() + if (!keyIter.hasNext) return false + streamedIter = keyIter.next() + buildRowsStart = buf.readerIndex() + atSentinel = false + if (!streamedIter.hasNext) advanceStreamedRow() + else { curStreamed = streamedIter.next(); true } + } + } + + def nextBuildRow(): UnsafeRow = { + if (atSentinel) return null + val blen = buf.readInt() + if (blen == 0) { atSentinel = true; null } + else advanceRead(blen) + } + + def skipRemainingBuildRows(): Unit = { + while (!atSentinel) { + val blen = buf.readInt() + if (blen == 0) atSentinel = true + else buf.readerIndex(buf.readerIndex() + blen) + } + } + + private def rewindBuildRows(): Unit = { + buf.readerIndex(buildRowsStart) + atSentinel = false + } + + override def close(): Unit = { + batch.release() + buffer.release() + } + } + + // --------------------------------------------------------------------------- + // LookupJoinIterator: async pipeline management + // --------------------------------------------------------------------------- + + private class LookupJoinIterator( + streamedIter: Iterator[InternalRow], + setId: Long, + scanBatch: BatchMatchReader => Iterator[InternalRow], + onMissing: InternalRow => Option[InternalRow]) + extends Iterator[InternalRow] { + + private val maxInFlightNum = conf.distributedMapJoinMaxInFlightNum + private val keyGenerator: UnsafeProjection = UnsafeProjection.create(streamedBoundKeys) + private val valueGenerator: UnsafeProjection = UnsafeProjection.create(streamedPlan.schema) + private val shardGenerator = + UnsafeProjection.create( + HashPartitioning(streamedBoundKeys, numShards).partitionIdExpression :: Nil) + + private val bloom = streamedBloomFilter(setId) + private val probeUr: UnsafeRow = new UnsafeRow(streamedPlan.schema.length) + private val bufferedMap = { + val mm = TaskContext.get().taskMemoryManager() + val maxBatchSize = conf.distributedMapJoinMaxBatchSize + val map = + new BufferedShardRowMap( + mm, + setId, + numShards, + streamedBoundKeys.length, + probeUr, + maxBatchSize) + TaskContext.get().addTaskCompletionListener[Unit](_ => map.free()) + map + } + + private sealed trait Lookup + private case class LookupSuccess(batch: PBatch, buffer: BBuffer) extends Lookup + private case class LookupFailure(batch: PBatch, cause: Throwable) extends Lookup + + private val lookupQueue = new util.concurrent.LinkedBlockingQueue[Lookup] + + private var inputExhausted = false + private var prepared = false + private var nextRowVal: InternalRow = _ + private var numInFlight = 0 + private var currentReader: BatchMatchReader = _ + private var currentBatchIter: Iterator[InternalRow] = _ + + override def hasNext: Boolean = { + if (!prepared) { + processNext() + } + prepared + } + + override def next(): InternalRow = { + if (!prepared && !hasNext) { + throw QueryExecutionErrors.noSuchElementExceptionError() + } + prepared = false + nextRowVal + } + + private def prepareNextRow(row: InternalRow): Unit = { + nextRowVal = row + prepared = true + } + + private def processNext(): Unit = { + if (currentBatchIter != null) { + iterateLookup() + } + + var hasLookup = true + while (!prepared && hasLookup) { + val ele = lookupQueue.poll() + if (ele == null) { + hasLookup = false + } else { + pollLookup(ele) + if (currentBatchIter != null) { + iterateLookup() + } + } + } + + while (!prepared && !inputExhausted && streamedIter.hasNext) { + val streamedRow = streamedIter.next() + val keyUr = keyGenerator(streamedRow) + if (keyUr.anyNull || !bloom.mightContain(keyUr.getBytes)) { + onMissing(streamedRow).foreach(prepareNextRow) + } else { + val shardId = shardGenerator(streamedRow).getInt(0) + val valueUr = streamedRow match { + case ur: UnsafeRow => ur + case r => valueGenerator(r) + } + processLookup(shardId, keyUr, valueUr) + } + } + + if (!prepared) { + if (!inputExhausted) { + inputExhausted = true + flushLookup(bufferedMap.tailingIterator()) + } + while (!prepared && numInFlight > 0) { + pollLookup(lookupQueue.poll(200, util.concurrent.TimeUnit.MILLISECONDS)) + if (currentBatchIter != null) { + iterateLookup() + } + } + } + } + + private def processLookup(shardId: Int, keyUr: UnsafeRow, valueUr: UnsafeRow): Unit = { + bufferedMap.putRow( + shardId, + keyUr.getBaseObject, + keyUr.getBaseOffset, + keyUr.getSizeInBytes, + keyUr.hashCode(), + valueUr.getBaseObject, + valueUr.getBaseOffset, + valueUr.getSizeInBytes) + + if (bufferedMap.hasPending) { + flushLookup(bufferedMap.pendingIterator()) + } + + while (numInFlight >= maxInFlightNum) { + pollLookup(lookupQueue.poll(200, util.concurrent.TimeUnit.MILLISECONDS)) + if (currentBatchIter != null) { + iterateLookup() + } + } + } + + private def flushLookup[T <: PBatch](iter: util.Iterator[T]): Unit = { + val manager = SparkEnv.get.shardManager + implicit val ec: ExecutionContextExecutorService = manager.lookupEc + while (iter.hasNext) { + val batch = iter.next() + iter.remove() + val future = + manager.fetchRemoteBatch(setId, batch.getShard, () => batch.wrapKeysBuffer()) + + future.onComplete { + case scala.util.Success(buffer) => + lookupQueue.put(LookupSuccess(batch, buffer)) + case scala.util.Failure(cause) => + lookupQueue.put(LookupFailure(batch, cause)) + } + + numInFlight += 1 + } + } + + private def pollLookup(look: Lookup): Unit = { + look match { + case LookupSuccess(batch, buffer) => + currentReader = new BatchMatchReader(batch, buffer) + currentBatchIter = scanBatch(currentReader) + numInFlight -= 1 + case LookupFailure(batch, cause) => + throw new SparkException( + s"DistributedMapJoin batch lookup failed: ($setId, ${batch.getShard}) ", + cause) + case null => + } + } + + private def iterateLookup(): Unit = { + if (currentBatchIter.hasNext) { + prepareNextRow(currentBatchIter.next()) + } else { + currentReader.close() + currentReader = null + currentBatchIter = null + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelationAdapter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelationAdapter.scala new file mode 100644 index 0000000000000..261b4d4b27977 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelationAdapter.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicBoolean + +import io.netty.buffer.Unpooled + +import org.apache.spark.SparkEnv +import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.util.NettyUtils +import org.apache.spark.shard.{ShardLookupAdapter, ShardManager} +import org.apache.spark.sql.catalyst.expressions.{BasePredicate, Expression, Predicate, UnsafeRow} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.types.StructType + +/** + * Build-side RPC handler for distributed map join. + * + * Receives batched key lookups from probe-side executors, performs hash + * lookups against the local [[HashedRelation]], and returns matching rows. + * + * When a build-only filter is stored in the shard meta, it is loaded once + * per setId and evaluated server-side to reduce network transfer. + * + * Wire format (request): + * {{{ + * (setId:long)(shardId:int)(numKeyFields:int) + * [(keyLen:int)(keyBytes)]... + * }}} + */ +private[spark] class HashedRelationAdapter extends ShardLookupAdapter { + + private val INITIAL_RESPONSE_BUFFER_BYTES = 1 << 20 + private val alloc = NettyUtils.getSharedPooledByteBufAllocator(true, true) + + private val filterCache = new ConcurrentHashMap[Long, Option[BasePredicate]]() + private val cleanupRegistered = new AtomicBoolean(false) + + override def lookup(manager: ShardManager, reqMsg: ManagedBuffer): ManagedBuffer = { + if (cleanupRegistered.compareAndSet(false, true)) { + manager.registerCleanupCallback(setId => filterCache.remove(setId)) + } + val keysBuf = Unpooled.wrappedBuffer(reqMsg.nioByteBuffer()) + val setId = keysBuf.readLong() + val shard = keysBuf.readInt() + val numKeyFields = keysBuf.readInt() + + val keyUr = new UnsafeRow(numKeyFields) + val rel = manager.getLocalValue[HashedRelation](setId, shard).asReadOnlyCopy() + val valuesBuf = alloc.buffer(INITIAL_RESPONSE_BUFFER_BYTES) + valuesBuf.writeLong(setId) + valuesBuf.writeInt(shard) + + val advanceReadKey = UnsafeRowBufCodec.makeAdvanceRead(keyUr, keysBuf) + val advanceWrite = UnsafeRowBufCodec.makeAdvanceWrite(valuesBuf) + + val predicate = filterCache.computeIfAbsent(setId, _ => { + manager.getFilterBytes(setId, shard).map { case (exprBytes, schemaBytes) => + val ser = SparkEnv.get.serializer.newInstance() + val expr = ser.deserialize[Expression](java.nio.ByteBuffer.wrap(exprBytes)) + if (schemaBytes.nonEmpty) { + val schema = ser.deserialize[StructType](java.nio.ByteBuffer.wrap(schemaBytes)) + Predicate.create(expr, DataTypeUtils.toAttributes(schema)) + } else { + Predicate.create(expr) + } + } + }) + + predicate match { + case Some(pred) => + while (keysBuf.isReadable) { + val klen = keysBuf.readInt() + val key = advanceReadKey(klen) + val iter = rel.get(key) + if (iter != null) { + while (iter.hasNext) { + val buildRow = iter.next().asInstanceOf[UnsafeRow] + if (pred.eval(buildRow)) { + valuesBuf.writeInt(buildRow.getSizeInBytes) + advanceWrite(buildRow) + } + } + } + valuesBuf.writeInt(0) + } + case None => + while (keysBuf.isReadable) { + val klen = keysBuf.readInt() + val key = advanceReadKey(klen) + val iter = rel.get(key) + if (iter != null) { + while (iter.hasNext) { + val buildRow = iter.next().asInstanceOf[UnsafeRow] + valuesBuf.writeInt(buildRow.getSizeInBytes) + advanceWrite(buildRow) + } + } + valuesBuf.writeInt(0) + } + } + + new NettyManagedBuffer(valuesBuf) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeRowBufCodec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeRowBufCodec.scala new file mode 100644 index 0000000000000..2401602fbad23 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeRowBufCodec.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import io.netty.buffer.ByteBuf + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform + +object UnsafeRowBufCodec { + + def makeAdvanceRead(ur: UnsafeRow, buf: ByteBuf): Int => UnsafeRow = + if (buf.hasArray) { (len: Int) => + val idx = buf.readerIndex() + ur.pointTo(buf.array(), Platform.BYTE_ARRAY_OFFSET + buf.arrayOffset() + idx, len) + buf.readerIndex(idx + len) + ur + } else if (buf.hasMemoryAddress) { (len: Int) => + val idx = buf.readerIndex() + ur.pointTo(null, buf.memoryAddress() + idx, len) + buf.readerIndex(idx + len) + ur + } else { (len: Int) => + // fallback to bad perf case + val arr = new Array[Byte](len) + buf.readBytes(arr) + ur.pointTo(arr, Platform.BYTE_ARRAY_OFFSET, len) + ur + } + + def makeAdvanceWrite(buf: ByteBuf): UnsafeRow => Unit = + if (buf.hasArray) { (ur: UnsafeRow) => + buf.ensureWritable(ur.getSizeInBytes) + val idx = buf.writerIndex() + val dstOffset = Platform.BYTE_ARRAY_OFFSET + buf.arrayOffset() + idx + Platform.copyMemory( + ur.getBaseObject, + ur.getBaseOffset, + buf.array(), + dstOffset, + ur.getSizeInBytes) + buf.writerIndex(idx + ur.getSizeInBytes) + } else if (buf.hasMemoryAddress) { (ur: UnsafeRow) => + buf.ensureWritable(ur.getSizeInBytes) + val idx = buf.writerIndex() + Platform.copyMemory( + ur.getBaseObject, + ur.getBaseOffset, + null, + buf.memoryAddress() + idx, + ur.getSizeInBytes) + buf.writerIndex(idx + ur.getSizeInBytes) + } else { (ur: UnsafeRow) => + // fallback to bad perf case + buf.writeBytes(ur.getBytes) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 106ee36594b38..4cbe62d507463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -487,6 +487,7 @@ class CachedTableSuite extends SharedSparkSession def rddCleaned(rddId: Int): Unit = {} def shuffleCleaned(shuffleId: Int): Unit = {} def broadcastCleaned(broadcastId: Long): Unit = {} + def shardSetCleaned(setId: Long): Unit = {} def accumCleaned(accId: Long): Unit = { toBeCleanedAccIds.synchronized { toBeCleanedAccIds -= accId } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinSuite.scala new file mode 100644 index 0000000000000..ff4196143a99d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinSuite.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.internal.config.{EXECUTOR_MEMORY, SHARD_ENABLED} +import org.apache.spark.sql.{DataFrame, QueryTest, SparkSession} +import org.apache.spark.sql.catalyst.plans.logical.{DistributedMapJoinStrategy, JoinStrategyHint} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{ + AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite +} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShardExchangeExec} +import org.apache.spark.sql.internal.SQLConf + +abstract class DistributedMapJoinSuiteBase + extends QueryTest + with AdaptiveSparkPlanHelper { + + import testImplicits._ + + protected var spark: SparkSession = _ + + private val ensureReqs = EnsureRequirements() + + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession + .builder() + .master("local-cluster[2,1,512]") + .config(EXECUTOR_MEMORY.key, "512m") + .config(SHARD_ENABLED.key, "true") + .appName("dmj-testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + try { + spark.stop() + spark = null + } finally { + super.afterAll() + } + } + + override def beforeEach(): Unit = { + super.beforeEach() + System.gc() + } + + // prepare test data + private def prepareTables(): Unit = { + val dim: DataFrame = Seq[(Option[Int], String)]( + (Some(1), "A"), + (Some(2), "B"), + (Some(3), "C"), + (Some(4), "D"), + (Some(5), "E"), + (Some(6), "F"), + (Some(7), "G"), + (None, "NULLK")).toDF("k", "v") + + val fact: DataFrame = Seq + .tabulate(1000) { i => + val k: Option[Int] = if (i % 20 == 0) None else Some(i % 9) + (k, s"r$i") + } + .toDF("k", "payload") + + dim.createOrReplaceTempView("dim") + fact.createOrReplaceTempView("fact") + + val dim2: DataFrame = Seq[(Option[Int], String, String)]( + (Some(1), "X", "AX"), + (Some(1), "Y", "AY"), + (Some(2), "X", "BX"), + (Some(3), "Z", "CZ"), + (None, "X", "NX")).toDF("k", "cat", "val") + + val fact2: DataFrame = Seq( + (Some(1), "X", "fx1"), + (Some(1), "Y", "fy1"), + (Some(1), null.asInstanceOf[String], "fnull1"), + (Some(2), "Y", "fy2"), + (Some(2), "X", "fx2"), + (Some(3), "Z", "fz3"), + (Some(4), "W", "fw4")).toDF("k", "cat", "payload") + + dim2.createOrReplaceTempView("dim2") + fact2.createOrReplaceTempView("fact2") + + val dim3: DataFrame = Seq((1, "P"), (2, "Q"), (3, "R")).toDF("k", "tag") + dim3.createOrReplaceTempView("dim3") + + val t: DataFrame = Seq.tabulate(20)(i => (i % 5, s"s$i")).toDF("k", "s") + t.createOrReplaceTempView("t") + } + + private def assertDMJPlan(plan: SparkPlan, expectBuildRight: Boolean = true): Unit = { + // After EnsureRequirements, we should see ShardExchangeExec and DistributedMapJoinExec + val p = ensureReqs.apply(plan) + val hasShard = p.collect { case _: ShardExchangeExec => 1 }.nonEmpty + val dmjOpt = p.collect { case j: DistributedMapJoinExec => j }.headOption + assert(hasShard, s"Plan should contain ShardExchangeExec:\n$p") + assert(dmjOpt.isDefined, s"Plan should contain DistributedMapJoinExec:\n$p") + if (expectBuildRight) { + assert(dmjOpt.get.buildSide == org.apache.spark.sql.catalyst.optimizer.BuildRight) + } else { + assert(dmjOpt.get.buildSide == org.apache.spark.sql.catalyst.optimizer.BuildLeft) + } + } + + private def checkDMJEquals( + dmjSql: String, + baselineSql: String, + expectBuildRight: Boolean = true): Unit = { + val df = sql(dmjSql) + assertDMJPlan(df.queryExecution.sparkPlan, expectBuildRight) + checkAnswer(df, sql(baselineSql)) + } + + private def assertDMJPlanCount(plan: SparkPlan, expect: Int): Unit = { + val applied = ensureReqs.apply(plan) + val dmjCount = applied.collect { case _: DistributedMapJoinExec => 1 }.sum + val shardCount = applied.collect { case _: ShardExchangeExec => 1 }.sum + assert(dmjCount == expect, s"Expected $expect DMJ, but found $dmjCount:\n$applied") + assert( + shardCount == expect, + s"Expected $expect ShardExchange, but found $shardCount:\n$applied") + } + + test("inner join build right via DMJ hint") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d(shard_count=3, replica_count=1)) */ + | f.k, d.v, f.payload + |FROM fact f JOIN dim d ON f.k = d.k + |""".stripMargin + val base = + """SELECT f.k, d.v, f.payload FROM fact f JOIN dim d ON f.k = d.k""" + checkDMJEquals(dmj, base, expectBuildRight = true) + } + } + + test("DMJ: inner join on two keys (build right)") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d(shard_count=3,replica_count=1)) */ + | f.k, f.cat, d.val, f.payload + |FROM fact2 f JOIN dim2 d + |ON f.k = d.k AND f.cat = d.cat + |""".stripMargin + val base = + """SELECT f.k, f.cat, d.val, f.payload FROM fact2 f JOIN dim2 d + |ON f.k = d.k AND f.cat = d.cat""".stripMargin + val df = sql(dmj) + assertDMJPlanCount(df.queryExecution.sparkPlan, expect = 1) + checkAnswer(df, sql(base)) + } + } + + test("DMJ: two-key inner join with nulls on probe/build") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d) */ + | f.k, f.cat, d.val + |FROM fact2 f JOIN dim2 d + |ON f.k = d.k AND f.cat = d.cat + |""".stripMargin + val base = + """SELECT f.k, f.cat, d.val FROM fact2 f JOIN dim2 d + |ON f.k = d.k AND f.cat = d.cat""".stripMargin + val df = sql(dmj) + assertDMJPlanCount(df.queryExecution.sparkPlan, expect = 1) + checkAnswer(df, sql(base)) + } + } + + test("left outer join build right via DMJ hint (AQE on)") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d(shard_count=4, replica_count=1)) */ + | f.k, d.v + |FROM fact f LEFT OUTER JOIN dim d ON f.k = d.k + |""".stripMargin + val base = + """SELECT f.k, d.v FROM fact f LEFT OUTER JOIN dim d ON f.k = d.k""" + checkDMJEquals(dmj, base, expectBuildRight = true) + } + } + + test("DMJ: left outer join with null keys on probe, null-supplied rows preserved") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d(shard_count=4,replica_count=1)) */ + | f.k, d.v + |FROM fact f LEFT OUTER JOIN dim d ON f.k = d.k + |ORDER BY f.k NULLS FIRST, d.v NULLS FIRST + |""".stripMargin + val base = + """SELECT f.k, d.v FROM fact f LEFT OUTER JOIN dim d ON f.k = d.k + |ORDER BY f.k NULLS FIRST, d.v NULLS FIRST""".stripMargin + val df = sql(dmj) + assertDMJPlanCount(df.queryExecution.sparkPlan, expect = 1) + checkAnswer(df, sql(base)) + } + } + + test("right outer join build left via DMJ hint") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(dleft(shard_count=2, replica_count=1)) */ + | dleft.k, dleft.v, f.payload + |FROM dim dleft RIGHT OUTER JOIN fact f ON dleft.k = f.k + |""".stripMargin + val base = + """SELECT d.k, d.v, f.payload FROM dim d RIGHT OUTER JOIN fact f ON d.k = f.k""" + val df = sql(dmj) + assertDMJPlan(df.queryExecution.sparkPlan, expectBuildRight = false) + checkAnswer(df, sql(base)) + } + } + + test("left semi join build right via DMJ hint") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d(shard_count=3, replica_count=1)) */ + | f.k + |FROM fact f LEFT SEMI JOIN dim d ON f.k = d.k + |""".stripMargin + val base = + """SELECT f.k FROM fact f LEFT SEMI JOIN dim d ON f.k = d.k""" + checkDMJEquals(dmj, base, expectBuildRight = true) + } + } + + test("left anti join build right via DMJ hint") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d(shard_count=3, replica_count=1)) */ + | f.k + |FROM fact f LEFT ANTI JOIN dim d ON f.k = d.k + |""".stripMargin + val base = + """SELECT f.k FROM fact f LEFT ANTI JOIN dim d ON f.k = d.k""" + checkDMJEquals(dmj, base, expectBuildRight = true) + } + } + + test("DMJ: multi-table join chain (two DMJs in one query)") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d), DISTMAPJOIN(d3) */ + | f.k, d.v, d3.tag, f.payload + |FROM fact f + |JOIN dim d ON f.k = d.k + |JOIN dim3 d3 ON f.k = d3.k + |""".stripMargin + val base = + """SELECT f.k, d.v, d3.tag, f.payload FROM fact f + |JOIN dim d ON f.k = d.k + |JOIN dim3 d3 ON f.k = d3.k""".stripMargin + val df = sql(dmj) + assertDMJPlanCount(df.queryExecution.sparkPlan, expect = 2) + checkAnswer(df, sql(base)) + } + } + + test("DMJ: self join with build left via hint") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(a(shard_count=2,replica_count=1)) */ + | a.k, a.s, b.s + |FROM t a JOIN t b ON a.k = b.k + |""".stripMargin + val base = + """SELECT a.k, a.s, b.s FROM t a JOIN t b ON a.k = b.k""".stripMargin + val df = sql(dmj) + val applied = ensureReqs.apply(df.queryExecution.sparkPlan) + val dmjNode = applied.collect { case j: DistributedMapJoinExec => j }.head + assert( + dmjNode.buildSide == org.apache.spark.sql.catalyst.optimizer.BuildLeft, + s"Expected build left, but got ${dmjNode.buildSide}") + assertDMJPlanCount(df.queryExecution.sparkPlan, expect = 1) + checkAnswer(df, sql(base)) + } + } + + test("DMJ: subquery as build side with hint on alias") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(ds) */ + | f.k, ds.v + |FROM fact f + |JOIN (SELECT k, v FROM dim WHERE k IS NOT NULL) ds + |ON f.k = ds.k + |""".stripMargin + val base = + """SELECT f.k, ds.v FROM fact f + |JOIN (SELECT k, v FROM dim WHERE k IS NOT NULL) ds + |ON f.k = ds.k""".stripMargin + val df = sql(dmj) + assertDMJPlanCount(df.queryExecution.sparkPlan, expect = 1) + checkAnswer(df, sql(base)) + } + } + + test("DMJ: heterogeneous two-key join (int + string)") { + prepareTables() + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val fhs = Seq((1, "X", "p1"), (2, "Y", "p2"), (3, "Z", "p3"), (3, "W", "p4")).toDF( + "k", + "cat", + "payload") + fhs.createOrReplaceTempView("facts") + val dmj = + """ + |SELECT /*+ DISTMAPJOIN(d) */ + | f.k, f.cat, d.val + |FROM facts f JOIN dim2 d ON f.k = d.k AND f.cat = d.cat + |""".stripMargin + val base = + """SELECT f.k, f.cat, d.val FROM facts f JOIN dim2 d + |ON f.k = d.k AND f.cat = d.cat""".stripMargin + val df = sql(dmj) + assertDMJPlanCount(df.queryExecution.sparkPlan, expect = 1) + checkAnswer(df, sql(base)) + } + } + + test("DMJ hint in SQL is applied to correct side") { + withTempView("t", "u") { + spark.range(10).createOrReplaceTempView("t") + spark.range(10).createOrReplaceTempView("u") + val plan1 = sql( + "SELECT /*+ DISTMAPJOIN(t) */ * FROM t JOIN u ON t.id = u.id").queryExecution.optimizedPlan + .asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Join] + val plan2 = sql( + "SELECT /*+ DISTMAPJOIN(u) */ * FROM t JOIN u ON t.id = u.id").queryExecution.optimizedPlan + .asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Join] + val plan3 = sql( + "SELECT /*+ DISTMAPJOIN(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution.optimizedPlan + .asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Join] + assert(plan1.hint.leftHint.get.strategy.exists(isDistributedMapJoin)) + assert(plan1.hint.rightHint.isEmpty) + assert(plan2.hint.leftHint.isEmpty) + assert(plan2.hint.rightHint.get.strategy.exists(isDistributedMapJoin)) + assert(plan3.hint.leftHint.isEmpty && plan3.hint.rightHint.isEmpty) + } + } + + private def isDistributedMapJoin(s: JoinStrategyHint): Boolean = + s.isInstanceOf[DistributedMapJoinStrategy] +} + +// run with AQE disabled +class DistributedMapJoinSuite + extends DistributedMapJoinSuiteBase + with DisableAdaptiveExecutionSuite + +// run with AQE enabled +class DistributedMapJoinSuiteAE + extends DistributedMapJoinSuiteBase + with EnableAdaptiveExecutionSuite From fe1aebe87e9cef059ca9e344f546fc2887e69303 Mon Sep 17 00:00:00 2001 From: "zhongheng.gy" Date: Wed, 17 Jun 2026 10:02:43 +0800 Subject: [PATCH 2/3] Fix CI: binding policy, MiMa exclusion, CleanerTester arg order --- .../main/scala/org/apache/spark/internal/config/package.scala | 1 + .../src/test/scala/org/apache/spark/ContextCleanerSuite.scala | 3 ++- project/MimaExcludes.scala | 2 ++ .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++++ 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 514f319744458..db64aff15692c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -3001,6 +3001,7 @@ package object config { "started at application launch, enabling the distributed map join strategy " + "via SQL hints.") .version("5.0.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) .booleanConf .createWithDefault(false) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index a0cfa7659327c..66490d812b840 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -209,7 +209,8 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { assert(fs.exists(path)) // the checkpoint is not cleaned by default (without the configuration set) - var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, + checkpointIds = Seq(rddId)) rdd = null // Make RDD out of scope, ok if collected earlier runGC() postGCTester.assertCleanup() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 16ff9b40c16d1..d3c168bc755d4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,8 @@ object MimaExcludes { // Exclude rules for 4.2.x from 4.1.0 lazy val v42excludes = v41excludes ++ Seq( + // [SPARK-57487][SQL] Distributed map join adds shardManagerFactory parameter to SparkEnv + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this"), // [SQL] SafeJsonSerializer.safeMapToJValue: second parameter widened from Function1 to // Function2 so the key is passed to the value serializer (progress.scala). Binary-incompatible // vs spark-sql-api 4.0.0; not part of the public supported API (private[streaming] package). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bff770556b3a6..80163f2811516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -7492,6 +7492,7 @@ object SQLConf { buildConf("spark.sql.execution.distributedMapJoin.maxInFlightNum") .doc("Maximum number of concurrent RPC lookup batches per task on the probe side.") .version("5.0.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) .intConf .createWithDefault(8) @@ -7499,6 +7500,7 @@ object SQLConf { buildConf("spark.sql.execution.distributedMapJoin.maxBatchSize") .doc("Maximum number of probe-side keys per RPC lookup batch.") .version("5.0.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) .intConf .createWithDefault(1024) @@ -7507,6 +7509,7 @@ object SQLConf { .doc("Expected number of distinct keys per shard for the build-side bloom filter. " + "The total bloom filter capacity is this value multiplied by the number of shards.") .version("5.0.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) .longConf .createWithDefault(5L << 20) @@ -7514,6 +7517,7 @@ object SQLConf { buildConf("spark.sql.execution.distributedMapJoin.exchangeTimeout") .doc("Timeout in seconds for building shard data in distributed map join.") .version("5.0.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) .timeConf(TimeUnit.SECONDS) .createWithDefaultString(s"${30 * 60}") From 8f906d346744c658a52f8a5cded4f384702a7241 Mon Sep 17 00:00:00 2001 From: "zhongheng.gy" Date: Wed, 17 Jun 2026 14:22:45 +0800 Subject: [PATCH 3/3] Address review feedback: correctness, security, and resource safety --- .../spark/network/shard/ShardStoreClient.java | 7 ++- .../scala/org/apache/spark/SparkEnv.scala | 1 + .../network/ShardLookupServiceFactory.scala | 11 +++-- .../netty/NettyShardLookupService.scala | 29 ++++++++---- .../org/apache/spark/shard/ShardManager.scala | 7 ++- .../spark/shard/ShardManagerMaster.scala | 18 +++++-- .../shard/ShardManagerMasterEndpoint.scala | 41 ++++++++++++---- .../spark/shard/ShardManagerMessages.scala | 2 + .../plans/physical/partitioning.scala | 1 + .../apache/spark/sql/internal/SQLConf.scala | 2 + .../sql/execution/BufferedShardRowMap.java | 7 ++- .../spark/sql/execution/SparkStrategies.scala | 32 +++++++------ .../exchange/ShardExchangeExec.scala | 14 +++++- .../joins/DistributedMapJoinExec.scala | 47 +++++++++++++++++-- 14 files changed, 158 insertions(+), 61 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java b/common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java index 9cd3b2be82e4f..f921131a555d2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shard/ShardStoreClient.java @@ -19,9 +19,8 @@ import java.io.Closeable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - +import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.SparkLoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.util.TransportConf; @@ -32,7 +31,7 @@ * key lookups to build-side executors. */ public abstract class ShardStoreClient implements Closeable { - protected final Logger logger = LoggerFactory.getLogger(this.getClass()); + protected final SparkLogger logger = SparkLoggerFactory.getLogger(this.getClass()); protected volatile TransportClientFactory clientFactory; protected String appId; protected TransportConf transportConf; diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 084cf007ea8c4..e2a4d88b0bd30 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -604,6 +604,7 @@ object SparkEnv extends Logging { val lookupService = ShardLookupServiceFactory.create( conf, + securityManager, bindAddress, advertiseAddress, 0, diff --git a/core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala b/core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala index aadae80f413d2..d3d67d9c5a565 100644 --- a/core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/ShardLookupServiceFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.network -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.CLASS_NAME import org.apache.spark.network.netty.NettyShardLookupService @@ -41,6 +41,7 @@ private[spark] object ShardLookupServiceFactory extends Logging { def create( conf: SparkConf, + securityManager: SecurityManager, bindAddress: String, advertiseAddress: String, port: Int, @@ -48,23 +49,23 @@ private[spark] object ShardLookupServiceFactory extends Logging { masterEndpoint: RpcEndpointRef): ShardLookupService = { val className = conf.get(SHARD_LOOKUP_SERVICE_CLASS_KEY, DEFAULT_CLASS) if (className == DEFAULT_CLASS) { - new NettyShardLookupService(conf, bindAddress, advertiseAddress, port, + new NettyShardLookupService(conf, securityManager, bindAddress, advertiseAddress, port, numCores, masterEndpoint) } else { try { logInfo(log"Creating custom ShardLookupService: ${MDC(CLASS_NAME, className)}") Utils.classForName(className) .getDeclaredConstructor( - classOf[SparkConf], classOf[String], classOf[String], + classOf[SparkConf], classOf[SecurityManager], classOf[String], classOf[String], classOf[Int], classOf[Int], classOf[RpcEndpointRef]) - .newInstance(conf, bindAddress, advertiseAddress, + .newInstance(conf, securityManager, bindAddress, advertiseAddress, port.asInstanceOf[AnyRef], numCores.asInstanceOf[AnyRef], masterEndpoint) .asInstanceOf[ShardLookupService] } catch { case e: Exception => logWarning(log"Failed to create ${MDC(CLASS_NAME, className)}, falling back to Netty", e) - new NettyShardLookupService(conf, bindAddress, advertiseAddress, port, + new NettyShardLookupService(conf, securityManager, bindAddress, advertiseAddress, port, numCores, masterEndpoint) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala index e724a5f2efed6..79be30ea39b54 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyShardLookupService.scala @@ -19,10 +19,11 @@ package org.apache.spark.network.netty import scala.jdk.CollectionConverters._ -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network.{ShardLookupService, TransportContext} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.ManagedRpcResponseCallback +import org.apache.spark.network.client.{ManagedRpcResponseCallback, TransportClientBootstrap} +import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} import org.apache.spark.network.shard.ShardLookupListener import org.apache.spark.rpc.RpcEndpointRef @@ -32,6 +33,7 @@ import org.apache.spark.util.Utils private[spark] class NettyShardLookupService( conf: SparkConf, + securityManager: SecurityManager, bindAddress: String, val hostName: String, _port: Int, @@ -40,6 +42,7 @@ private[spark] class NettyShardLookupService( extends ShardLookupService { private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var rpcHandler: NettyShardRpcServer = _ @@ -65,23 +68,28 @@ private[spark] class NettyShardLookupService( cloned.setIfMissing("spark.network.waitForReachable", "false") cloned.setIfMissing("spark.network.sharedByteBufAllocators.enabled", "true") cloned.setIfMissing("spark.network.io.preferDirectBufs", "true") - transportConf = SparkTransportConf.fromSparkConf(cloned, "shard", numCores) + transportConf = SparkTransportConf.fromSparkConf( + cloned, "shard", numCores, + sslOptions = Some(securityManager.getRpcSSLOptions())) + var serverBootstrap: Option[TransportServerBootstrap] = None + var clientBootstrap: Option[TransportClientBootstrap] = None + if (authEnabled) { + serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) + clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) + } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory() - server = createNonAuthServer() + clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) + server = createServer(serverBootstrap.toList) appId = conf.getAppId logger.info(s"Server created on $hostName $bindAddress:${server.getPort}") } override def port: Int = server.getPort - private def createNonAuthServer(): TransportServer = { + private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { val server = - transportContext.createServer( - bindAddress, - port, - List.empty[TransportServerBootstrap].asJava) + transportContext.createServer(bindAddress, port, bootstraps.asJava) (server, server.getPort) } @@ -108,6 +116,7 @@ private[spark] class NettyShardLookupService( }) } catch { case e: Exception => + reqMsg.release() listener.onBatchFetchFailure(e) } diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManager.scala b/core/src/main/scala/org/apache/spark/shard/ShardManager.scala index 3102816e8e087..a1b25e7c175e6 100644 --- a/core/src/main/scala/org/apache/spark/shard/ShardManager.scala +++ b/core/src/main/scala/org/apache/spark/shard/ShardManager.scala @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.zip.Adler32 import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future} +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -136,8 +137,8 @@ private[spark] class ShardManager( installReplica(setId, id) } - def installReplicaSet(setId: Long, id: Int): Unit = { - master.installReplicaSet(setId, id) + def installReplicaSet(setId: Long, id: Int, timeout: FiniteDuration): Unit = { + master.installReplicaSet(setId, id, timeout) } def installReplica(setId: Long, id: Int): Unit = { @@ -323,6 +324,8 @@ private[spark] class ShardManager( def unpersist(setId: Long, blocking: Boolean): Unit = { logDebug(log"Unpersisting shard-set ${MDC(SHARD_SET_ID, setId)}") SparkEnv.get.blockManager.master.removeShardSet(setId, blocking) + master.removeShardSet(setId) + invokeCleanupCallbacks(setId) } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala index 616673b2f7f83..aeeb31c7b5afd 100644 --- a/core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerMaster.scala @@ -17,12 +17,14 @@ package org.apache.spark.shard +import scala.concurrent.duration.FiniteDuration + import com.google.common.cache.{CacheBuilder, CacheLoader} import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{EXECUTOR_ID, SHARD_ID, SHARD_MANAGER_ID, SHARD_SET_ID} -import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} import org.apache.spark.shard.ShardManagerMessages._ /** @@ -66,11 +68,12 @@ private[spark] class ShardManagerMaster( masterEndpoint.askSync[Boolean](UpdateShardInfo(shardManagerId, setId, id)) } - def installReplicaSet(setId: Long, shardId: Int): Unit = { - // non-blocking - masterEndpoint.ask[Boolean](InstallReplicaSet(setId, shardId)) + def installReplicaSet(setId: Long, shardId: Int, timeout: FiniteDuration): Unit = { + val rpcTimeout = new RpcTimeout(timeout, + "spark.sql.execution.distributedMapJoin.exchangeTimeout") + masterEndpoint.askSync[Boolean](InstallReplicaSet(setId, shardId), rpcTimeout) logInfo(log"Install replica set of shard" + - log" (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, shardId)}) requested") + log" (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, shardId)}) completed") } def getLocations(setId: Long, shardId: Int, refresh: Boolean = false): Seq[ShardManagerId] = { @@ -81,6 +84,11 @@ private[spark] class ShardManagerMaster( shardLocationsCache.get(key) } + def removeShardSet(setId: Long): Unit = { + masterEndpoint.ask[Boolean](RemoveShardSet(setId)) + logInfo(log"Removal of shard set ${MDC(SHARD_SET_ID, setId)} requested") + } + def removeExecutor(execId: String): Unit = { masterEndpoint.ask[Boolean](RemoveExecutor(execId)) logInfo(log"Removal(shard) of executor ${MDC(EXECUTOR_ID, execId)} requested") diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala index f4bef2c192b2d..2de3736e0a063 100644 --- a/core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerMasterEndpoint.scala @@ -21,11 +21,11 @@ import java.util.{HashMap => JHashMap} import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService} +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future} import scala.jdk.CollectionConverters._ import scala.util.Random -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{EXECUTOR_ID, SHARD_ID, SHARD_MANAGER_ID, SHARD_SET_ID} import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} @@ -72,12 +72,21 @@ private[spark] class ShardManagerMasterEndpoint( context.reply(true) case InstallReplicaSet(setId, shardId) => - installReplicaToWorkers(setId, shardId) - context.reply(true) + installReplicaToWorkers(setId, shardId).onComplete { + case scala.util.Success(_) => context.reply(true) + case scala.util.Failure(e) => + logWarning(log"Replica installation failed for" + + log" (${MDC(SHARD_SET_ID, setId)}, ${MDC(SHARD_ID, shardId)})", e) + context.reply(true) + } case GetLocations(setId, shardId) => context.reply(getLocations(setId, shardId)) + case RemoveShardSet(setId) => + removeShardSet(setId) + context.reply(true) + case RemoveExecutor(execId) => removeExecutor(execId) context.reply(true) @@ -156,10 +165,10 @@ private[spark] class ShardManagerMasterEndpoint( * and hosts that don't already hold this shard. Score is * execLoad*10 + hostLoad*5 + sameHostPenalty(100000) + jitter. */ - private def installReplicaToWorkers(setId: Long, shardId: Int): Unit = { + private def installReplicaToWorkers(setId: Long, shardId: Int): Future[Unit] = { shardSetInfo .get(setId) - .foreach { setInfo => + .map { setInfo => val targetReplica = setInfo.replicaCount val currentHolders = getLocations(setId, shardId) @@ -196,8 +205,9 @@ private[spark] class ShardManagerMasterEndpoint( } val currentExecIds = currentHolders.map(_.executorId).toSet - val candidates = - shardManagerInfo.keys.filterNot(smi => currentExecIds.contains(smi.executorId)) + val candidates = shardManagerInfo.keys + .filterNot(smi => currentExecIds.contains(smi.executorId)) + .filterNot(smi => !isLocal && smi.executorId == SparkContext.DRIVER_IDENTIFIER) var remaining = need while (remaining > 0) { @@ -214,12 +224,23 @@ private[spark] class ShardManagerMasterEndpoint( } } - chosen.foreach { smi => - shardManagerInfo.get(smi).foreach { sm => + val futures = chosen.flatMap { smi => + shardManagerInfo.get(smi).map { sm => sm.managerEndpoint.ask[Boolean](InstallReplica(setId, shardId)) } } + Future.sequence(futures).map(_ => ()) } + .getOrElse(Future.successful(())) + } + + private def removeShardSet(setId: Long): Unit = { + shardSetInfo.remove(setId) + shardSetLocations.remove(setId) + shardManagerInfo.values.foreach { info => + info.shards.remove(setId) + } + logInfo(log"Removed shard set ${MDC(SHARD_SET_ID, setId)} metadata") } private def removeExecutor(execId: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala b/core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala index 9f8e03ddd4a83..a02c69e02bc2c 100644 --- a/core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/shard/ShardManagerMessages.scala @@ -36,6 +36,8 @@ private[spark] object ShardManagerMessages { case class GetLocations(setId: Long, shardId: Int) extends ToShardManagerMasterEndpoint + case class RemoveShardSet(setId: Long) extends ToShardManagerMasterEndpoint + case class RemoveExecutor(execId: String) extends ToShardManagerMasterEndpoint case object StopShardManagerMaster extends ToShardManagerMasterEndpoint diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 230f3398edd5e..3141efcdafe36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -298,6 +298,7 @@ case object SinglePartition extends Partitioning { override def satisfies0(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false + case _: ShardDistribution => false case _ => true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80163f2811516..4c49bd4e3969c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -7494,6 +7494,7 @@ object SQLConf { .version("5.0.0") .withBindingPolicy(ConfigBindingPolicy.SESSION) .intConf + .checkValue(_ > 0, "must be positive") .createWithDefault(8) val DISTRIBUTED_MAP_JOIN_MAX_BATCH_SIZE = @@ -7502,6 +7503,7 @@ object SQLConf { .version("5.0.0") .withBindingPolicy(ConfigBindingPolicy.SESSION) .intConf + .checkValue(_ > 0, "must be positive") .createWithDefault(1024) val DISTRIBUTED_MAP_JOIN_BLOOM_FILTER_CAPACITY = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java index d5f2f251585cd..a68144f50192b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedShardRowMap.java @@ -26,9 +26,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.PooledByteBufAllocator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - +import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.SparkLoggerFactory; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.buffer.ManagedBuffer; @@ -50,7 +49,7 @@ */ public final class BufferedShardRowMap extends MemoryConsumer { - private static final Logger logger = LoggerFactory.getLogger(BufferedShardRowMap.class); + private static final SparkLogger logger = SparkLoggerFactory.getLogger(BufferedShardRowMap.class); private final TaskMemoryManager taskMemoryManager; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c21f7b80699f9..2e7ddf4fa2e92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -290,20 +290,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createDistributedMapJoin(onlyLookingAtHint: Boolean) = { - val buildSide = getDistributedMapJoinBuildSide( - left, right, joinType, hint, onlyLookingAtHint, conf) - checkDistributedMapJoinBuildSide(onlyLookingAtHint, buildSide, joinType, hint) - buildSide.map { - buildSide => - Seq(joins.DistributedMapJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - nonEquiCond, - planLater(left), - planLater(right), - hint)) + if (hashJoinSupport) { + val buildSide = getDistributedMapJoinBuildSide( + left, right, joinType, hint, onlyLookingAtHint, conf) + checkDistributedMapJoinBuildSide(onlyLookingAtHint, buildSide, joinType, hint) + buildSide.map { + buildSide => + Seq(joins.DistributedMapJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + nonEquiCond, + planLater(left), + planLater(right), + hint)) + } + } else { + None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala index a0fe0bca2f58e..0848cb9971453 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShardExchangeExec.scala @@ -22,6 +22,7 @@ import java.util.UUID import java.util.concurrent.{Future => JFuture, TimeUnit} import scala.concurrent.{ExecutionContext, Promise, TimeoutException} +import scala.concurrent.duration.Duration import org.apache.spark.{shard, SparkEnv, SparkException, TaskContext} import org.apache.spark.rdd.RDD @@ -102,6 +103,7 @@ case class ShardExchangeExec( override lazy val relationFuture: JFuture[shard.ShardSetRef] = SQLExecution .withThreadLocalCaptured[shard.ShardSetRef](session, ShardExchangeExec.executionContext) { + var setId = -1L try { sparkContext.setJobGroup( runId.toString, @@ -112,7 +114,7 @@ case class ShardExchangeExec( s"ensureExecutors timed out or SC stopped;" + s" active executors may be < target, continue build") } - val setId = sparkContext.env.shardManager.newShardSet(numShards, replicaCount) + setId = sparkContext.env.shardManager.newShardSet(numShards, replicaCount) val shuffled = ShuffleExchangeExec( HashPartitioning(buildBoundKeys, numShards), child, @@ -127,6 +129,7 @@ case class ShardExchangeExec( .getOrElse(Array.empty[Byte]) (exprBytes, schemaBytes) } + val replicaTimeout = Duration(timeout, TimeUnit.SECONDS) val shardIds = sharded .mapPartitionsWithIndexInternal { case (shardId, rowIter) => @@ -150,7 +153,7 @@ case class ShardExchangeExec( val sm = SparkEnv.get.shardManager sm.installShard(relation, setId, shardId, filterBytes)(bfOutput => bf.writeTo(bfOutput)) - sm.installReplicaSet(setId, shardId) + sm.installReplicaSet(setId, shardId, replicaTimeout) Iterator.single(shardId) } @@ -177,6 +180,13 @@ case class ShardExchangeExec( } catch { case e: Throwable => promise.tryFailure(e) + if (setId >= 0) { + try { + sparkContext.env.shardManager.unpersist(setId, blocking = false) + } catch { + case scala.util.control.NonFatal(_) => + } + } throw e } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala index 51cf3dc86fb9b..4733e3048715d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/DistributedMapJoinExec.scala @@ -29,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.rdd.RDD import org.apache.spark.shard.ShardSetRef import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, BindReferences, Expression, GenericInternalRow, JoinedRow, Predicate, PredicateHelper, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, BindReferences, Expression, GenericInternalRow, JoinedRow, Predicate, PredicateHelper, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} @@ -117,6 +117,8 @@ case class DistributedMapJoinExec( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = Nil + override def inputRDDs(): Seq[RDD[InternalRow]] = throw QueryExecutionErrors.executeCodePathUnsupportedError("DistributedMapJoin") @@ -402,6 +404,8 @@ case class DistributedMapJoinExec( private val bloom = streamedBloomFilter(setId) private val probeUr: UnsafeRow = new UnsafeRow(streamedPlan.schema.length) + @volatile private var cancelled = false + private val bufferedMap = { val mm = TaskContext.get().taskMemoryManager() val maxBatchSize = conf.distributedMapJoinMaxBatchSize @@ -413,7 +417,24 @@ case class DistributedMapJoinExec( streamedBoundKeys.length, probeUr, maxBatchSize) - TaskContext.get().addTaskCompletionListener[Unit](_ => map.free()) + TaskContext.get().addTaskCompletionListener[Unit] { _ => + cancelled = true + if (currentReader != null) { + currentReader.close() + currentReader = null + } + var item = lookupQueue.poll() + while (item != null) { + item match { + case LookupSuccess(batch, buffer) => + batch.release() + buffer.release() + case _ => + } + item = lookupQueue.poll() + } + map.free() + } map } @@ -523,7 +544,14 @@ case class DistributedMapJoinExec( private def flushLookup[T <: PBatch](iter: util.Iterator[T]): Unit = { val manager = SparkEnv.get.shardManager implicit val ec: ExecutionContextExecutorService = manager.lookupEc - while (iter.hasNext) { + while (iter.hasNext && !cancelled) { + while (numInFlight >= maxInFlightNum && !cancelled) { + pollLookup(lookupQueue.poll(200, util.concurrent.TimeUnit.MILLISECONDS)) + if (currentBatchIter != null) { + iterateLookup() + } + } + if (cancelled) return val batch = iter.next() iter.remove() val future = @@ -531,9 +559,18 @@ case class DistributedMapJoinExec( future.onComplete { case scala.util.Success(buffer) => - lookupQueue.put(LookupSuccess(batch, buffer)) + if (cancelled) { + batch.release() + buffer.release() + } else { + lookupQueue.put(LookupSuccess(batch, buffer)) + } case scala.util.Failure(cause) => - lookupQueue.put(LookupFailure(batch, cause)) + if (!cancelled) { + lookupQueue.put(LookupFailure(batch, cause)) + } else { + batch.release() + } } numInFlight += 1