Skip to content

Commit c1b62e4

Browse files
gengliangwanggatorsmile
authored andcommitted
[SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated methods
## What changes were proposed in this pull request? Improve Avro unit test: 1. use QueryTest/SharedSQLContext/SQLTestUtils, instead of the duplicated test utils. 2. replace deprecated methods ## How was this patch tested? Unit test Author: Gengliang Wang <[email protected]> Closes #21760 from gengliangwang/improve_avro_test.
1 parent dfd7ac9 commit c1b62e4

File tree

2 files changed

+53
-217
lines changed

2 files changed

+53
-217
lines changed

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 53 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
3131
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
3232
import org.apache.commons.io.FileUtils
3333

34-
import org.apache.spark.SparkFunSuite
3534
import org.apache.spark.sql._
3635
import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException
36+
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
3737
import org.apache.spark.sql.types._
3838

39-
class AvroSuite extends SparkFunSuite {
39+
class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
4040
val episodesFile = "src/test/resources/episodes.avro"
4141
val testFile = "src/test/resources/test.avro"
4242

43-
private var spark: SparkSession = _
44-
4543
override protected def beforeAll(): Unit = {
4644
super.beforeAll()
47-
spark = SparkSession.builder()
48-
.master("local[2]")
49-
.appName("AvroSuite")
50-
.config("spark.sql.files.maxPartitionBytes", 1024)
51-
.getOrCreate()
52-
}
53-
54-
override protected def afterAll(): Unit = {
55-
try {
56-
spark.sparkContext.stop()
57-
} finally {
58-
super.afterAll()
59-
}
45+
spark.conf.set("spark.sql.files.maxPartitionBytes", 1024)
46+
}
47+
48+
def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = {
49+
val originalEntries = spark.read.avro(testFile).collect()
50+
val newEntries = spark.read.avro(newFile)
51+
checkAnswer(newEntries, originalEntries)
6052
}
6153

6254
test("reading from multiple paths") {
@@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite {
6860
val df = spark.read.avro(episodesFile)
6961
val fields = List("title", "air_date", "doctor")
7062
for (field <- fields) {
71-
TestUtils.withTempDir { dir =>
63+
withTempPath { dir =>
7264
val outputDir = s"$dir/${UUID.randomUUID}"
7365
df.write.partitionBy(field).avro(outputDir)
7466
val input = spark.read.avro(outputDir)
@@ -82,28 +74,29 @@ class AvroSuite extends SparkFunSuite {
8274

8375
test("request no fields") {
8476
val df = spark.read.avro(episodesFile)
85-
df.registerTempTable("avro_table")
77+
df.createOrReplaceTempView("avro_table")
8678
assert(spark.sql("select count(*) from avro_table").collect().head === Row(8))
8779
}
8880

8981
test("convert formats") {
90-
TestUtils.withTempDir { dir =>
82+
withTempPath { dir =>
9183
val df = spark.read.avro(episodesFile)
9284
df.write.parquet(dir.getCanonicalPath)
9385
assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
9486
}
9587
}
9688

9789
test("rearrange internal schema") {
98-
TestUtils.withTempDir { dir =>
90+
withTempPath { dir =>
9991
val df = spark.read.avro(episodesFile)
10092
df.select("doctor", "title").write.avro(dir.getCanonicalPath)
10193
}
10294
}
10395

10496
test("test NULL avro type") {
105-
TestUtils.withTempDir { dir =>
106-
val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava
97+
withTempPath { dir =>
98+
val fields =
99+
Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava
107100
val schema = Schema.createRecord("name", "docs", "namespace", false)
108101
schema.setFields(fields)
109102
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
@@ -122,11 +115,11 @@ class AvroSuite extends SparkFunSuite {
122115
}
123116

124117
test("union(int, long) is read as long") {
125-
TestUtils.withTempDir { dir =>
118+
withTempPath { dir =>
126119
val avroSchema: Schema = {
127120
val union =
128121
Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava)
129-
val fields = Seq(new Field("field1", union, "doc", null)).asJava
122+
val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
130123
val schema = Schema.createRecord("name", "docs", "namespace", false)
131124
schema.setFields(fields)
132125
schema
@@ -150,11 +143,11 @@ class AvroSuite extends SparkFunSuite {
150143
}
151144

152145
test("union(float, double) is read as double") {
153-
TestUtils.withTempDir { dir =>
146+
withTempPath { dir =>
154147
val avroSchema: Schema = {
155148
val union =
156149
Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava)
157-
val fields = Seq(new Field("field1", union, "doc", null)).asJava
150+
val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
158151
val schema = Schema.createRecord("name", "docs", "namespace", false)
159152
schema.setFields(fields)
160153
schema
@@ -178,15 +171,15 @@ class AvroSuite extends SparkFunSuite {
178171
}
179172

180173
test("union(float, double, null) is read as nullable double") {
181-
TestUtils.withTempDir { dir =>
174+
withTempPath { dir =>
182175
val avroSchema: Schema = {
183176
val union = Schema.createUnion(
184177
List(Schema.create(Type.FLOAT),
185178
Schema.create(Type.DOUBLE),
186179
Schema.create(Type.NULL)
187180
).asJava
188181
)
189-
val fields = Seq(new Field("field1", union, "doc", null)).asJava
182+
val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava
190183
val schema = Schema.createRecord("name", "docs", "namespace", false)
191184
schema.setFields(fields)
192185
schema
@@ -210,9 +203,9 @@ class AvroSuite extends SparkFunSuite {
210203
}
211204

212205
test("Union of a single type") {
213-
TestUtils.withTempDir { dir =>
206+
withTempPath { dir =>
214207
val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava)
215-
val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava
208+
val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava
216209
val schema = Schema.createRecord("name", "docs", "namespace", false)
217210
schema.setFields(fields)
218211

@@ -233,16 +226,16 @@ class AvroSuite extends SparkFunSuite {
233226
}
234227

235228
test("Complex Union Type") {
236-
TestUtils.withTempDir { dir =>
229+
withTempPath { dir =>
237230
val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4)
238231
val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava)
239232
val complexUnionType = Schema.createUnion(
240233
List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava)
241234
val fields = Seq(
242-
new Field("field1", complexUnionType, "doc", null),
243-
new Field("field2", complexUnionType, "doc", null),
244-
new Field("field3", complexUnionType, "doc", null),
245-
new Field("field4", complexUnionType, "doc", null)
235+
new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]),
236+
new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]),
237+
new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]),
238+
new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any])
246239
).asJava
247240
val schema = Schema.createRecord("name", "docs", "namespace", false)
248241
schema.setFields(fields)
@@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite {
271264
}
272265

273266
test("Lots of nulls") {
274-
TestUtils.withTempDir { dir =>
267+
withTempPath { dir =>
275268
val schema = StructType(Seq(
276269
StructField("binary", BinaryType, true),
277270
StructField("timestamp", TimestampType, true),
@@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite {
290283
}
291284

292285
test("Struct field type") {
293-
TestUtils.withTempDir { dir =>
286+
withTempPath { dir =>
294287
val schema = StructType(Seq(
295288
StructField("float", FloatType, true),
296289
StructField("short", ShortType, true),
@@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite {
309302
}
310303

311304
test("Date field type") {
312-
TestUtils.withTempDir { dir =>
305+
withTempPath { dir =>
313306
val schema = StructType(Seq(
314307
StructField("float", FloatType, true),
315308
StructField("date", DateType, true)
@@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite {
329322
}
330323

331324
test("Array data types") {
332-
TestUtils.withTempDir { dir =>
325+
withTempPath { dir =>
333326
val testSchema = StructType(Seq(
334327
StructField("byte_array", ArrayType(ByteType), true),
335328
StructField("short_array", ArrayType(ShortType), true),
@@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite {
363356
}
364357

365358
test("write with compression") {
366-
TestUtils.withTempDir { dir =>
359+
withTempPath { dir =>
367360
val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
368361
val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level"
369362
val uncompressDir = s"$dir/uncompress"
370363
val deflateDir = s"$dir/deflate"
371364
val snappyDir = s"$dir/snappy"
372-
val fakeDir = s"$dir/fake"
373365

374366
val df = spark.read.avro(testFile)
375367
spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed")
@@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite {
439431
test("sql test") {
440432
spark.sql(
441433
s"""
442-
|CREATE TEMPORARY TABLE avroTable
434+
|CREATE TEMPORARY VIEW avroTable
443435
|USING avro
444436
|OPTIONS (path "$episodesFile")
445437
""".stripMargin.replaceAll("\n", " "))
@@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite {
450442
test("conversion to avro and back") {
451443
// Note that test.avro includes a variety of types, some of which are nullable. We expect to
452444
// get the same values back.
453-
TestUtils.withTempDir { dir =>
445+
withTempPath { dir =>
454446
val avroDir = s"$dir/avro"
455447
spark.read.avro(testFile).write.avro(avroDir)
456-
TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
448+
checkReloadMatchesSaved(testFile, avroDir)
457449
}
458450
}
459451

460452
test("conversion to avro and back with namespace") {
461453
// Note that test.avro includes a variety of types, some of which are nullable. We expect to
462454
// get the same values back.
463-
TestUtils.withTempDir { tempDir =>
455+
withTempPath { tempDir =>
464456
val name = "AvroTest"
465457
val namespace = "com.databricks.spark.avro"
466458
val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)
467459

468460
val avroDir = tempDir + "/namedAvro"
469461
spark.read.avro(testFile).write.options(parameters).avro(avroDir)
470-
TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir)
462+
checkReloadMatchesSaved(testFile, avroDir)
471463

472464
// Look at raw file and make sure has namespace info
473465
val rawSaved = spark.sparkContext.textFile(avroDir)
@@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite {
478470
}
479471

480472
test("converting some specific sparkSQL types to avro") {
481-
TestUtils.withTempDir { tempDir =>
473+
withTempPath { tempDir =>
482474
val testSchema = StructType(Seq(
483475
StructField("Name", StringType, false),
484476
StructField("Length", IntegerType, true),
@@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite {
520512
}
521513

522514
test("correctly read long as date/timestamp type") {
523-
TestUtils.withTempDir { tempDir =>
515+
withTempPath { tempDir =>
524516
val sparkSession = spark
525517
import sparkSession.implicits._
526518

@@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite {
549541
}
550542

551543
test("does not coerce null date/timestamp value to 0 epoch.") {
552-
TestUtils.withTempDir { tempDir =>
544+
withTempPath { tempDir =>
553545
val sparkSession = spark
554546
import sparkSession.implicits._
555547

@@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite {
610602

611603
// Directory given has no avro files
612604
intercept[AnalysisException] {
613-
TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath))
605+
withTempPath(dir => spark.read.avro(dir.getCanonicalPath))
614606
}
615607

616608
intercept[AnalysisException] {
@@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite {
624616
}
625617

626618
intercept[FileNotFoundException] {
627-
TestUtils.withTempDir { dir =>
619+
withTempPath { dir =>
628620
FileUtils.touch(new File(dir, "test"))
629621
spark.read.avro(dir.toString)
630622
}
@@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite {
633625
}
634626

635627
test("SQL test insert overwrite") {
636-
TestUtils.withTempDir { tempDir =>
628+
withTempPath { tempDir =>
637629
val tempEmptyDir = s"$tempDir/sqlOverwrite"
638630
// Create a temp directory for table that will be overwritten
639631
new File(tempEmptyDir).mkdirs()
640632
spark.sql(
641633
s"""
642-
|CREATE TEMPORARY TABLE episodes
634+
|CREATE TEMPORARY VIEW episodes
643635
|USING avro
644636
|OPTIONS (path "$episodesFile")
645637
""".stripMargin.replaceAll("\n", " "))
646638
spark.sql(
647639
s"""
648-
|CREATE TEMPORARY TABLE episodesEmpty
640+
|CREATE TEMPORARY VIEW episodesEmpty
649641
|(name string, air_date string, doctor int)
650642
|USING avro
651643
|OPTIONS (path "$tempEmptyDir")
@@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite {
665657

666658
test("test save and load") {
667659
// Test if load works as expected
668-
TestUtils.withTempDir { tempDir =>
660+
withTempPath { tempDir =>
669661
val df = spark.read.avro(episodesFile)
670662
assert(df.count == 8)
671663

@@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite {
679671

680672
test("test load with non-Avro file") {
681673
// Test if load works as expected
682-
TestUtils.withTempDir { tempDir =>
674+
withTempPath { tempDir =>
683675
val df = spark.read.avro(episodesFile)
684676
assert(df.count == 8)
685677

@@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite {
737729
}
738730

739731
test("read avro file partitioned") {
740-
TestUtils.withTempDir { dir =>
732+
withTempPath { dir =>
741733
val sparkSession = spark
742734
import sparkSession.implicits._
743735
val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records")
@@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite {
756748
case class NestedTop(id: Int, data: NestedMiddle)
757749

758750
test("saving avro that has nested records with the same name") {
759-
TestUtils.withTempDir { tempDir =>
751+
withTempPath { tempDir =>
760752
// Save avro file on output folder path
761753
val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1")))))
762754
val outputFolder = s"$tempDir/duplicate_names/"
@@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite {
773765
case class NestedTopArray(id: Int, data: NestedMiddleArray)
774766

775767
test("saving avro that has nested records with the same name inside an array") {
776-
TestUtils.withTempDir { tempDir =>
768+
withTempPath { tempDir =>
777769
// Save avro file on output folder path
778770
val writeDf = spark.createDataFrame(
779771
List(NestedTopArray(1, NestedMiddleArray(2, Array(
@@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite {
794786
case class NestedTopMap(id: Int, data: NestedMiddleMap)
795787

796788
test("saving avro that has nested records with the same name inside a map") {
797-
TestUtils.withTempDir { tempDir =>
789+
withTempPath { tempDir =>
798790
// Save avro file on output folder path
799791
val writeDf = spark.createDataFrame(
800792
List(NestedTopMap(1, NestedMiddleMap(2, Map(

0 commit comments

Comments
 (0)