Skip to content

Commit fdb5467

Browse files
committed
Merge branch 'master' into SPARK-31474
2 parents 746eedf + 74aed8c commit fdb5467

File tree

24 files changed

+771
-507
lines changed

24 files changed

+771
-507
lines changed

common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java

Lines changed: 16 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@
2929
import org.slf4j.LoggerFactory;
3030

3131
import org.apache.spark.network.client.RpcResponseCallback;
32-
import org.apache.spark.network.client.StreamCallbackWithID;
3332
import org.apache.spark.network.client.TransportClient;
3433
import org.apache.spark.network.sasl.SecretKeyHolder;
3534
import org.apache.spark.network.sasl.SaslRpcHandler;
35+
import org.apache.spark.network.server.AbstractAuthRpcHandler;
3636
import org.apache.spark.network.server.RpcHandler;
37-
import org.apache.spark.network.server.StreamManager;
3837
import org.apache.spark.network.util.TransportConf;
3938

4039
/**
@@ -46,7 +45,7 @@
4645
* The delegate will only receive messages if the given connection has been successfully
4746
* authenticated. A connection may be authenticated at most once.
4847
*/
49-
class AuthRpcHandler extends RpcHandler {
48+
class AuthRpcHandler extends AbstractAuthRpcHandler {
5049
private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
5150

5251
/** Transport configuration. */
@@ -55,36 +54,31 @@ class AuthRpcHandler extends RpcHandler {
5554
/** The client channel. */
5655
private final Channel channel;
5756

58-
/**
59-
* RpcHandler we will delegate to for authenticated connections. When falling back to SASL
60-
* this will be replaced with the SASL RPC handler.
61-
*/
62-
@VisibleForTesting
63-
RpcHandler delegate;
64-
6557
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
6658
private final SecretKeyHolder secretKeyHolder;
6759

68-
/** Whether auth is done and future calls should be delegated. */
60+
/** RPC handler for auth handshake when falling back to SASL auth. */
6961
@VisibleForTesting
70-
boolean doDelegate;
62+
SaslRpcHandler saslHandler;
7163

7264
AuthRpcHandler(
7365
TransportConf conf,
7466
Channel channel,
7567
RpcHandler delegate,
7668
SecretKeyHolder secretKeyHolder) {
69+
super(delegate);
7770
this.conf = conf;
7871
this.channel = channel;
79-
this.delegate = delegate;
8072
this.secretKeyHolder = secretKeyHolder;
8173
}
8274

8375
@Override
84-
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
85-
if (doDelegate) {
86-
delegate.receive(client, message, callback);
87-
return;
76+
protected boolean doAuthChallenge(
77+
TransportClient client,
78+
ByteBuffer message,
79+
RpcResponseCallback callback) {
80+
if (saslHandler != null) {
81+
return saslHandler.doAuthChallenge(client, message, callback);
8882
}
8983

9084
int position = message.position();
@@ -98,18 +92,17 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
9892
if (conf.saslFallback()) {
9993
LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
10094
channel.remoteAddress());
101-
delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
95+
saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder);
10296
message.position(position);
10397
message.limit(limit);
104-
delegate.receive(client, message, callback);
105-
doDelegate = true;
98+
return saslHandler.doAuthChallenge(client, message, callback);
10699
} else {
107100
LOG.debug("Unexpected challenge message from client {}, closing channel.",
108101
channel.remoteAddress());
109102
callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
110103
channel.close();
111104
}
112-
return;
105+
return false;
113106
}
114107

115108
// Here we have the client challenge, so perform the new auth protocol and set up the channel.
@@ -131,7 +124,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
131124
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
132125
callback.onFailure(new IllegalArgumentException("Authentication failed."));
133126
channel.close();
134-
return;
127+
return false;
135128
} finally {
136129
if (engine != null) {
137130
try {
@@ -143,40 +136,6 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
143136
}
144137

145138
LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
146-
doDelegate = true;
147-
}
148-
149-
@Override
150-
public void receive(TransportClient client, ByteBuffer message) {
151-
delegate.receive(client, message);
152-
}
153-
154-
@Override
155-
public StreamCallbackWithID receiveStream(
156-
TransportClient client,
157-
ByteBuffer message,
158-
RpcResponseCallback callback) {
159-
return delegate.receiveStream(client, message, callback);
139+
return true;
160140
}
161-
162-
@Override
163-
public StreamManager getStreamManager() {
164-
return delegate.getStreamManager();
165-
}
166-
167-
@Override
168-
public void channelActive(TransportClient client) {
169-
delegate.channelActive(client);
170-
}
171-
172-
@Override
173-
public void channelInactive(TransportClient client) {
174-
delegate.channelInactive(client);
175-
}
176-
177-
@Override
178-
public void exceptionCaught(Throwable cause, TransportClient client) {
179-
delegate.exceptionCaught(cause, client);
180-
}
181-
182141
}

common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828
import org.slf4j.LoggerFactory;
2929

3030
import org.apache.spark.network.client.RpcResponseCallback;
31-
import org.apache.spark.network.client.StreamCallbackWithID;
3231
import org.apache.spark.network.client.TransportClient;
32+
import org.apache.spark.network.server.AbstractAuthRpcHandler;
3333
import org.apache.spark.network.server.RpcHandler;
34-
import org.apache.spark.network.server.StreamManager;
3534
import org.apache.spark.network.util.JavaUtils;
3635
import org.apache.spark.network.util.TransportConf;
3736

@@ -43,7 +42,7 @@
4342
* Note that the authentication process consists of multiple challenge-response pairs, each of
4443
* which are individual RPCs.
4544
*/
46-
public class SaslRpcHandler extends RpcHandler {
45+
public class SaslRpcHandler extends AbstractAuthRpcHandler {
4746
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
4847

4948
/** Transport configuration. */
@@ -52,37 +51,28 @@ public class SaslRpcHandler extends RpcHandler {
5251
/** The client channel. */
5352
private final Channel channel;
5453

55-
/** RpcHandler we will delegate to for authenticated connections. */
56-
private final RpcHandler delegate;
57-
5854
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
5955
private final SecretKeyHolder secretKeyHolder;
6056

6157
private SparkSaslServer saslServer;
62-
private boolean isComplete;
63-
private boolean isAuthenticated;
6458

6559
public SaslRpcHandler(
6660
TransportConf conf,
6761
Channel channel,
6862
RpcHandler delegate,
6963
SecretKeyHolder secretKeyHolder) {
64+
super(delegate);
7065
this.conf = conf;
7166
this.channel = channel;
72-
this.delegate = delegate;
7367
this.secretKeyHolder = secretKeyHolder;
7468
this.saslServer = null;
75-
this.isComplete = false;
76-
this.isAuthenticated = false;
7769
}
7870

7971
@Override
80-
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
81-
if (isComplete) {
82-
// Authentication complete, delegate to base handler.
83-
delegate.receive(client, message, callback);
84-
return;
85-
}
72+
public boolean doAuthChallenge(
73+
TransportClient client,
74+
ByteBuffer message,
75+
RpcResponseCallback callback) {
8676
if (saslServer == null || !saslServer.isComplete()) {
8777
ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
8878
SaslMessage saslMessage;
@@ -118,55 +108,28 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
118108
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
119109
logger.debug("SASL authentication successful for channel {}", client);
120110
complete(true);
121-
return;
111+
return true;
122112
}
123113

124114
logger.debug("Enabling encryption for channel {}", client);
125115
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
126116
complete(false);
127-
return;
117+
return true;
128118
}
129-
}
130-
131-
@Override
132-
public void receive(TransportClient client, ByteBuffer message) {
133-
delegate.receive(client, message);
134-
}
135-
136-
@Override
137-
public StreamCallbackWithID receiveStream(
138-
TransportClient client,
139-
ByteBuffer message,
140-
RpcResponseCallback callback) {
141-
return delegate.receiveStream(client, message, callback);
142-
}
143-
144-
@Override
145-
public StreamManager getStreamManager() {
146-
return delegate.getStreamManager();
147-
}
148-
149-
@Override
150-
public void channelActive(TransportClient client) {
151-
delegate.channelActive(client);
119+
return false;
152120
}
153121

154122
@Override
155123
public void channelInactive(TransportClient client) {
156124
try {
157-
delegate.channelInactive(client);
125+
super.channelInactive(client);
158126
} finally {
159127
if (saslServer != null) {
160128
saslServer.dispose();
161129
}
162130
}
163131
}
164132

165-
@Override
166-
public void exceptionCaught(Throwable cause, TransportClient client) {
167-
delegate.exceptionCaught(cause, client);
168-
}
169-
170133
private void complete(boolean dispose) {
171134
if (dispose) {
172135
try {
@@ -177,7 +140,6 @@ private void complete(boolean dispose) {
177140
}
178141

179142
saslServer = null;
180-
isComplete = true;
181143
}
182144

183145
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.server;
19+
20+
import java.nio.ByteBuffer;
21+
22+
import org.apache.spark.network.client.RpcResponseCallback;
23+
import org.apache.spark.network.client.StreamCallbackWithID;
24+
import org.apache.spark.network.client.TransportClient;
25+
26+
/**
27+
* RPC Handler which performs authentication, and when it's successful, delegates further
28+
* calls to another RPC handler. The authentication handshake itself should be implemented
29+
* by subclasses.
30+
*/
31+
public abstract class AbstractAuthRpcHandler extends RpcHandler {
32+
/** RpcHandler we will delegate to for authenticated connections. */
33+
private final RpcHandler delegate;
34+
35+
private boolean isAuthenticated;
36+
37+
protected AbstractAuthRpcHandler(RpcHandler delegate) {
38+
this.delegate = delegate;
39+
}
40+
41+
/**
42+
* Responds to an authentication challenge.
43+
*
44+
* @return Whether the client is authenticated.
45+
*/
46+
protected abstract boolean doAuthChallenge(
47+
TransportClient client,
48+
ByteBuffer message,
49+
RpcResponseCallback callback);
50+
51+
@Override
52+
public final void receive(
53+
TransportClient client,
54+
ByteBuffer message,
55+
RpcResponseCallback callback) {
56+
if (isAuthenticated) {
57+
delegate.receive(client, message, callback);
58+
} else {
59+
isAuthenticated = doAuthChallenge(client, message, callback);
60+
}
61+
}
62+
63+
@Override
64+
public final void receive(TransportClient client, ByteBuffer message) {
65+
if (isAuthenticated) {
66+
delegate.receive(client, message);
67+
} else {
68+
throw new SecurityException("Unauthenticated call to receive().");
69+
}
70+
}
71+
72+
@Override
73+
public final StreamCallbackWithID receiveStream(
74+
TransportClient client,
75+
ByteBuffer message,
76+
RpcResponseCallback callback) {
77+
if (isAuthenticated) {
78+
return delegate.receiveStream(client, message, callback);
79+
} else {
80+
throw new SecurityException("Unauthenticated call to receiveStream().");
81+
}
82+
}
83+
84+
@Override
85+
public StreamManager getStreamManager() {
86+
return delegate.getStreamManager();
87+
}
88+
89+
@Override
90+
public void channelActive(TransportClient client) {
91+
delegate.channelActive(client);
92+
}
93+
94+
@Override
95+
public void channelInactive(TransportClient client) {
96+
delegate.channelInactive(client);
97+
}
98+
99+
@Override
100+
public void exceptionCaught(Throwable cause, TransportClient client) {
101+
delegate.exceptionCaught(cause, client);
102+
}
103+
104+
public boolean isAuthenticated() {
105+
return isAuthenticated;
106+
}
107+
}

0 commit comments

Comments
 (0)