Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.client;

import org.apache.spark.network.buffer.ManagedBuffer;

public interface ManagedRpcResponseCallback extends BaseResponseCallback {

/**
* Successful response body.
* Ownership of {@code response} is transferred to the callback.
* The callback implementation MUST ensure {@code response} is released exactly once:
* either hand it off to the transport (e.g. wrap it in {@code RpcResponse} and write it to
* the channel, letting Netty release it after write completes), or call
* {@code response.release()} itself if it is not handed off.
* It may call {@code retain()} if it needs to pass the buffer to another thread.
*/
void onSuccess(ManagedBuffer response);
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,21 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
return requestId;
}

public long sendManagedRpc(ManagedBuffer message, ManagedRpcResponseCallback callback) {
if (logger.isTraceEnabled()) {
logger.trace("Sending RPC to {}", getRemoteAddress(channel));
}

long requestId = requestId();
handler.addRpcRequest(requestId, callback);

RpcChannelListener listener = new RpcChannelListener(requestId, callback);
channel.writeAndFlush(new RpcRequest(requestId, message))
.addListener(listener);

return requestId;
}

/**
* Sends a MergedBlockMetaRequest message to the server. The response of this message is
* either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ public void handle(ResponseMessage message) throws Exception {
"Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
}
} else if (message instanceof RpcResponse resp) {
RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId);
BaseResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
MDC.of(LogKeys.REQUEST_ID, resp.requestId),
Expand All @@ -198,10 +198,14 @@ public void handle(ResponseMessage message) throws Exception {
resp.body().release();
} else {
outstandingRpcs.remove(resp.requestId);
try {
listener.onSuccess(resp.body().nioByteBuffer());
} finally {
resp.body().release();
if (listener instanceof ManagedRpcResponseCallback) {
((ManagedRpcResponseCallback) listener).onSuccess(resp.body());
} else {
try {
((RpcResponseCallback) listener).onSuccess(resp.body().nioByteBuffer());
} finally {
resp.body().release();
}
}
}
} else if (message instanceof RpcFailure resp) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import org.apache.spark.internal.SparkLogger;
import org.apache.spark.internal.SparkLoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ManagedRpcResponseCallback;
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
Expand Down Expand Up @@ -136,6 +138,14 @@ public void onFailure(Throwable e) {

}

public interface ManagedRpcHandler {

void receive(
TransportClient client,
ManagedBuffer message,
ManagedRpcResponseCallback callback);
}

/**
* Handler for {@link MergedBlockMetaRequest}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,25 +163,58 @@ private void processStreamRequest(final StreamRequest req) {
}

private void processRpcRequest(final RpcRequest req) {
try {
rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
if (rpcHandler instanceof RpcHandler.ManagedRpcHandler) {
boolean handedOff = false;
try {
((RpcHandler.ManagedRpcHandler) rpcHandler).receive(reverseClient, req.body(),
new ManagedRpcResponseCallback() {
@Override
public void onSuccess(ManagedBuffer response) {
boolean sent = false;
try {
respond(new RpcResponse(req.requestId, response));
sent = true;
} finally {
if (!sent) {
response.release();
}
}
}

@Override
public void onFailure(Throwable e) {
respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e)));
}
});
handedOff = true;
} catch (Exception e) {
respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e)));
} finally {
if (!handedOff) {
req.body().release();
}
}
} else {
try {
rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
}

@Override
public void onFailure(Throwable e) {
respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e)));
}
});
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() on RPC id {} from {}", e,
MDC.of(LogKeys.REQUEST_ID, req.requestId),
MDC.of(LogKeys.HOST_PORT, getRemoteAddress(channel)));
respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e)));
} finally {
req.body().release();
@Override
public void onFailure(Throwable e) {
respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e)));
}
});
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() on RPC id {} from {}", e,
MDC.of(LogKeys.REQUEST_ID, req.requestId),
MDC.of(LogKeys.HOST_PORT, getRemoteAddress(channel)));
respond(new RpcFailure(req.requestId, JavaUtils.stackTraceToString(e)));
} finally {
req.body().release();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shard;

import java.util.EventListener;

import org.apache.spark.network.buffer.ManagedBuffer;

/**
* Listener for asynchronous shard batch lookup results in distributed map join.
*/
public interface ShardLookupListener extends EventListener {

/** Called when a batch lookup RPC completes successfully. */
void onBatchFetchSuccess(ManagedBuffer response);

/** Called when a batch lookup RPC fails. */
void onBatchFetchFailure(Throwable exception);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shard;

import java.io.Closeable;

import org.apache.spark.internal.SparkLogger;
import org.apache.spark.internal.SparkLoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.util.TransportConf;

/**
* Client interface for shard-based RPC lookups in distributed map join.
* Each executor has one instance, used by the probe side to send batched
* key lookups to build-side executors.
*/
public abstract class ShardStoreClient implements Closeable {
protected final SparkLogger logger = SparkLoggerFactory.getLogger(this.getClass());
protected volatile TransportClientFactory clientFactory;
protected String appId;
protected TransportConf transportConf;

protected void checkInit() {
assert appId != null : "Called before init()";
}

/**
* Send a batched key lookup request to the specified host.
*
* @param host the target executor's hostname
* @param port the target executor's shard service port
* @param reqMsg the serialized batch of probe keys
* @param listener callback for success or failure
*/
public abstract void fetchBatch(
String host,
int port,
ManagedBuffer reqMsg,
ShardLookupListener listener);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shard.protocol;

import java.util.Arrays;
import java.util.Objects;

import io.netty.buffer.ByteBuf;

import org.apache.spark.network.protocol.Encoders;

/**
* Request message for a batched key lookup in distributed map join.
* Contains the shard set ID, target shard, number of key fields,
* and the serialized key data.
*/
public class BatchLookupReq extends ShardLookupMessage {

public final long setId;
public final int shardId;
public final int numFields;
public final byte[] keysData;

public BatchLookupReq(long setId, int shardId, int numFields, byte[] keysData) {
this.setId = setId;
this.shardId = shardId;
this.numFields = numFields;
this.keysData = keysData;
}

@Override
protected Type type() {
return Type.BATCH_LOOKUP_REQ;
}

@Override
public int hashCode() {
return (Objects.hash(setId, shardId) * 31 + Objects.hash(numFields)) * 31 +
Arrays.hashCode(keysData);
}

@Override
public String toString() {
return String.format("BatchLookupReq[setId=%d,shardId=%d,numFields=%d,keysSize=%d]",
setId, shardId, numFields, keysData.length);
}

@Override
public boolean equals(Object other) {
if (other instanceof BatchLookupReq) {
BatchLookupReq o = (BatchLookupReq) other;
return setId == o.setId
&& shardId == o.shardId
&& numFields == o.numFields
&& Arrays.equals(keysData, o.keysData);
}
return false;
}

@Override
public int encodedLength() {
return Long.BYTES
+ Integer.BYTES
+ Integer.BYTES
+ Encoders.ByteArrays.encodedLength(keysData);
}

@Override
public void encode(ByteBuf buf) {
buf.writeLong(setId);
buf.writeInt(shardId);
buf.writeInt(numFields);
Encoders.ByteArrays.encode(buf, keysData);
}

public static BatchLookupReq decode(ByteBuf buf) {
long setId = buf.readLong();
int shardId = buf.readInt();
int numFields = buf.readInt();
byte[] keysData = Encoders.ByteArrays.decode(buf);
return new BatchLookupReq(setId, shardId, numFields, keysData);
}
}
Loading