Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
import java.security.{KeyStore, SecureRandom}
import java.security.cert.X509Certificate
import javax.crypto.KeyGenerator
import javax.net.ssl._

import com.google.common.hash.HashCodes
Expand All @@ -33,7 +32,6 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.sasl.SecretKeyHolder
import org.apache.spark.security.CryptoStreamUtils._
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -185,7 +183,9 @@ import org.apache.spark.util.Utils
* setting `spark.ssl.useNodeLocalConf` to `true`.
*/

private[spark] class SecurityManager(sparkConf: SparkConf)
private[spark] class SecurityManager(
sparkConf: SparkConf,
ioEncryptionKey: Option[Array[Byte]] = None)
extends Logging with SecretKeyHolder {

import SecurityManager._
Expand Down Expand Up @@ -415,6 +415,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
logInfo("Changing acls enabled to: " + aclsOn)
}

def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey

/**
* Generates or looks up the secret key.
*
Expand Down Expand Up @@ -559,19 +561,4 @@ private[spark] object SecurityManager {
// key used to store the spark secret in the Hadoop UGI
val SECRET_LOOKUP_KEY = "sparkCookie"

/**
* Setup the cryptographic key used by IO encryption in credentials. The key is generated using
* [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]].
*/
def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = {
if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) {
val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
keyGen.init(keyLen)

val ioKey = keyGen.generateKey()
credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded)
}
}
}
4 changes: 0 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging {
}

if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")
if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) {
throw new SparkException("IO encryption is only supported in YARN mode, please disable it " +
s"by setting ${IO_ENCRYPTION_ENABLED.key} to false")
}

// "_jobProgressListener" should be set up before creating SparkEnv because when creating
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.
Expand Down
33 changes: 24 additions & 9 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
Expand Down Expand Up @@ -165,15 +166,20 @@ object SparkEnv extends Logging {
val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
val port = conf.get("spark.driver.port").toInt
val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
Some(CryptoStreamUtils.createKey(conf))
} else {
None
}
create(
conf,
SparkContext.DRIVER_IDENTIFIER,
bindAddress,
advertiseAddress,
port,
isDriver = true,
isLocal = isLocal,
numUsableCores = numCores,
isLocal,
numCores,
ioEncryptionKey,
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
Expand All @@ -189,16 +195,17 @@ object SparkEnv extends Logging {
hostname: String,
port: Int,
numCores: Int,
ioEncryptionKey: Option[Array[Byte]],
isLocal: Boolean): SparkEnv = {
val env = create(
conf,
executorId,
hostname,
hostname,
port,
isDriver = false,
isLocal = isLocal,
numUsableCores = numCores
isLocal,
numCores,
ioEncryptionKey
)
SparkEnv.set(env)
env
Expand All @@ -213,18 +220,26 @@ object SparkEnv extends Logging {
bindAddress: String,
advertiseAddress: String,
port: Int,
isDriver: Boolean,
isLocal: Boolean,
numUsableCores: Int,
ioEncryptionKey: Option[Array[Byte]],
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {

val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER

// Listener bus is only used on the driver
if (isDriver) {
assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
}

val securityManager = new SecurityManager(conf)
val securityManager = new SecurityManager(conf, ioEncryptionKey)
ioEncryptionKey.foreach { _ =>
if (!securityManager.isSaslEncryptionEnabled()) {
logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " +
"wire.")
}
}

val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf,
Expand Down Expand Up @@ -270,7 +285,7 @@ object SparkEnv extends Logging {
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")

val serializerManager = new SerializerManager(serializer, conf)
val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)

val closureSerializer = new JavaSerializer(conf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
new SecurityManager(executorConf),
clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++
Seq[(String, String)](("spark.app.id", appId))
val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig)
val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()

// Create SparkEnv using properties we fetched from the driver.
Expand All @@ -221,7 +221,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}

val env = SparkEnv.createExecutorEnv(
driverConf, executorId, hostname, port, cores, isLocal = false)
driverConf, executorId, hostname, port, cores, cfg.ioEncryptionKey, isLocal = false)

env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable

private[spark] object CoarseGrainedClusterMessages {

case object RetrieveSparkProps extends CoarseGrainedClusterMessage
case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage

case class SparkAppConfig(
sparkProperties: Seq[(String, String)],
ioEncryptionKey: Option[Array[Byte]])
extends CoarseGrainedClusterMessage

case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
removeExecutor(executorId, reason)
context.reply(true)

case RetrieveSparkProps =>
context.reply(sparkProperties)
case RetrieveSparkAppConfig =>
val reply = SparkAppConfig(sparkProperties,
SparkEnv.get.securityManager.getIOEncryptionKey())
context.reply(reply)
}

// Make fake resource offers on all executors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,20 @@ package org.apache.spark.security

import java.io.{InputStream, OutputStream}
import java.util.Properties
import javax.crypto.KeyGenerator
import javax.crypto.spec.{IvParameterSpec, SecretKeySpec}

import org.apache.commons.crypto.random._
import org.apache.commons.crypto.stream._
import org.apache.hadoop.io.Text

import org.apache.spark.SparkConf
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._

/**
* A util class for manipulating IO encryption and decryption streams.
*/
private[spark] object CryptoStreamUtils extends Logging {
/**
* Constants and variables for spark IO encryption
*/
val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN")

// The initialization vector length in bytes.
val IV_LENGTH_IN_BYTES = 16
Expand All @@ -50,12 +45,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoOutputStream(
os: OutputStream,
sparkConf: SparkConf): OutputStream = {
sparkConf: SparkConf,
key: Array[Byte]): OutputStream = {
val properties = toCryptoConf(sparkConf)
val iv = createInitializationVector(properties)
os.write(iv)
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoOutputStream(transformationStr, properties, os,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
Expand All @@ -66,12 +60,11 @@ private[spark] object CryptoStreamUtils extends Logging {
*/
def createCryptoInputStream(
is: InputStream,
sparkConf: SparkConf): InputStream = {
sparkConf: SparkConf,
key: Array[Byte]): InputStream = {
val properties = toCryptoConf(sparkConf)
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
is.read(iv, 0, iv.length)
val credentials = SparkHadoopUtil.get.getCurrentUserCredentials()
val key = credentials.getSecretKey(SPARK_IO_TOKEN)
val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION)
new CryptoInputStream(transformationStr, properties, is,
new SecretKeySpec(key, "AES"), new IvParameterSpec(iv))
Expand All @@ -91,6 +84,17 @@ private[spark] object CryptoStreamUtils extends Logging {
props
}

/**
* Creates a new encryption key.
*/
def createKey(conf: SparkConf): Array[Byte] = {
val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS)
val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM)
val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm)
keyGen.init(keyLen)
keyGen.generateKey().getEncoded()
}

/**
* This method to generate an IV (Initialization Vector) using secure random.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
* Component which configures serialization, compression and encryption for various Spark
* components, including automatic selection of which [[Serializer]] to use for shuffles.
*/
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
private[spark] class SerializerManager(
defaultSerializer: Serializer,
conf: SparkConf,
encryptionKey: Option[Array[Byte]]) {

def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not use the parameter default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I thought I was calling this from Java code but must have removed it. I guess I can use a default argument now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm calling this from UnsafeShuffleWriterSuite.java in the accompanying change, and default values make it hard to call the code from Java.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for your clarifying.


private[this] val kryoSerializer = new KryoSerializer(conf)

Expand Down Expand Up @@ -63,9 +68,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
// Whether to compress shuffle output temporarily spilled to disk
private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)

// Whether to enable IO encryption
private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED)

/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
Expand Down Expand Up @@ -125,14 +127,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Wrap an input stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: InputStream): InputStream = {
if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s
encryptionKey
.map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
.getOrElse(s)
}

/**
* Wrap an output stream for encryption if shuffle encryption is enabled
*/
private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s
encryptionKey
.map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
.getOrElse(s)
}

/**
Expand Down
Loading