Skip to content

Commit 5461f1c

Browse files
committed
Merge pull request #8 from tdas/kafka-refactor3
Refactor 3 = Refactor 2 + refactored KafkaStreamSuite further to elimite KafkaTestUtils, and made Java testsuite more robust
2 parents 2a20a01 + eae4ad6 commit 5461f1c

File tree

7 files changed

+212
-184
lines changed

7 files changed

+212
-184
lines changed

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,23 @@ class ReliableKafkaReceiver[
178178

179179
/** Store a Kafka message and the associated metadata as a tuple. */
180180
private def storeMessageAndMetadata(
181-
msgAndMetadata: MessageAndMetadata[K, V]): Unit = synchronized {
181+
msgAndMetadata: MessageAndMetadata[K, V]): Unit = {
182182
val topicAndPartition = TopicAndPartition(msgAndMetadata.topic, msgAndMetadata.partition)
183-
blockGenerator += ((msgAndMetadata.key, msgAndMetadata.message))
184-
topicPartitionOffsetMap.put(topicAndPartition, msgAndMetadata.offset)
183+
val data = (msgAndMetadata.key, msgAndMetadata.message)
184+
val metadata = (topicAndPartition, msgAndMetadata.offset)
185+
blockGenerator.addDataWithCallback(data, metadata)
186+
}
187+
188+
/** Update stored offset */
189+
private def updateOffset(topicAndPartition: TopicAndPartition, offset: Long): Unit = {
190+
topicPartitionOffsetMap.put(topicAndPartition, offset)
185191
}
186192

187193
/**
188194
* Remember the current offsets for each topic and partition. This is called when a block is
189195
* generated.
190196
*/
191-
private def rememberBlockOffsets(blockId: StreamBlockId): Unit = synchronized {
197+
private def rememberBlockOffsets(blockId: StreamBlockId): Unit = {
192198
// Get a snapshot of current offset map and store with related block id.
193199
val offsetSnapshot = topicPartitionOffsetMap.toMap
194200
blockOffsetMap.put(blockId, offsetSnapshot)
@@ -221,8 +227,9 @@ class ReliableKafkaReceiver[
221227

222228
ZkUtils.updatePersistentPath(zkClient, zkPath, offset.toString)
223229
} catch {
224-
case t: Throwable => logWarning(s"Exception during commit offset $offset for topic" +
225-
s"${topicAndPart.topic}, partition ${topicAndPart.partition}", t)
230+
case e: Exception =>
231+
logWarning(s"Exception during commit offset $offset for topic" +
232+
s"${topicAndPart.topic}, partition ${topicAndPart.partition}", e)
226233
}
227234

228235
logInfo(s"Committed offset $offset for topic ${topicAndPart.topic}, " +
@@ -250,17 +257,25 @@ class ReliableKafkaReceiver[
250257
/** Class to handle blocks generated by the block generator. */
251258
private final class GeneratedBlockHandler extends BlockGeneratorListener {
252259

253-
override def onGenerateBlock(blockId: StreamBlockId): Unit = {
260+
def onAddData(data: Any, metadata: Any): Unit = {
261+
// Update the offset of the data that was added to the generator
262+
if (metadata != null) {
263+
val (topicAndPartition, offset) = metadata.asInstanceOf[(TopicAndPartition, Long)]
264+
updateOffset(topicAndPartition, offset)
265+
}
266+
}
267+
268+
def onGenerateBlock(blockId: StreamBlockId): Unit = {
254269
// Remember the offsets of topics/partitions when a block has been generated
255270
rememberBlockOffsets(blockId)
256271
}
257272

258-
override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
273+
def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
259274
// Store block and commit the blocks offset
260275
storeBlockAndCommitOffset(blockId, arrayBuffer)
261276
}
262277

263-
override def onError(message: String, throwable: Throwable): Unit = {
278+
def onError(message: String, throwable: Throwable): Unit = {
264279
reportError(message, throwable)
265280
}
266281
}

external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import java.util.List;
2323
import java.util.Random;
2424

25+
import org.apache.spark.SparkConf;
26+
import org.apache.spark.streaming.Duration;
2527
import scala.Predef;
2628
import scala.Tuple2;
2729
import scala.collection.JavaConverters;
@@ -43,23 +45,25 @@
4345

4446
public class JavaKafkaStreamSuite implements Serializable {
4547
private transient JavaStreamingContext ssc = null;
46-
private Random random = new Random();
48+
private transient Random random = new Random();
4749
private transient KafkaStreamSuiteBase suiteBase = null;
4850

4951
@Before
5052
public void setUp() {
5153
suiteBase = new KafkaStreamSuiteBase() { };
52-
suiteBase.beforeFunction();
54+
suiteBase.setupKafka();
5355
System.clearProperty("spark.driver.port");
54-
ssc = new JavaStreamingContext(suiteBase.sparkConf(), suiteBase.batchDuration());
56+
SparkConf sparkConf = new SparkConf()
57+
.setMaster("local[4]").setAppName(this.getClass().getSimpleName());
58+
ssc = new JavaStreamingContext(sparkConf, new Duration(500));
5559
}
5660

5761
@After
5862
public void tearDown() {
5963
ssc.stop();
6064
ssc = null;
6165
System.clearProperty("spark.driver.port");
62-
suiteBase.afterFunction();
66+
suiteBase.tearDownKafka();
6367
}
6468

6569
@Test
@@ -76,8 +80,8 @@ public void testKafkaStream() throws InterruptedException {
7680
suiteBase.createTopic(topic);
7781
HashMap<String, Object> tmp = new HashMap<String, Object>(sent);
7882
suiteBase.produceAndSendMessage(topic,
79-
JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
80-
Predef.<Tuple2<String, Object>>conforms()));
83+
JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
84+
Predef.<Tuple2<String, Object>>conforms()));
8185

8286
HashMap<String, String> kafkaParams = new HashMap<String, String>();
8387
kafkaParams.put("zookeeper.connect", suiteBase.zkAddress());
@@ -123,11 +127,16 @@ public Void call(JavaPairRDD<String, Long> rdd) throws Exception {
123127
);
124128

125129
ssc.start();
126-
ssc.awaitTermination(3000);
127-
130+
long startTime = System.currentTimeMillis();
131+
boolean sizeMatches = false;
132+
while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) {
133+
sizeMatches = sent.size() == result.size();
134+
Thread.sleep(200);
135+
}
128136
Assert.assertEquals(sent.size(), result.size());
129137
for (String k : sent.keySet()) {
130138
Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue());
131139
}
140+
ssc.stop();
132141
}
133142
}

external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala

Lines changed: 73 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,12 @@ import org.apache.spark.storage.StorageLevel
4242
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
4343
import org.apache.spark.util.Utils
4444

45-
abstract class KafkaStreamSuiteBase extends FunSuite with Logging {
46-
import KafkaTestUtils._
45+
/**
46+
* This is an abstract base class for Kafka testsuites. This has the functionality to set up
47+
* and tear down local Kafka servers, and to push data using Kafka producers.
48+
*/
49+
abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging {
4750

48-
val sparkConf = new SparkConf()
49-
.setMaster("local[4]")
50-
.setAppName(this.getClass.getSimpleName)
51-
val batchDuration = Milliseconds(500)
52-
var ssc: StreamingContext = _
53-
5451
var zkAddress: String = _
5552
var zkClient: ZkClient = _
5653

@@ -64,7 +61,7 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Logging {
6461
private var server: KafkaServer = _
6562
private var producer: Producer[String, String] = _
6663

67-
def beforeFunction() {
64+
def setupKafka() {
6865
// Zookeeper server startup
6966
zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
7067
// Get the actual zookeeper binding port
@@ -80,7 +77,7 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Logging {
8077
var bindSuccess: Boolean = false
8178
while(!bindSuccess) {
8279
try {
83-
val brokerProps = getBrokerConfig(brokerPort, zkAddress)
80+
val brokerProps = getBrokerConfig()
8481
brokerConf = new KafkaConfig(brokerProps)
8582
server = new KafkaServer(brokerConf)
8683
logInfo("==================== 2 ====================")
@@ -100,12 +97,7 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Logging {
10097
logInfo("==================== 4 ====================")
10198
}
10299

103-
def afterFunction() {
104-
if (ssc != null) {
105-
ssc.stop()
106-
ssc = null
107-
}
108-
100+
def tearDownKafka() {
109101
if (producer != null) {
110102
producer.close()
111103
producer = null
@@ -141,101 +133,43 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Logging {
141133
CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0")
142134
logInfo("==================== 5 ====================")
143135
// wait until metadata is propagated
144-
waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000)
136+
waitUntilMetadataIsPropagated(topic, 0)
145137
}
146138

147139
def produceAndSendMessage(topic: String, sent: Map[String, Int]) {
148-
val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
149-
if (producer == null) {
150-
producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr)))
151-
}
140+
producer = new Producer[String, String](new ProducerConfig(getProducerConfig()))
152141
producer.send(createTestMessage(topic, sent): _*)
142+
producer.close()
153143
logInfo("==================== 6 ====================")
154144
}
155-
}
156-
157-
class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually {
158-
159-
before { beforeFunction() }
160-
after { afterFunction() }
161-
162-
test("Kafka input stream") {
163-
ssc = new StreamingContext(sparkConf, batchDuration)
164-
val topic = "topic1"
165-
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
166-
createTopic(topic)
167-
produceAndSendMessage(topic, sent)
168-
169-
val kafkaParams = Map("zookeeper.connect" -> zkAddress,
170-
"group.id" -> s"test-consumer-${Random.nextInt(10000)}",
171-
"auto.offset.reset" -> "smallest")
172-
173-
val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
174-
ssc,
175-
kafkaParams,
176-
Map(topic -> 1),
177-
StorageLevel.MEMORY_ONLY)
178-
val result = new mutable.HashMap[String, Long]()
179-
stream.map { case (k, v) => v }
180-
.countByValue()
181-
.foreachRDD { r =>
182-
val ret = r.collect()
183-
ret.toMap.foreach { kv =>
184-
val count = result.getOrElseUpdate(kv._1, 0) + kv._2
185-
result.put(kv._1, count)
186-
}
187-
}
188-
ssc.start()
189-
eventually(timeout(3000 milliseconds), interval(100 milliseconds)) {
190-
assert(sent.size === result.size)
191-
sent.keys.foreach { k => assert(sent(k) === result(k).toInt) }
192-
}
193-
194-
ssc.stop()
195-
}
196-
}
197-
198145

199-
object KafkaTestUtils {
200-
201-
def getBrokerConfig(port: Int, zkConnect: String): Properties = {
146+
private def getBrokerConfig(): Properties = {
202147
val props = new Properties()
203148
props.put("broker.id", "0")
204149
props.put("host.name", "localhost")
205-
props.put("port", port.toString)
150+
props.put("port", brokerPort.toString)
206151
props.put("log.dir", Utils.createTempDir().getAbsolutePath)
207-
props.put("zookeeper.connect", zkConnect)
152+
props.put("zookeeper.connect", zkAddress)
208153
props.put("log.flush.interval.messages", "1")
209154
props.put("replica.socket.timeout.ms", "1500")
210155
props
211156
}
212157

213-
def getProducerConfig(brokerList: String): Properties = {
158+
private def getProducerConfig(): Properties = {
159+
val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
214160
val props = new Properties()
215-
props.put("metadata.broker.list", brokerList)
161+
props.put("metadata.broker.list", brokerAddr)
216162
props.put("serializer.class", classOf[StringEncoder].getName)
217163
props
218164
}
219165

220-
def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = {
221-
val startTime = System.currentTimeMillis()
222-
while (true) {
223-
if (condition())
224-
return true
225-
if (System.currentTimeMillis() > startTime + waitTime)
226-
return false
227-
Thread.sleep(waitTime.min(100L))
166+
private def waitUntilMetadataIsPropagated(topic: String, partition: Int) {
167+
eventually(timeout(1000 milliseconds), interval(100 milliseconds)) {
168+
assert(
169+
server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)),
170+
s"Partition [$topic, $partition] metadata not propagated after timeout"
171+
)
228172
}
229-
// Should never go to here
230-
throw new RuntimeException("unexpected error")
231-
}
232-
233-
def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int,
234-
timeout: Long) {
235-
assert(waitUntilTrue(() =>
236-
servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains(
237-
TopicAndPartition(topic, partition))), timeout),
238-
s"Partition [$topic, $partition] metadata not propagated after timeout")
239173
}
240174

241175
class EmbeddedZookeeper(val zkConnect: String) {
@@ -261,3 +195,53 @@ object KafkaTestUtils {
261195
}
262196
}
263197
}
198+
199+
200+
class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
201+
var ssc: StreamingContext = _
202+
203+
before {
204+
setupKafka()
205+
}
206+
207+
after {
208+
if (ssc != null) {
209+
ssc.stop()
210+
ssc = null
211+
}
212+
tearDownKafka()
213+
}
214+
215+
test("Kafka input stream") {
216+
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
217+
ssc = new StreamingContext(sparkConf, Milliseconds(500))
218+
val topic = "topic1"
219+
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
220+
createTopic(topic)
221+
produceAndSendMessage(topic, sent)
222+
223+
val kafkaParams = Map("zookeeper.connect" -> zkAddress,
224+
"group.id" -> s"test-consumer-${Random.nextInt(10000)}",
225+
"auto.offset.reset" -> "smallest")
226+
227+
val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
228+
ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY)
229+
val result = new mutable.HashMap[String, Long]()
230+
stream.map(_._2).countByValue().foreachRDD { r =>
231+
val ret = r.collect()
232+
ret.toMap.foreach { kv =>
233+
val count = result.getOrElseUpdate(kv._1, 0) + kv._2
234+
result.put(kv._1, count)
235+
}
236+
}
237+
ssc.start()
238+
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
239+
assert(sent.size === result.size)
240+
sent.keys.foreach { k =>
241+
assert(sent(k) === result(k).toInt)
242+
}
243+
}
244+
ssc.stop()
245+
}
246+
}
247+

0 commit comments

Comments
 (0)