Skip to content

Commit 1c88474

Browse files
committed
Avoid using ExecutorClassLoader to load Netty generated classes
1 parent e33aaa2 commit 1c88474

File tree

6 files changed

+43
-28
lines changed

6 files changed

+43
-28
lines changed

common/network-common/src/main/java/org/apache/spark/network/TransportContext.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,20 @@ public class TransportContext {
6262
private final RpcHandler rpcHandler;
6363
private final boolean closeIdleConnections;
6464

65-
private final MessageEncoder encoder;
66-
private final MessageDecoder decoder;
65+
/**
66+
* Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
67+
* before switching the current context class loader to ExecutorClassLoader.
68+
*
69+
* Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the
70+
* implementation calls "Class.forName" to check if this calls is already generated. If the
71+
* following two objects are created in "ExecutorClassLoader.findClass", it will cause
72+
* "ClassCircularityError". This is because loading this Netty generated class will call
73+
* "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use
74+
* RPC to load it and cause to load the non-exist matcher class again. JVM will report
75+
* `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
76+
*/
77+
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
78+
private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;
6779

6880
public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
6981
this(conf, rpcHandler, false);
@@ -75,8 +87,6 @@ public TransportContext(
7587
boolean closeIdleConnections) {
7688
this.conf = conf;
7789
this.rpcHandler = rpcHandler;
78-
this.encoder = new MessageEncoder();
79-
this.decoder = new MessageDecoder();
8090
this.closeIdleConnections = closeIdleConnections;
8191
}
8292

@@ -135,9 +145,9 @@ public TransportChannelHandler initializePipeline(
135145
try {
136146
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
137147
channel.pipeline()
138-
.addLast("encoder", encoder)
148+
.addLast("encoder", ENCODER)
139149
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
140-
.addLast("decoder", decoder)
150+
.addLast("decoder", DECODER)
141151
.addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
142152
// NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
143153
// would require more logic to guarantee if this were not part of the same event loop.

common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
3535

3636
private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
3737

38+
public static final MessageDecoder INSTANCE = new MessageDecoder();
39+
40+
private MessageDecoder() {
41+
super();
42+
}
43+
3844
@Override
3945
public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
4046
Message.Type msgType = Message.Type.decode(in);

common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
3535

3636
private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
3737

38+
public static final MessageEncoder INSTANCE = new MessageEncoder();
39+
40+
private MessageEncoder() {
41+
super();
42+
}
43+
3844
/***
3945
* Encodes a Message by invoking its encode() method. For non-data messages, we will add one
4046
* ByteBuf to 'out' containing the total frame length, the message type, and the message itself.

common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
package org.apache.spark.network.server;
1919

2020
import io.netty.channel.ChannelHandlerContext;
21-
import io.netty.channel.SimpleChannelInboundHandler;
21+
import io.netty.channel.ChannelInboundHandlerAdapter;
2222
import io.netty.handler.timeout.IdleState;
2323
import io.netty.handler.timeout.IdleStateEvent;
2424
import org.slf4j.Logger;
2525
import org.slf4j.LoggerFactory;
2626

2727
import org.apache.spark.network.client.TransportClient;
2828
import org.apache.spark.network.client.TransportResponseHandler;
29-
import org.apache.spark.network.protocol.Message;
3029
import org.apache.spark.network.protocol.RequestMessage;
3130
import org.apache.spark.network.protocol.ResponseMessage;
3231
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
@@ -48,7 +47,7 @@
4847
* on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
4948
* timeout if the client is continuously sending but getting no responses, for simplicity.
5049
*/
51-
public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
50+
public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
5251
private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
5352

5453
private final TransportClient client;
@@ -114,11 +113,13 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
114113
}
115114

116115
@Override
117-
public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
116+
public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
118117
if (request instanceof RequestMessage) {
119118
requestHandler.handle((RequestMessage) request);
120-
} else {
119+
} else if (request instanceof ResponseMessage) {
121120
responseHandler.handle((ResponseMessage) request);
121+
} else {
122+
ctx.fireChannelRead(request);
122123
}
123124
}
124125

common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@
4949
public class ProtocolSuite {
5050
private void testServerToClient(Message msg) {
5151
EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
52-
new MessageEncoder());
52+
MessageEncoder.INSTANCE);
5353
serverChannel.writeOutbound(msg);
5454

5555
EmbeddedChannel clientChannel = new EmbeddedChannel(
56-
NettyUtils.createFrameDecoder(), new MessageDecoder());
56+
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
5757

5858
while (!serverChannel.outboundMessages().isEmpty()) {
5959
clientChannel.writeInbound(serverChannel.readOutbound());
@@ -65,11 +65,11 @@ private void testServerToClient(Message msg) {
6565

6666
private void testClientToServer(Message msg) {
6767
EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
68-
new MessageEncoder());
68+
MessageEncoder.INSTANCE);
6969
clientChannel.writeOutbound(msg);
7070

7171
EmbeddedChannel serverChannel = new EmbeddedChannel(
72-
NettyUtils.createFrameDecoder(), new MessageDecoder());
72+
NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
7373

7474
while (!clientChannel.outboundMessages().isEmpty()) {
7575
serverChannel.writeInbound(clientChannel.readOutbound());

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,14 +2599,10 @@ private[spark] object Utils extends Logging {
25992599

26002600
private[util] object CallerContext extends Logging {
26012601
val callerContextSupported: Boolean = {
2602-
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
2602+
SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", true) && {
26032603
try {
2604-
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
2605-
// master Maven build, so do not use it before resolving SPARK-17714.
2606-
// scalastyle:off classforname
2607-
Class.forName("org.apache.hadoop.ipc.CallerContext")
2608-
Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
2609-
// scalastyle:on classforname
2604+
Utils.classForName("org.apache.hadoop.ipc.CallerContext")
2605+
Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
26102606
true
26112607
} catch {
26122608
case _: ClassNotFoundException =>
@@ -2681,12 +2677,8 @@ private[spark] class CallerContext(
26812677
def setCurrentContext(): Unit = {
26822678
if (CallerContext.callerContextSupported) {
26832679
try {
2684-
// `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
2685-
// master Maven build, so do not use it before resolving SPARK-17714.
2686-
// scalastyle:off classforname
2687-
val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
2688-
val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
2689-
// scalastyle:on classforname
2680+
val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
2681+
val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
26902682
val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
26912683
val hdfsContext = builder.getMethod("build").invoke(builderInst)
26922684
callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)

0 commit comments

Comments
 (0)