Skip to content

Commit 1d18928

Browse files
author
Marcelo Vanzin
committed
[SPARK-18547][core] Propagate I/O encryption key when executors register.
This change modifies the method used to propagate encryption keys used during shuffle. Instead of relying on YARN's UserGroupInformation credential propagation, this change explicitly distributes the key using the messages exchanged between driver and executor during registration. When RPC encryption is enabled, this means key propagation is also secure. This allows shuffle encryption to work in non-YARN mode, which means that it's easier to write unit tests for areas of the code that are affected by the feature. The key is stored in the SecurityManager; because there are many instances of that class used in the code, the key is only guaranteed to exist in the instance managed by the SparkEnv. This path was chosen to avoid storing the key in the SparkConf, which would risk having the key being written to disk as part of the configuration (as, for example, is done when starting YARN applications). Test by new and existing unit tests (which were moved from the YARN module to core), and by running apps with shuffle encryption enabled.
1 parent bdc8153 commit 1d18928

File tree

11 files changed

+143
-224
lines changed

11 files changed

+143
-224
lines changed

core/src/main/scala/org/apache/spark/SecurityManager.scala

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import java.lang.{Byte => JByte}
2121
import java.net.{Authenticator, PasswordAuthentication}
2222
import java.security.{KeyStore, SecureRandom}
2323
import java.security.cert.X509Certificate
24-
import javax.crypto.KeyGenerator
2524
import javax.net.ssl._
2625

2726
import com.google.common.hash.HashCodes
@@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
3332
import org.apache.spark.internal.Logging
3433
import org.apache.spark.internal.config._
3534
import org.apache.spark.network.sasl.SecretKeyHolder
36-
import org.apache.spark.security.CryptoStreamUtils._
3735
import org.apache.spark.util.Utils
3836

3937
/**
@@ -185,7 +183,9 @@ import org.apache.spark.util.Utils
185183
* setting `spark.ssl.useNodeLocalConf` to `true`.
186184
*/
187185

188-
private[spark] class SecurityManager(sparkConf: SparkConf)
186+
private[spark] class SecurityManager(
187+
sparkConf: SparkConf,
188+
ioEncryptionKey: Option[Array[Byte]] = None)
189189
extends Logging with SecretKeyHolder {
190190

191191
import SecurityManager._
@@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
415415
logInfo("Changing acls enabled to: " + aclsOn)
416416
}
417417

418+
def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey
419+
418420
/**
419421
* Generates or looks up the secret key.
420422
*
@@ -559,19 +561,4 @@ private[spark] object SecurityManager {
559561
// key used to store the spark secret in the Hadoop UGI
560562
val SECRET_LOOKUP_KEY = "sparkCookie"
561563

562-
/**
563-
* Setup the cryptographic key used by IO encryption in credentials. The key is generated using
564-
* [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
565-
*/
566-
def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
567-
if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
568-
val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
569-
val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
570-
val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
571-
keyGen.init(keyLen)
572-
573-
val ioKey = keyGen.generateKey()
574-
credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
575-
}
576-
}
577564
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging {
422422
}
423423

424424
if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
425-
if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
426-
throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
427-
s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
428-
}
429425

430426
// "_jobProgressListener" should be set up before creating SparkEnv because when creating
431427
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
3636
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
3737
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
3838
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
39+
import org.apache.spark.security.CryptoStreamUtils
3940
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
4041
import org.apache.spark.shuffle.ShuffleManager
4142
import org.apache.spark.storage._
@@ -165,15 +166,20 @@ object SparkEnv extends Logging {
165166
val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
166167
val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
167168
val port = conf.get("spark.driver.port").toInt
169+
val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
170+
Some(CryptoStreamUtils.createKey(conf))
171+
} else {
172+
None
173+
}
168174
create(
169175
conf,
170176
SparkContext.DRIVER_IDENTIFIER,
171177
bindAddress,
172178
advertiseAddress,
173179
port,
174-
isDriver = true,
175-
isLocal = isLocal,
176-
numUsableCores = numCores,
180+
isLocal,
181+
numCores,
182+
ioEncryptionKey,
177183
listenerBus = listenerBus,
178184
mockOutputCommitCoordinator = mockOutputCommitCoordinator
179185
)
@@ -189,16 +195,17 @@ object SparkEnv extends Logging {
189195
hostname: String,
190196
port: Int,
191197
numCores: Int,
198+
ioEncryptionKey: Option[Array[Byte]],
192199
isLocal: Boolean): SparkEnv = {
193200
val env = create(
194201
conf,
195202
executorId,
196203
hostname,
197204
hostname,
198205
port,
199-
isDriver = false,
200-
isLocal = isLocal,
201-
numUsableCores = numCores
206+
isLocal,
207+
numCores,
208+
ioEncryptionKey
202209
)
203210
SparkEnv.set(env)
204211
env
@@ -213,18 +220,20 @@ object SparkEnv extends Logging {
213220
bindAddress: String,
214221
advertiseAddress: String,
215222
port: Int,
216-
isDriver: Boolean,
217223
isLocal: Boolean,
218224
numUsableCores: Int,
225+
ioEncryptionKey: Option[Array[Byte]],
219226
listenerBus: LiveListenerBus = null,
220227
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
221228

229+
val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
230+
222231
// Listener bus is only used on the driver
223232
if (isDriver) {
224233
assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
225234
}
226235

227-
val securityManager = new SecurityManager(conf)
236+
val securityManager = new SecurityManager(conf, ioEncryptionKey)
228237

229238
val systemName = if (isDriver) driverSystemName else executorSystemName
230239
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf,
@@ -270,7 +279,7 @@ object SparkEnv extends Logging {
270279
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
271280
logDebug(s"Using serializer: ${serializer.getClass}")
272281

273-
val serializerManager = new SerializerManager(serializer, conf)
282+
val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)
274283

275284
val closureSerializer = new JavaSerializer(conf)
276285

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
200200
new SecurityManager(executorConf),
201201
clientMode = true)
202202
val driver = fetcher.setupEndpointRefByURI(driverUrl)
203-
val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
204-
Seq[(String, String)](("spark.app.id", appId))
203+
val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig)
204+
val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
205205
fetcher.shutdown()
206206

207207
// Create SparkEnv using properties we fetched from the driver.
@@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
221221
}
222222

223223
val env = SparkEnv.createExecutorEnv(
224-
driverConf, executorId, hostname, port, cores, isLocal = false)
224+
driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false)
225225

226226
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
227227
env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
2828

2929
private[spark] object CoarseGrainedClusterMessages {
3030

31-
case object RetrieveSparkProps extends CoarseGrainedClusterMessage
31+
case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage
32+
33+
case class SparkAppConfig(
34+
sparkProperties: Seq[(String, String)],
35+
ioEncryptionKey: Option[Array[Byte]])
36+
extends CoarseGrainedClusterMessage
3237

3338
case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage
3439

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
206206
removeExecutor(executorId, reason)
207207
context.reply(true)
208208

209-
case RetrieveSparkProps =>
210-
context.reply(sparkProperties)
209+
case RetrieveSparkAppConfig =>
210+
val reply = SparkAppConfig(sparkProperties,
211+
SparkEnv.get.securityManager.getIOEncryptionKey())
212+
context.reply(reply)
211213
}
212214

213215
// Make fake resource offers on all executors

core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,20 @@ package org.apache.spark.security
1818

1919
import java.io.{InputStream, OutputStream}
2020
import java.util.Properties
21+
import javax.crypto.KeyGenerator
2122
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}
2223

2324
import org.apache.commons.crypto.random._
2425
import org.apache.commons.crypto.stream._
25-
import org.apache.hadoop.io.Text
2626

2727
import org.apache.spark.SparkConf
28-
import org.apache.spark.deploy.SparkHadoopUtil
2928
import org.apache.spark.internal.Logging
3029
import org.apache.spark.internal.config._
3130

3231
/**
3332
* A util class for manipulating IO encryption and decryption streams.
3433
*/
3534
private[spark] object CryptoStreamUtils extends Logging {
36-
/**
37-
* Constants and variables for spark IO encryption
38-
*/
39-
val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")
4035

4136
// The initialization vector length in bytes.
4237
val IV_LENGTH_IN_BYTES = 16
@@ -50,12 +45,11 @@ private[spark] object CryptoStreamUtils extends Logging {
5045
*/
5146
def createCryptoOutputStream(
5247
os: OutputStream,
53-
sparkConf: SparkConf): OutputStream = {
48+
sparkConf: SparkConf,
49+
key: Array[Byte]): OutputStream = {
5450
val properties = toCryptoConf(sparkConf)
5551
val iv = createInitializationVector(properties)
5652
os.write(iv)
57-
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
58-
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
5953
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
6054
new CryptoOutputStream(transformationStr, properties, os,
6155
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -66,12 +60,11 @@ private[spark] object CryptoStreamUtils extends Logging {
6660
*/
6761
def createCryptoInputStream(
6862
is: InputStream,
69-
sparkConf: SparkConf): InputStream = {
63+
sparkConf: SparkConf,
64+
key: Array[Byte]): InputStream = {
7065
val properties = toCryptoConf(sparkConf)
7166
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
7267
is.read(iv, 0, iv.length)
73-
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
74-
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
7568
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
7669
new CryptoInputStream(transformationStr, properties, is,
7770
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
@@ -91,6 +84,17 @@ private[spark] object CryptoStreamUtils extends Logging {
9184
props
9285
}
9386

87+
/**
88+
* Creates a new encryption key.
89+
*/
90+
def createKey(conf: SparkConf): Array[Byte] = {
91+
val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
92+
val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
93+
val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
94+
keyGen.init(keyLen)
95+
keyGen.generateKey().getEncoded()
96+
}
97+
9498
/**
9599
* This method to generate an IV (Initialization Vector) using secure random.
96100
*/

core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
3333
* Component which configures serialization, compression and encryption for various Spark
3434
* components, including automatic selection of which [[Serializer]] to use for shuffles.
3535
*/
36-
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
36+
private[spark] class SerializerManager(
37+
defaultSerializer: Serializer,
38+
conf: SparkConf,
39+
encryptionKey: Option[Array[Byte]]) {
40+
41+
def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
3742

3843
private[this] val kryoSerializer = new KryoSerializer(conf)
3944

@@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
6368
// Whether to compress shuffle output temporarily spilled to disk
6469
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
6570

66-
// Whether to enable IO encryption
67-
private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)
68-
6971
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
7072
* the initialization of the compression codec until it is first used. The reason is that a Spark
7173
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -125,14 +127,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
125127
* Wrap an input stream for encryption if shuffle encryption is enabled
126128
*/
127129
private[this] def wrapForEncryption(s: InputStream): InputStream = {
128-
if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
130+
encryptionKey
131+
.map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
132+
.getOrElse(s)
129133
}
130134

131135
/**
132136
* Wrap an output stream for encryption if shuffle encryption is enabled
133137
*/
134138
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
135-
if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
139+
encryptionKey
140+
.map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
141+
.getOrElse(s)
136142
}
137143

138144
/**

0 commit comments

Comments
 (0)