diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 64102ccc05882..b9014a448fbf6 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -908,7 +908,7 @@ private[spark] object MapOutputTracker extends Logging { if (arr.length >= minBroadcastSize) { // Use broadcast instead. // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! - val bcast = broadcastManager.newBroadcast(arr, isLocal) + val bcast = broadcastManager.newDriverBroadcast(arr, isLocal) // toByteArray creates copy, so we can reuse out out.reset() out.write(BROADCAST) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 85a24acb97c07..43f512cc147cf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -40,7 +40,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.broadcast.{Broadcast, BroadcastMode} import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.executor.{ExecutorMetrics, ExecutorMetricsSource} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -1488,13 +1488,36 @@ class SparkContext(config: SparkConf) extends Logging { assertNotStopped() require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass), "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.") - val bc = env.broadcastManager.newBroadcast[T](value, isLocal) - val callSite = getCallSite + val bc = env.broadcastManager.newDriverBroadcast[T](value, isLocal) + val callSite = getCallSite() logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc } + /** + * :: DeveloperApi :: + * Broadcast a read-only variable to the cluster, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each cluster only once. + * + * Notice that the RDD to be broadcasted should be cached and materialized first so we can + * access its data on the executors. + */ + @DeveloperApi + def broadcast[T: ClassTag, U: ClassTag]( + rdd: RDD[T], mode: BroadcastMode[T]): Broadcast[U] = { + assertNotStopped() + require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass), + "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.") + val bc = env.broadcastManager.newExecutorBroadcast[T, U](rdd, mode, isLocal) + rdd.broadcast(bc) + val callSite = getCallSite() + logInfo("Created executor broadcast " + bc.id + " from " + callSite.shortForm) + cleaner.foreach(_.registerBroadcastForCleanup(bc)) + bc + } + /** * Add a file to be downloaded with this Spark job on every node. * diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index ece4ae6ab0310..30fbf906e3770 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -19,8 +19,8 @@ package org.apache.spark.broadcast import scala.reflect.ClassTag -import org.apache.spark.SecurityManager -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.rdd.RDD /** * An interface for all the broadcast implementations in Spark (to allow @@ -38,7 +38,22 @@ private[spark] trait BroadcastFactory { * @param isLocal whether we are in local mode (single JVM process) * @param id unique id representing this broadcast variable */ - def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newDriverBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + + /** + * Creates a new broadcast variable which is broadcasted on executors without collecting first + * to the driver. + * + * @param rdd the RDD to be broadcasted among executors + * @param mode the broadcast mode used to transform the result of RDD to broadcasted object + * @param isLocal whether we are in local mode (single JVM process) + * @param id unique id representing this broadcast variable + */ + def newExecutorBroadcast[T: ClassTag, U: ClassTag]( + rdd: RDD[T], + mode: BroadcastMode[T], + isLocal: Boolean, + id: Long): Broadcast[U] def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index c93cadf1ab3e8..af35a1cacb22d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -27,6 +27,7 @@ import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD private[spark] class BroadcastManager( val isDriver: Boolean, @@ -62,7 +63,7 @@ private[spark] class BroadcastManager( .asInstanceOf[java.util.Map[Any, Any]] ) - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { + def newDriverBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { val bid = nextBroadcastId.getAndIncrement() value_ match { case pb: PythonBroadcast => @@ -74,7 +75,15 @@ private[spark] class BroadcastManager( case _ => // do nothing } - broadcastFactory.newBroadcast[T](value_, isLocal, bid) + broadcastFactory.newDriverBroadcast[T](value_, isLocal, bid) + } + + def newExecutorBroadcast[T: ClassTag, U: ClassTag]( + rdd_ : RDD[T], + mode: BroadcastMode[T], + isLocal: Boolean): Broadcast[U] = { + broadcastFactory.newExecutorBroadcast[T, U](rdd_, mode, isLocal, + nextBroadcastId.getAndIncrement()) } def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastMode.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastMode.scala new file mode 100644 index 0000000000000..ac42f03ecb226 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastMode.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +/** + * The trait used in executor side broadcast. The implementation of `transform` identify the shape + * in which the results of a RDD are broadcasted. + * + * @tparam T The type of RDD elements. + */ +trait BroadcastMode[T] extends Serializable { + def transform(rows: Array[T]): Any + def transform(rows: Iterator[T], sizeHint: Option[Long]): Any + def canonicalized: BroadcastMode[T] = this +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 1024d9b5060bc..1d2b70d0963a3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,11 +20,9 @@ package org.apache.spark.broadcast import java.io._ import java.lang.ref.SoftReference import java.nio.ByteBuffer -import java.util.zip.Adler32 import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import scala.util.Random import org.apache.spark._ import org.apache.spark.internal.{config, Logging} @@ -37,25 +35,9 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea /** * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. * - * The mechanism is as follows: - * - * The driver divides the serialized object into small chunks and - * stores those chunks in the BlockManager of the driver. - * - * On each executor, the executor first attempts to fetch the object from its BlockManager. If - * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or - * other executors if available. Once it gets the chunks, it puts the chunks in its own - * BlockManager, ready for other executors to fetch from. - * - * This prevents the driver from being the bottleneck in sending out multiple copies of the - * broadcast data (one per executor). - * - * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. - * - * @param obj object to broadcast * @param id A unique identifier for the broadcast variable. */ -private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) +private[spark] abstract class TorrentBroadcast[T: ClassTag](id: Long) extends Broadcast[T](id) with Logging with Serializable { /** @@ -68,34 +50,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) */ @transient private var _value: SoftReference[T] = _ - /** The compression codec to use, or None if compression is disabled */ - @transient private var compressionCodec: Option[CompressionCodec] = _ - /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ - @transient private var blockSize: Int = _ - - private def setConf(conf: SparkConf): Unit = { - compressionCodec = if (conf.get(config.BROADCAST_COMPRESS)) { - Some(CompressionCodec.createCodec(conf)) - } else { - None - } - // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided - blockSize = conf.get(config.BROADCAST_BLOCKSIZE).toInt * 1024 - checksumEnabled = conf.get(config.BROADCAST_CHECKSUM) - } - setConf(SparkEnv.get.conf) - - private val broadcastId = BroadcastBlockId(id) + protected val broadcastId: BroadcastBlockId = BroadcastBlockId(id) /** Total number of blocks this broadcast variable contains. */ - private val numBlocks: Int = writeBlocks(obj) + protected val numBlocks: Int - /** Whether to generate checksum for blocks or not. */ - private var checksumEnabled: Boolean = false - /** The checksum for all the blocks. */ - private var checksums: Array[Int] = _ + protected def readAndProcessBlocks(): T - override protected def getValue() = synchronized { + override protected def getValue(): T = synchronized { val memoized: T = if (_value == null) null.asInstanceOf[T] else _value.get if (memoized != null) { memoized @@ -106,101 +68,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } } - private def calcChecksum(block: ByteBuffer): Int = { - val adler = new Adler32() - if (block.hasArray) { - adler.update(block.array, block.arrayOffset + block.position(), block.limit() - - block.position()) - } else { - val bytes = new Array[Byte](block.remaining()) - block.duplicate.get(bytes) - adler.update(bytes) - } - adler.getValue.toInt - } - - /** - * Divide the object into multiple blocks and put those blocks in the block manager. - * - * @param value the object to divide - * @return number of blocks this broadcast variable is divided into - */ - private def writeBlocks(value: T): Int = { - import StorageLevel._ - // Store a copy of the broadcast variable in the driver so that tasks run on the driver - // do not create a duplicate copy of the broadcast variable's value. - val blockManager = SparkEnv.get.blockManager - if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") - } - try { - val blocks = - TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) - if (checksumEnabled) { - checksums = new Array[Int](blocks.length) - } - blocks.zipWithIndex.foreach { case (block, i) => - if (checksumEnabled) { - checksums(i) = calcChecksum(block) - } - val pieceId = BroadcastBlockId(id, "piece" + i) - val bytes = new ChunkedByteBuffer(block.duplicate()) - if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { - throw new SparkException(s"Failed to store $pieceId of $broadcastId " + - s"in local BlockManager") - } - } - blocks.length - } catch { - case t: Throwable => - logError(s"Store broadcast $broadcastId fail, remove all pieces of the broadcast") - blockManager.removeBroadcast(id, tellMaster = true) - throw t - } - } - - /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[BlockData] = { - // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported - // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[BlockData](numBlocks) - val bm = SparkEnv.get.blockManager - - for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { - val pieceId = BroadcastBlockId(id, "piece" + pid) - logDebug(s"Reading piece $pieceId of $broadcastId") - // First try getLocalBytes because there is a chance that previous attempts to fetch the - // broadcast blocks have already fetched some of the blocks. In that case, some blocks - // would be available locally (on this executor). - bm.getLocalBytes(pieceId) match { - case Some(block) => - blocks(pid) = block - releaseBlockManagerLock(pieceId) - case None => - bm.getRemoteBytes(pieceId) match { - case Some(b) => - if (checksumEnabled) { - val sum = calcChecksum(b.chunks(0)) - if (sum != checksums(pid)) { - throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + - s" $sum != ${checksums(pid)}") - } - } - // We found the block from remote executors/driver's BlockManager, so put the block - // in this executor's BlockManager. - if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { - throw new SparkException( - s"Failed to store $pieceId of $broadcastId in local BlockManager") - } - blocks(pid) = new ByteBufferBlockData(b, true) - case None => - throw new SparkException(s"Failed to get $pieceId of $broadcastId") - } - } - } - blocks - } - /** * Remove all persisted state associated with this Torrent broadcast on the executors. */ @@ -222,14 +89,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) out.defaultWriteObject() } - private def readBroadcastBlock(): T = Utils.tryOrIOException { + protected def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.torrentBroadcastLock.withLock(broadcastId) { // As we only lock based on `broadcastId`, whenever using `broadcastCache`, we should only // touch `broadcastId`. val broadcastCache = SparkEnv.get.broadcastManager.cachedValues Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { - setConf(SparkEnv.get.conf) val blockManager = SparkEnv.get.blockManager blockManager.getLocalValues(broadcastId) match { case Some(blockResult) => @@ -246,31 +112,19 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } case None => - val estimatedTotalSize = Utils.bytesToString(numBlocks * blockSize) - logInfo(s"Started reading broadcast variable $id with $numBlocks pieces " + - s"(estimated total size $estimatedTotalSize)") - val startTimeNs = System.nanoTime() - val blocks = readBlocks() - logInfo(s"Reading broadcast variable $id took ${Utils.getUsedTimeNs(startTimeNs)}") - - try { - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") - } - - if (obj != null) { - broadcastCache.put(broadcastId, obj) - } + val obj = readAndProcessBlocks() + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } - obj - } finally { - blocks.foreach(_.dispose()) + if (obj != null) { + broadcastCache.put(broadcastId, obj) } + + obj } } } @@ -280,7 +134,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) * If running in a task, register the given block's locks for release upon task completion. * Otherwise, if not running in a task then immediately release the lock. */ - private def releaseBlockManagerLock(blockId: BlockId): Unit = { + protected def releaseBlockManagerLock(blockId: BlockId): Unit = { val blockManager = SparkEnv.get.blockManager Option(TaskContext.get()) match { case Some(taskContext) => diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 65fb5186afae1..2e2e7dbb7ef88 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -20,19 +20,32 @@ package org.apache.spark.broadcast import scala.reflect.ClassTag import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD /** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a BitTorrent-like * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. */ -private[spark] class TorrentBroadcastFactory extends BroadcastFactory { +private[spark] class TorrentBroadcastFactory extends BroadcastFactory with Logging { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit = { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { - new TorrentBroadcast[T](value_, id) + override def newDriverBroadcast[T: ClassTag]( + value_ : T, + isLocal: Boolean, + id: Long): Broadcast[T] = { + new TorrentDriverBroadcast[T](value_, id) + } + + override def newExecutorBroadcast[T: ClassTag, U: ClassTag]( + rdd: RDD[T], + mode: BroadcastMode[T], + isLocal: Boolean, id: Long): Broadcast[U] = { + logInfo(s"Creating executor broadcast $id for rdd_${rdd.id}") + new TorrentExecutorBroadcast[T, U](rdd, mode, id) } override def stop(): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentDriverBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentDriverBroadcast.scala new file mode 100644 index 0000000000000..418e5fcc8bd56 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentDriverBroadcast.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ +import java.nio.ByteBuffer +import java.util.zip.Adler32 + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.util.io.ChunkedByteBuffer + +/** + * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. + * + * The mechanism is as follows: + * + * The driver divides the serialized object into small chunks and + * stores those chunks in the BlockManager of the driver. + * + * On each executor, the executor first attempts to fetch the object from its BlockManager. If + * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or + * other executors if available. Once it gets the chunks, it puts the chunks in its own + * BlockManager, ready for other executors to fetch from. + * + * This prevents the driver from being the bottleneck in sending out multiple copies of the + * broadcast data (one per executor). + * + * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. + * + * @param obj object to broadcast + * @param id A unique identifier for the broadcast variable. + */ +private[spark] class TorrentDriverBroadcast[T: ClassTag](obj: T, id: Long) + extends TorrentBroadcast[T](id) with Logging with Serializable { + + /** The compression codec to use, or None if compression is disabled */ + @transient private var compressionCodec: Option[CompressionCodec] = _ + /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ + @transient private var blockSize: Int = _ + + private def setConf(conf: SparkConf): Unit = { + compressionCodec = if (conf.get(config.BROADCAST_COMPRESS)) { + Some(CompressionCodec.createCodec(conf)) + } else { + None + } + // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided + blockSize = conf.get(config.BROADCAST_BLOCKSIZE).toInt * 1024 + checksumEnabled = conf.get(config.BROADCAST_CHECKSUM) + } + setConf(SparkEnv.get.conf) + + /** Total number of blocks this broadcast variable contains. */ + override protected val numBlocks: Int = writeBlocks(obj) + + /** Whether to generate checksum for blocks or not. */ + private var checksumEnabled: Boolean = false + /** The checksum for all the blocks. */ + private var checksums: Array[Int] = _ + + private def calcChecksum(block: ByteBuffer): Int = { + val adler = new Adler32() + if (block.hasArray) { + adler.update(block.array, block.arrayOffset + block.position(), block.limit() + - block.position()) + } else { + val bytes = new Array[Byte](block.remaining()) + block.duplicate.get(bytes) + adler.update(bytes) + } + adler.getValue.toInt + } + + /** + * Divide the object into multiple blocks and put those blocks in the block manager. + * + * @param value the object to divide + * @return number of blocks this broadcast variable is divided into + */ + private def writeBlocks(value: T): Int = { + import StorageLevel._ + // Store a copy of the broadcast variable in the driver so that tasks run on the driver + // do not create a duplicate copy of the broadcast variable's value. + val blockManager = SparkEnv.get.blockManager + if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + try { + val blocks = + TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) + if (checksumEnabled) { + checksums = new Array[Int](blocks.length) + } + blocks.zipWithIndex.foreach { case (block, i) => + if (checksumEnabled) { + checksums(i) = calcChecksum(block) + } + val pieceId = BroadcastBlockId(id, "piece" + i) + val bytes = new ChunkedByteBuffer(block.duplicate()) + if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException(s"Failed to store $pieceId of $broadcastId " + + s"in local BlockManager") + } + } + blocks.length + } catch { + case t: Throwable => + logError(s"Store broadcast $broadcastId fail, remove all pieces of the broadcast") + blockManager.removeBroadcast(id, tellMaster = true) + throw t + } + } + + /** Fetch torrent blocks from the driver and/or other executors. */ + private def readBlocks(): Array[BlockData] = { + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these chunks from this executor as well. + val blocks = new Array[BlockData](numBlocks) + val bm = SparkEnv.get.blockManager + + for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { + val pieceId = BroadcastBlockId(id, "piece" + pid) + logDebug(s"Reading piece $pieceId of $broadcastId") + // First try getLocalBytes because there is a chance that previous attempts to fetch the + // broadcast blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + bm.getLocalBytes(pieceId) match { + case Some(block) => + blocks(pid) = block + releaseBlockManagerLock(pieceId) + case None => + bm.getRemoteBytes(pieceId) match { + case Some(b) => + if (checksumEnabled) { + val sum = calcChecksum(b.chunks(0)) + if (sum != checksums(pid)) { + throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + + s" $sum != ${checksums(pid)}") + } + } + // We found the block from remote executors/driver's BlockManager, so put the block + // in this executor's BlockManager. + if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store $pieceId of $broadcastId in local BlockManager") + } + blocks(pid) = new ByteBufferBlockData(b, true) + case None => + throw new SparkException(s"Failed to get $pieceId of $broadcastId") + } + } + } + blocks + } + + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + assertValid() + out.defaultWriteObject() + } + + override protected def readAndProcessBlocks(): T = { + setConf(SparkEnv.get.conf) + val estimatedTotalSize = Utils.bytesToString(numBlocks * blockSize) + logInfo(s"Started reading broadcast variable $id with $numBlocks pieces " + + s"(estimated total size $estimatedTotalSize)") + + val startTimeMs = System.currentTimeMillis() + val blocks = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeNs(startTimeMs)) + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + obj + } finally { + blocks.foreach(_.dispose()) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentExecutorBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentExecutorBroadcast.scala new file mode 100644 index 0000000000000..1d4c58703eacf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentExecutorBroadcast.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io.ObjectOutputStream + +import scala.reflect.ClassTag +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockResult, RDDBlockId, StorageLevel} +import org.apache.spark.util.Utils + +/** + * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. + * + * Different to [[TorrentDriverBroadcast]], this implementation doesn't divide the object to + * broadcast. In contrast, this implementation performs broadcast on executor side for a RDD. + * So the results of the RDD does not need to collect first back to the driver before broadcasting. + * + * The mechanism is as follows: + * + * On each executor, the executor first attempts to fetch the object from its BlockManager. If + * it does not exist, it then uses remote fetches to fetch the blocks of the RDD from other + * executors if available. Once it gets the blocks, it puts the blocks in its own BlockManager, + * ready for other executors to fetch from. + * + * @tparam T The type of the element of RDD to be broadcasted. + * @tparam U The type of object transformed from the collection of elements of the RDD. + * + * @param rdd The RDD to be broadcasted on executors. + * @param mode The [[org.apache.spark.broadcast.BroadcastMode]] object used to transform the result + * of RDD to the object which will be stored in block manager. + * @param id A unique identifier for the broadcast variable. + */ +private[spark] class TorrentExecutorBroadcast[T: ClassTag, U: ClassTag]( + @transient private val rdd: RDD[T], + mode: BroadcastMode[T], + id: Long) extends TorrentBroadcast[U](id) with Logging with Serializable { + + // Total number of blocks this broadcast variable contains. + override protected val numBlocks: Int = rdd.getNumPartitions + // The id of the RDD to be broadcasted on executors. + private val rddId: Int = rdd.id + + /** Fetch torrent blocks from other executors. */ + private def readBlocks(): Array[T] = { + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these chunks from this executor as well. + val blocks = new Array[Array[T]](numBlocks) + val bm = SparkEnv.get.blockManager + + for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { + val pieceId = RDDBlockId(rddId, pid) + // First try getLocalValues because there is a chance that previous attempts to fetch the + // broadcast blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + bm.getLocalValues(pieceId) match { + case Some(block: BlockResult) => + blocks(pid) = block.data.asInstanceOf[Iterator[T]].toArray + case None => + bm.get[T](pieceId) match { + case Some(b) => + val data = b.data.asInstanceOf[Iterator[T]].toArray + // We found the block from remote executors' BlockManager, so put the block + // in this executor's BlockManager. + if (!bm.putIterator(pieceId, data.toIterator, + StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store $pieceId of $broadcastId in local BlockManager") + } + blocks(pid) = data + case None => + logWarning(s"Failed to get $pieceId of $broadcastId") + throw new SparkException(s"Failed to get $pieceId of $broadcastId") + } + } + } + blocks.flatMap(x => x) + } + + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + assertValid() + out.defaultWriteObject() + } + + override protected def readAndProcessBlocks(): U = { + logInfo(s"Started reading executor broadcast variable $id with $numBlocks pieces") + val startTimeMs = System.currentTimeMillis() + val rawInput = readBlocks() + logInfo("Reading executor broadcast variable " + id + " took" + + Utils.getUsedTimeNs(startTimeMs)) + + mode.transform(rawInput).asInstanceOf[U] + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6095042de7f0c..8f57890c90285 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -36,6 +36,7 @@ import org.apache.spark._ import org.apache.spark.Partitioner._ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.RDD_LIMIT_SCALE_UP_FACTOR @@ -211,6 +212,24 @@ abstract class RDD[T: ClassTag]( */ def cache(): this.type = persist() + /** + * Broadcast this RDD on executors. The executor side broadcast variable is created by + * [[SparkContext]]. This RDD should be cached and materialized first before calling on + * this method. + */ + private[spark] def broadcast[U: ClassTag](broadcasted: Broadcast[U]): Unit = { + // The RDD should be cached and materialized before it can be executor side broadcasted. + // We do the checking here. + if (storageLevel == StorageLevel.NONE) { + throw new SparkException("To broadcast this RDD on executors, it should be cached first.") + } + // Create the executor side broadcast object on executors. + mapPartitionsInternal { iter: Iterator[T] => + broadcasted.value + Iterator.empty.asInstanceOf[Iterator[T]] + }.count + } + /** * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. * diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 5e8b25f425166..6f7ee13e7521c 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -49,6 +49,11 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } +class IntArrayBroadcastMode extends BroadcastMode[Int] { + override def transform(rows: Array[Int]): Array[Int] = rows + override def transform(rows: Iterator[Int], sizeHint: Option[Long]): Array[Int] = rows.toArray +} + class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite { test("Using TorrentBroadcast locally") { @@ -144,7 +149,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio sc.stop() } - encryptionTest("Cache broadcast to disk") { conf => + encryptionTest("Cache broadcast to disk - driver side") { conf => conf.setMaster("local") .setAppName("test") .set(config.MEMORY_STORAGE_FRACTION, 0.0) @@ -154,7 +159,19 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(broadcast.value.sum === 10) } - test("One broadcast value instance per executor") { + encryptionTest("Cache broadcast to disk - executor side") { conf => + conf.setMaster("local") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + sc = new SparkContext(conf) + val rdd = sc.parallelize(1 to 4, 2).persist(StorageLevel.MEMORY_AND_DISK) + rdd.count() + val executorBroadcast = sc.broadcast[Int, Array[Int]](rdd, new IntArrayBroadcastMode) + assert(executorBroadcast.value.sum === 10) + } + + test("One broadcast value instance per executor - driver side") { val conf = new SparkConf() .setMaster("local[4]") .setAppName("test") @@ -170,7 +187,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(instances.size === 1) } - test("One broadcast value instance per executor when memory is constrained") { + test("One broadcast value instance per executor when memory is constrained - driver side") { val conf = new SparkConf() .setMaster("local[4]") .setAppName("test") @@ -187,6 +204,25 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio assert(instances.size === 1) } + test("One broadcast value instance per executor when memory is constrained - executor side") { + val conf = new SparkConf() + .setMaster("local[4]") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + + sc = new SparkContext(conf) + val rdd = sc.parallelize(1 to 4, 2).persist(StorageLevel.MEMORY_AND_DISK) + rdd.count() + val executorBroadcast = sc.broadcast[Int, Array[Int]](rdd, new IntArrayBroadcastMode) + val executorInstances = sc.parallelize(1 to 10) + .map(x => System.identityHashCode(executorBroadcast.value)) + .collect() + .toSet + + assert(executorInstances.size === 1) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8962fd6740bf6..9c259a7d6204b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -32,8 +32,10 @@ import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.BroadcastMode import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD import org.apache.spark.rdd.RDDSuiteUtils._ +import org.apache.spark.storage.BroadcastBlockId import org.apache.spark.util.{ThreadUtils, Utils} class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { @@ -1161,6 +1163,58 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(totalPartitionCount == 10) } + test("executor side broadcast for RDD") { + // Materialize and cache the RDD to be broadcasted on executors. + val rdd = sc.parallelize(1 to 4, 2).cache() + rdd.count() + val mode = new BroadcastMode[Int] { + override def transform(rows: Array[Int]): Array[Int] = rows + override def transform(rows: Iterator[Int], sizeHint: Option[Long]): Array[Int] = rows.toArray + } + val broadcastedVal = sc.broadcast[Int, Array[Int]](rdd, mode) + val collected = sc.parallelize(1 to 2, 2).map { _ => + broadcastedVal.value.reduce(_ + _) // 1 + 2 + 3 + 4 = 10 + }.collect() + assert(broadcastedVal.value.sum == 10) + assert(collected.sum == 20) + } + + test("executor side broadcast for RDD: unbroadcast") { + // Materialize and cache the RDD to be broadcasted on executors. + val rdd = sc.parallelize(1 to 4, 2).cache() + rdd.count() + val mode = new BroadcastMode[Int] { + override def transform(rows: Array[Int]): Int = 1 + override def transform(rows: Iterator[Int], sizeHint: Option[Long]): Int = 1 + } + val broadcastedVal = sc.broadcast[Int, Int](rdd, mode) + val collected = sc.parallelize(1 to 2, 2).map { _ => + broadcastedVal.value + }.collect() + val blockId = BroadcastBlockId(broadcastedVal.id) + assert(sc.env.blockManager.getSingle(blockId).isDefined) + sc.env.blockManager.releaseLock(blockId) + // Unbroadcast it. + sc.env.broadcastManager.unbroadcast(broadcastedVal.id, true, true) + assert(sc.env.blockManager.getSingle(blockId).isEmpty) + } + + test("executor side broadcast for RDD: unpersist RDD") { + // Materialize and cache the RDD to be broadcasted on executors. + val rdd = sc.parallelize(1 to 4, 2).cache() + rdd.count() + val mode = new BroadcastMode[Int] { + override def transform(rows: Array[Int]): Int = 1 + override def transform(rows: Iterator[Int], sizeHint: Option[Long]): Int = 1 + } + val broadcastedVal = sc.broadcast[Int, Int](rdd, mode) + rdd.unpersist() + val collected = sc.parallelize(1 to 2, 2).map { _ => + broadcastedVal.value + }.collect() + assert(collected.sum == 2) + } + test("SPARK-18406: race between end-of-task and completion iterator read lock release") { val rdd = sc.parallelize(1 to 1000, 10) rdd.cache() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9fac95aed8f12..9307172aa3377 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -17,30 +17,25 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.broadcast.BroadcastMode import org.apache.spark.sql.catalyst.InternalRow /** * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index). */ -trait BroadcastMode { - def transform(rows: Array[InternalRow]): Any - - def transform(rows: Iterator[InternalRow], sizeHint: Option[Long]): Any - - def canonicalized: BroadcastMode +trait RowBroadcastMode extends BroadcastMode[InternalRow] { + override def canonicalized: RowBroadcastMode = this } /** * IdentityBroadcastMode requires that rows are broadcasted in their original form. */ -case object IdentityBroadcastMode extends BroadcastMode { +case object IdentityBroadcastMode extends RowBroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows override def transform( rows: Iterator[InternalRow], sizeHint: Option[Long]): Array[InternalRow] = rows.toArray - - override def canonicalized: BroadcastMode = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 17e1cb416fc8a..5d5811fb65abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -139,7 +140,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { +case class BroadcastDistribution(mode: RowBroadcastMode) extends Distribution { override def requiredNumPartitions: Option[Int] = Some(1) override def createPartitioning(numPartitions: Int): Partitioning = { @@ -332,7 +333,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) * Represents a partitioning where rows are collected, transformed and broadcasted to each * node in the cluster. */ -case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { +case class BroadcastPartitioning(mode: RowBroadcastMode) extends Partitioning { override val numPartitions: Int = 1 override def satisfies0(required: Distribution): Boolean = required match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 8003012f30ca5..b4839aee587bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -27,6 +27,7 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.broadcast.BroadcastMode import org.apache.spark.sql.catalyst.{AliasIdentifier, IdentifierWithDatabase} import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf @@ -822,7 +823,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case catalog: CatalogTable => true case partition: Partitioning => true case resource: FunctionResource => true - case broadcast: BroadcastMode => true + case broadcast: BroadcastMode[_] => true case table: CatalogTableType => true case storage: CatalogStorageFormat => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0d1a3e365c918..f48af90c3815e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -878,6 +878,16 @@ object SQLConf { .timeConf(TimeUnit.SECONDS) .createWithDefaultString(s"${5 * 60}") + val EXECUTOR_SIDE_BROADCAST_ENABLED = buildConf("spark.sql.executorSideBroadcast.enabled") + .doc("When true, we will use executor side broadcast for Broadcast-based join in sql. " + + "Notice that broadcasted pieces of data in executor-side broadcast are not persisted " + + "in the driver, but fetched from RDD pieces persisted in other executors. " + + "If one executor is lost before its piece is fetched by other executors, " + + "we can't recover it back and broadcasting will be failed. Thus it is not " + + "guaranteed completely safe when using with dynamic allocation.") + .booleanConf + .createWithDefault(true) + // This is only used for the thriftserver val THRIFTSERVER_POOL = buildConf("spark.sql.thriftserver.scheduler.pool") .doc("Set a Fair Scheduler pool for a JDBC client session.") @@ -3145,6 +3155,8 @@ class SQLConf extends Serializable with Logging { if (timeoutValue < 0) Long.MaxValue else timeoutValue } + def executorSideBroadcastEnabled: Boolean = getConf(EXECUTOR_SIDE_BROADCAST_ENABLED) + def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) def convertCTAS: Boolean = getConf(CONVERT_CTAS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index 6973f55e8dca0..a675c39952d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeSeq, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal, PredicateHelper} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.catalyst.plans.physical.RowBroadcastMode import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -40,7 +40,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) /** * Identify the shape in which keys of a given plan are broadcasted. */ - private def broadcastMode(keys: Seq[Expression], output: AttributeSeq): BroadcastMode = { + private def broadcastMode(keys: Seq[Expression], output: AttributeSeq): RowBroadcastMode = { val packedKeys = BindReferences.bindReferences(HashJoin.rewriteKeyExpr(keys), output) HashedRelationBroadcastMode(packedKeys) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 6d8d37022ea42..8fd95d3a967c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -22,6 +22,7 @@ import java.util.concurrent._ import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS +import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} @@ -30,11 +31,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastPartitioning, Partitioning, + RowBroadcastMode} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{SparkFatalException, ThreadUtils} @@ -68,19 +71,27 @@ trait BroadcastExchangeLike extends Exchange { /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of * a transformed SparkPlan. + * + * @tparam T The type of object transformed from the result of RDD by [[broadcast.BroadcastMode]]. */ -case class BroadcastExchangeExec( - mode: BroadcastMode, +case class BroadcastExchangeExec[T: ClassTag]( + mode: RowBroadcastMode, child: SparkPlan) extends BroadcastExchangeLike { import BroadcastExchangeExec._ override val runId: UUID = UUID.randomUUID - override lazy val metrics = Map( - "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), - "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"), - "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"), - "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")) + override lazy val metrics = if (sqlContext.conf.executorSideBroadcastEnabled) { + Map( + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"), + "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")) + } else { + Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"), + "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"), + "broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast")) + } override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) @@ -89,10 +100,83 @@ case class BroadcastExchangeExec( } override def runtimeStatistics: Statistics = { - val dataSize = metrics("dataSize").value + val dataSize = if (sqlContext.conf.executorSideBroadcastEnabled) { + Long.MaxValue + } else { + metrics("dataSize").value + } Statistics(dataSize) } + // Private variable used to hold the reference of RDD created during executor-side broadcasting. + // If we don't keep its reference, it will be cleaned up. + private var childRDD: RDD[InternalRow] = null + + private def executorSideBroadcast(): broadcast.Broadcast[Any] = { + val beforeBuild = System.nanoTime() + // Call persist on the RDD because we want to broadcast the RDD blocks on executors. + childRDD = child.execute().mapPartitionsInternal { rowIterator => + rowIterator.map(_.copy()) + }.persist(StorageLevel.MEMORY_AND_DISK) + + val numOfRows = childRDD.count() + if (numOfRows >= MAX_BROADCAST_TABLE_ROWS) { + throw new SparkException( + s"Cannot broadcast the table with more than 512 millions rows: ${numOfRows} rows") + } + + // Broadcast the relation on executors. + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) + + val broadcasted = sparkContext.broadcast[InternalRow, T](childRDD, mode) + .asInstanceOf[broadcast.Broadcast[Any]] + + longMetric("broadcastTime") += NANOSECONDS.toMillis(System.nanoTime() - beforeBroadcast) + broadcasted + } + + private def driverSideBroadcast(): broadcast.Broadcast[Any] = { + val beforeCollect = System.nanoTime() + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val (numRows, input) = child.executeCollectIterator() + if (numRows >= MAX_BROADCAST_TABLE_ROWS) { + throw new SparkException( + s"Cannot broadcast the table over $MAX_BROADCAST_TABLE_ROWS rows: $numRows rows") + } + + val beforeBuild = System.nanoTime() + longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) + + // Construct the relation. + val relation = mode.transform(input, Some(numRows)) + + val dataSize = relation match { + case map: HashedRelation => + map.estimatedSize + case arr: Array[InternalRow] => + arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + case _ => + throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " + + s"type: ${relation.getClass.getName}") + } + + longMetric("dataSize") += dataSize + if (dataSize >= MAX_BROADCAST_TABLE_BYTES) { + throw new SparkException( + s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + } + + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) + + // Broadcast the relation + val broadcasted = sparkContext.broadcast(relation) + longMetric("broadcastTime") += NANOSECONDS.toMillis( + System.nanoTime() - beforeBroadcast) + broadcasted + } + @transient private lazy val promise = Promise[broadcast.Broadcast[Any]]() @@ -111,43 +195,11 @@ case class BroadcastExchangeExec( // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", interruptOnCancel = true) - val beforeCollect = System.nanoTime() - // Use executeCollect/executeCollectIterator to avoid conversion to Scala types - val (numRows, input) = child.executeCollectIterator() - if (numRows >= MAX_BROADCAST_TABLE_ROWS) { - throw new SparkException( - s"Cannot broadcast the table over $MAX_BROADCAST_TABLE_ROWS rows: $numRows rows") + val broadcasted = if (sqlContext.conf.executorSideBroadcastEnabled) { + executorSideBroadcast() + } else { + driverSideBroadcast() } - - val beforeBuild = System.nanoTime() - longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) - - // Construct the relation. - val relation = mode.transform(input, Some(numRows)) - - val dataSize = relation match { - case map: HashedRelation => - map.estimatedSize - case arr: Array[InternalRow] => - arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum - case _ => - throw new SparkException("[BUG] BroadcastMode.transform returned unexpected " + - s"type: ${relation.getClass.getName}") - } - - longMetric("dataSize") += dataSize - if (dataSize >= MAX_BROADCAST_TABLE_BYTES) { - throw new SparkException( - s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") - } - - val beforeBroadcast = System.nanoTime() - longMetric("buildTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeBuild) - - // Broadcast the relation - val broadcasted = sparkContext.broadcast(relation) - longMetric("broadcastTime") += NANOSECONDS.toMillis( - System.nanoTime() - beforeBroadcast) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) promise.trySuccess(broadcasted) @@ -202,6 +254,8 @@ case class BroadcastExchangeExec( ex) } } + + override protected def otherCopyArgs: Seq[AnyRef] = Seq(implicitly[ClassTag[T]]) } object BroadcastExchangeExec { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b176598ed8c2c..c298dc6f851a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.execution.exchange import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode, + ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -48,7 +50,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) + mode match { + case IdentityBroadcastMode => BroadcastExchangeExec[Array[InternalRow]](mode, child) + case _: HashedRelationBroadcastMode => BroadcastExchangeExec[HashedRelation](mode, child) + } case (child, distribution) => val numPartitions = distribution.requiredNumPartitions .getOrElse(conf.numShufflePartitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 2a9e15851e9f1..4fe3dfa9e69f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -183,7 +184,13 @@ case class BroadcastHashJoinExec( // For inner and outer joins, one row from the streamed side may produce multiple result rows, // if the build side has duplicated keys. Note that here we wait for the broadcast to be // finished, which is a no-op because it's already finished when we wait it in `doProduce`. - !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique + if (SQLConf.get.executorSideBroadcastEnabled) { + logInfo(s"Executor side broadcast is enabled, assuming multiple result rows per input row") + true + } else { + logInfo(s"Fetching broadcasted build side to driver to check if key is unique") + !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique + } // Other joins types(semi, anti, existence) can at most produce one result row for one input // row from the streamed side. @@ -202,7 +209,7 @@ case class BroadcastHashJoinExec( // create a name for HashedRelation val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - val clsName = broadcastRelation.value.getClass.getName + val clsName = classOf[HashedRelation].getName // Inline mutable state since not many join operations in a task val relationTerm = ctx.addMutableState(clsName, "relation", @@ -215,9 +222,13 @@ case class BroadcastHashJoinExec( protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - HashedRelationInfo(relationTerm, + if (SQLConf.get.executorSideBroadcastEnabled) { + HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false) + } else { + HashedRelationInfo(relationTerm, broadcastRelation.value.keyIsUnique, broadcastRelation.value == EmptyHashedRelation) + } } /** @@ -229,24 +240,37 @@ case class BroadcastHashJoinExec( val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") - - if (broadcastRelation.value == EmptyHashedRelation) { + val executorSideBroadcast = SQLConf.get.executorSideBroadcastEnabled + val consumeCode = s""" - |// If the right side is empty, NAAJ simply returns the left side. |$numOutput.add(1); |${consume(ctx, input)} """.stripMargin - } else if (broadcastRelation.value == HashedRelationWithAllNullKeys) { + + if (!executorSideBroadcast && broadcastRelation.value == EmptyHashedRelation) { + s""" + |// If the right side is empty, NAAJ simply returns the left side. + |$consumeCode + """.stripMargin + } else if (!executorSideBroadcast && + broadcastRelation.value == HashedRelationWithAllNullKeys) { s""" |// If the right side contains any all-null key, NAAJ simply returns Nothing. """.stripMargin } else { s""" - |// generate join key for stream side - |${keyEv.code} - |if (!$anyNull && $relationTerm.getValue(${keyEv.value}) == null) { - | $numOutput.add(1); - | ${consume(ctx, input)} + |if ($relationTerm == ${EmptyHashedRelation.getClass.getCanonicalName}.MODULE$$) { + | // If the right side is empty, NAAJ simply returns the left side. + | $consumeCode + |} else if ($relationTerm + | == ${HashedRelationWithAllNullKeys.getClass.getCanonicalName}.MODULE$$) { + | // If the right side contains any all-null key, NAAJ simply returns Nothing. + |} else { + | // generate join key for stream side + | ${keyEv.code} + | if (!$anyNull && $relationTerm.getValue(${keyEv.value}) == null) { + | $consumeCode + | } |} """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3c5ed40551206..173c9e6db2524 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.config.{BUFFER_PAGESIZE, MEMORY_OFFHEAP_ENABLED import org.apache.spark.memory._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.catalyst.plans.physical.RowBroadcastMode import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap @@ -1127,7 +1127,7 @@ case object HashedRelationWithAllNullKeys extends HashedRelation { /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ case class HashedRelationBroadcastMode(key: Seq[Expression], isNullAware: Boolean = false) - extends BroadcastMode { + extends RowBroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { transform(rows.iterator, Some(rows.length)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index b463a76a74026..396762e3a7a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -344,7 +344,7 @@ class DataFrameJoinSuite extends QueryTest val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p } assert(broadcastHashJoins.size == 1) val broadcastExchanges = broadcastHashJoins.head.collect { - case p: BroadcastExchangeExec => p + case p: BroadcastExchangeExec[_] => p } assert(broadcastExchanges.size == 1) val tables = broadcastExchanges.head.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 321f4966178d7..933916bbd14e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1852,7 +1852,7 @@ class DataFrameSuite extends QueryTest case e: ShuffleExchangeExec => true }.size == 1) assert( collect(join2.queryExecution.executedPlan) { - case e: BroadcastExchangeExec => true }.size === 1) + case e: BroadcastExchangeExec[_] => true }.size === 1) assert( collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index 55437aaa47298..fcbd81dbfe3b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -193,7 +193,7 @@ abstract class DynamicPartitionPruningSuiteBase subqueryBroadcast.foreach { s => s.child match { case _: ReusedExchangeExec => // reuse check ok. - case b: BroadcastExchangeExec => + case b: BroadcastExchangeExec[_] => val hasReuse = plan.find { case ReusedExchangeExec(_, e) => e eq b case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index e5e8bc6917799..422dc8c2e3171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -733,11 +733,11 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { // be replaced. val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan)) MyShuffleExchangeExec(replaced.asInstanceOf[ShuffleExchangeExec]) - case e: BroadcastExchangeExec => + case e: BroadcastExchangeExec[_] => // note that this is not actually columnar but demonstrates that exchanges can // be replaced. val replaced = e.withNewChildren(e.children.map(replaceWithColumnarPlan)) - MyBroadcastExchangeExec(replaced.asInstanceOf[BroadcastExchangeExec]) + MyBroadcastExchangeExec(replaced.asInstanceOf[BroadcastExchangeExec[_]]) case plan: ProjectExec => new ColumnarProjectExec(plan.projectList.map((exp) => replaceWithColumnarExpression(exp).asInstanceOf[NamedExpression]), @@ -778,7 +778,8 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE * Custom Exchange used in tests to demonstrate that broadcasts can be replaced regardless of * whether AQE is enabled. */ -case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike { +case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec[_]) + extends BroadcastExchangeLike { override def runId: UUID = delegate.runId override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] = delegate.relationFuture diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala index 7d6306b65ff47..f0a23fdd1e0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BroadcastExchangeSuite.scala @@ -57,7 +57,7 @@ class BroadcastExchangeSuite extends SparkPlanTest // get the exchange physical plan val hashExchange = collect( - df.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p }.head + df.queryExecution.executedPlan) { case p: BroadcastExchangeExec[_] => p }.head // materialize the future and wait for the job being scheduled hashExchange.prepare() @@ -88,7 +88,7 @@ class BroadcastExchangeSuite extends SparkPlanTest val df = spark.range(1).toDF() val joinDF = df.join(broadcast(df), "id") val broadcastExchangeExec = collect( - joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec => p } + joinDF.queryExecution.executedPlan) { case p: BroadcastExchangeExec[_] => p } assert(broadcastExchangeExec.size == 1, "one and only BroadcastExchangeExec") assert(joinDF.collect().length == 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index fb97e15e4df63..5a13651f67f73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange._ -import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.vectorized.ColumnarBatch @@ -73,12 +73,12 @@ class ExchangeSuite extends SparkPlanTest with SharedSparkSession { val output = plan.output assert(plan sameResult plan) - val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) + val exchange1 = BroadcastExchangeExec[Array[InternalRow]](IdentityBroadcastMode, plan) val hashMode = HashedRelationBroadcastMode(output) - val exchange2 = BroadcastExchangeExec(hashMode, plan) + val exchange2 = BroadcastExchangeExec[HashedRelation](hashMode, plan) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) - val exchange3 = BroadcastExchangeExec(hashMode2, plan) + val exchange3 = BroadcastExchangeExec[HashedRelation](hashMode2, plan) val exchange4 = ReusedExchangeExec(output, exchange3) assert(exchange1 sameResult exchange1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala index 5bcec9b1e517c..3366ba7190cca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -84,7 +84,7 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv // The exchange related nodes are created after the planning, they don't have corresponding // logical plan. - case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => + case _: ShuffleExchangeExec | _: BroadcastExchangeExec[_] | _: ReusedExchangeExec => assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) // The subquery exec nodes are just wrappers of the actual nodes, they don't have diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 8799dbb14ef34..0a47bd6c16504 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -519,7 +519,7 @@ class AdaptiveQueryExecSuite // Even with local shuffle reader, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.nonEmpty) - assert(ex.head.child.isInstanceOf[BroadcastExchangeExec]) + assert(ex.head.child.isInstanceOf[BroadcastExchangeExec[_]]) val sub = findReusedSubquery(adaptivePlan) assert(sub.isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 44ab3f7d023d3..7bb68e2671594 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -129,26 +129,30 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } - testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ => + def usingBroadcastHashJoin(buildSide: BuildSide): Unit = { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, buildSide), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } } + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ => + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastHashJoin(BuildLeft) + } + } + } + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=right)") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastHashJoin(BuildRight) } } } @@ -198,21 +202,28 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { } } - test(s"$testName using BroadcastNestedLoopJoin build left") { + def usingBroadcastNestedLoopJoin(buildSide: BuildSide): Unit = { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())), + BroadcastNestedLoopJoinExec(left, right, buildSide, Inner, Some(condition())), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } } + test(s"$testName using BroadcastNestedLoopJoin build left") { + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastNestedLoopJoin(BuildLeft) + } + } + } + test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastNestedLoopJoin(BuildRight) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index a466e05816ad8..77c2746174dff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -120,20 +120,28 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } } + def usingBroadcastHashJoin(): Unit = { + val buildSide = joinType match { + case LeftOuter => BuildRight + case RightOuter => BuildLeft + case _ => fail(s"Unsupported join type $joinType") + } + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashJoinExec( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + if (joinType != FullOuter) { testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => - val buildSide = joinType match { - case LeftOuter => BuildRight - case RightOuter => BuildLeft - case _ => fail(s"Unsupported join type $joinType") - } - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + usingBroadcastHashJoin() } } } @@ -152,20 +160,28 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession { } test(s"$testName using BroadcastNestedLoopJoin build left") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } } test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq("true", "false").foreach { executorSideBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorSideBroadcast) { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4e10c27edb0e9..d06b1a5255d49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -166,25 +166,27 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // HashAggregate(nodeId = 4) // Exchange(nodeId = 5) // LocalTableScan(nodeId = 6) - Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(1).groupBy('a).count() - val nodeIds = if (enableWholeStage) { - Set(4L, 1L) - } else { - Set(2L, 0L) - } - val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString - if (!probes.contains("\n")) { - // It's a single metrics value - assert(probes.toDouble > 1.0) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) } else { - val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") - // Extract min, med, max from the string and strip off everthing else. - val index = mainValue.indexOf(" (", 0) - mainValue.slice(0, index).split(", ").foreach { - probe => assert(probe.toDouble > 1.0) + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString + if (!probes.contains("\n")) { + // It's a single metrics value + assert(probes.toDouble > 1.0) + } else { + val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") + // Extract min, med, max from the string and strip off everthing else. + val index = mainValue.indexOf(" (", 0) + mainValue.slice(0, index).split(", ").foreach { + probe => assert(probe.toDouble > 1.0) + } } } } @@ -301,11 +303,13 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) Seq((1L, false), (2L, true)).foreach { case (nodeId, enableWholeStage) => val df = df1.join(broadcast(df2), "key") - testSparkPlanMetrics(df, 2, Map( - nodeId -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L)))), - enableWholeStage - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + testSparkPlanMetrics(df, 2, Map( + nodeId -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L)))), + enableWholeStage + ) + } } } @@ -405,33 +409,37 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils Seq(("left_outer", 0L, 5L, false), ("right_outer", 0L, 6L, false), ("left_outer", 1L, 5L, true), ("right_outer", 1L, 6L, true)).foreach { case (joinType, nodeId, numRows, enableWholeStage) => val df = df1.join(broadcast(df2), $"key" === $"key2", joinType) - testSparkPlanMetrics(df, 2, Map( - nodeId -> (("BroadcastHashJoin", Map( - "number of output rows" -> numRows)))), - enableWholeStage - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + testSparkPlanMetrics(df, 2, Map( + nodeId -> (("BroadcastHashJoin", Map( + "number of output rows" -> numRows)))), + enableWholeStage + ) + } } } test("BroadcastNestedLoopJoin metrics") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") - withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - withTempView("testDataForJoin") { - // Assume the execution plan is - // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val leftQuery = "SELECT * FROM testData2 LEFT JOIN testDataForJoin ON " + - "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" - val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " + - "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" - Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true)) - .foreach { case (query, enableWholeStage) => - val df = spark.sql(query) - testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastNestedLoopJoin", Map( - "number of output rows" -> 12L)))), - enableWholeStage - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + withTempView("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val leftQuery = "SELECT * FROM testData2 LEFT JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" + val rightQuery = "SELECT * FROM testData2 RIGHT JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a" + Seq((leftQuery, false), (rightQuery, false), (leftQuery, true), (rightQuery, true)) + .foreach { case (query, enableWholeStage) => + val df = spark.sql(query) + testSparkPlanMetrics(df, 2, Map( + 0L -> (("BroadcastNestedLoopJoin", Map( + "number of output rows" -> 12L)))), + enableWholeStage + ) + } } } } @@ -444,11 +452,13 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // ... -> BroadcastHashJoin(nodeId = 1) Seq((1L, false), (2L, true)).foreach { case (nodeId, enableWholeStage) => val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - nodeId -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L)))), - enableWholeStage - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + testSparkPlanMetrics(df, 2, Map( + nodeId -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L)))), + enableWholeStage + ) + } } } @@ -457,11 +467,13 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") Seq((1L, false), (2L, true)).foreach { case (nodeId, enableWholeStage) => val df = df2.join(broadcast(df1), $"key" === $"key2", "left_anti") - testSparkPlanMetrics(df, 2, Map( - nodeId -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L)))), - enableWholeStage - ) + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> "false") { + testSparkPlanMetrics(df, 2, Map( + nodeId -> (("BroadcastHashJoin", Map( + "number of output rows" -> 2L)))), + enableWholeStage + ) + } } } @@ -557,30 +569,42 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // TODO: test file source V2 as well when its statistics is correctly computed. withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { - withTempDir { tempDir => - withTempView("pqS") { - val dir = new File(tempDir, "pqS").getCanonicalPath - - spark.range(10).write.parquet(dir) - spark.read.parquet(dir).createOrReplaceTempView("pqS") - - // The executed plan looks like: - // Exchange RoundRobinPartitioning(2) - // +- BroadcastNestedLoopJoin BuildLeft, Cross - // :- BroadcastExchange IdentityBroadcastMode - // : +- Exchange RoundRobinPartitioning(3) - // : +- *Range (0, 30, step=1, splits=2) - // +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, Location: ...(ignored) - val res3 = InputOutputMetricsHelper.run( - spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() - ) - // The query above is executed in the following stages: - // 1. range(30) => (30, 0, 30) - // 2. sql("select * from pqS") => (0, 30, 0) - // 3. crossJoin(...) of 1. and 2. => (10, 0, 300) - // 4. shuffle & return results => (0, 300, 0) - assert(res3 === (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: (0L, 300L, 0L) :: - Nil) + Seq(true, false).foreach { executorBroadcast => + withSQLConf(SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED.key -> executorBroadcast.toString) { + withTempDir { tempDir => + withTempView("pqS") { + val dir = new File(tempDir, "pqS").getCanonicalPath + + spark.range(10).write.parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("pqS") + + // The executed plan looks like: + // Exchange RoundRobinPartitioning(2) + // +- BroadcastNestedLoopJoin BuildLeft, Cross + // :- BroadcastExchange IdentityBroadcastMode + // : +- Exchange RoundRobinPartitioning(3) + // : +- *Range (0, 30, step=1, splits=2) + // +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, Location: ... + val res3 = InputOutputMetricsHelper.run( + spark.range(30).repartition(3) + .crossJoin(sql("select * from pqS")).repartition(2).toDF() + ) + // The query above is executed in the following stages: + // 1. range(30) => (30, 0, 30) + // 2a. sql("select * from pqS") => (0, 30, 0) + // 2b. (only when `SQLConf.EXECUTOR_SIDE_BROADCAST_ENABLED` is enabled) + // executor-side-broadcast => (0, 0, 0) + // 3. crossJoin(...) of 1. and 2. => (10, 0, 300) + // 4. shuffle & return results => (0, 300, 0) + val expected = if (executorBroadcast) { + (30L, 0L, 30L) :: (0L, 30L, 0L) :: (0L, 0L, 0L) :: (10L, 0L, 300L) :: + (0L, 300L, 0L) :: Nil + } else { + (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: (0L, 300L, 0L) :: Nil + } + assert(res3 === expected) + } + } } } }