@@ -31,32 +31,24 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
3131import org .apache .avro .generic .GenericData .{EnumSymbol , Fixed }
3232import org .apache .commons .io .FileUtils
3333
34- import org .apache .spark .SparkFunSuite
3534import org .apache .spark .sql ._
3635import org .apache .spark .sql .avro .SchemaConverters .IncompatibleSchemaException
36+ import org .apache .spark .sql .test .{SharedSQLContext , SQLTestUtils }
3737import org .apache .spark .sql .types ._
3838
39- class AvroSuite extends SparkFunSuite {
39+ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
4040 val episodesFile = " src/test/resources/episodes.avro"
4141 val testFile = " src/test/resources/test.avro"
4242
43- private var spark : SparkSession = _
44-
4543 override protected def beforeAll (): Unit = {
4644 super .beforeAll()
47- spark = SparkSession .builder()
48- .master(" local[2]" )
49- .appName(" AvroSuite" )
50- .config(" spark.sql.files.maxPartitionBytes" , 1024 )
51- .getOrCreate()
52- }
53-
54- override protected def afterAll (): Unit = {
55- try {
56- spark.sparkContext.stop()
57- } finally {
58- super .afterAll()
59- }
45+ spark.conf.set(" spark.sql.files.maxPartitionBytes" , 1024 )
46+ }
47+
48+ def checkReloadMatchesSaved (originalFile : String , newFile : String ): Unit = {
49+ val originalEntries = spark.read.avro(testFile).collect()
50+ val newEntries = spark.read.avro(newFile)
51+ checkAnswer(newEntries, originalEntries)
6052 }
6153
6254 test(" reading from multiple paths" ) {
@@ -68,7 +60,7 @@ class AvroSuite extends SparkFunSuite {
6860 val df = spark.read.avro(episodesFile)
6961 val fields = List (" title" , " air_date" , " doctor" )
7062 for (field <- fields) {
71- TestUtils .withTempDir { dir =>
63+ withTempPath { dir =>
7264 val outputDir = s " $dir/ ${UUID .randomUUID}"
7365 df.write.partitionBy(field).avro(outputDir)
7466 val input = spark.read.avro(outputDir)
@@ -82,28 +74,29 @@ class AvroSuite extends SparkFunSuite {
8274
8375 test(" request no fields" ) {
8476 val df = spark.read.avro(episodesFile)
85- df.registerTempTable (" avro_table" )
77+ df.createOrReplaceTempView (" avro_table" )
8678 assert(spark.sql(" select count(*) from avro_table" ).collect().head === Row (8 ))
8779 }
8880
8981 test(" convert formats" ) {
90- TestUtils .withTempDir { dir =>
82+ withTempPath { dir =>
9183 val df = spark.read.avro(episodesFile)
9284 df.write.parquet(dir.getCanonicalPath)
9385 assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
9486 }
9587 }
9688
9789 test(" rearrange internal schema" ) {
98- TestUtils .withTempDir { dir =>
90+ withTempPath { dir =>
9991 val df = spark.read.avro(episodesFile)
10092 df.select(" doctor" , " title" ).write.avro(dir.getCanonicalPath)
10193 }
10294 }
10395
10496 test(" test NULL avro type" ) {
105- TestUtils .withTempDir { dir =>
106- val fields = Seq (new Field (" null" , Schema .create(Type .NULL ), " doc" , null )).asJava
97+ withTempPath { dir =>
98+ val fields =
99+ Seq (new Field (" null" , Schema .create(Type .NULL ), " doc" , null .asInstanceOf [Any ])).asJava
107100 val schema = Schema .createRecord(" name" , " docs" , " namespace" , false )
108101 schema.setFields(fields)
109102 val datumWriter = new GenericDatumWriter [GenericRecord ](schema)
@@ -122,11 +115,11 @@ class AvroSuite extends SparkFunSuite {
122115 }
123116
124117 test(" union(int, long) is read as long" ) {
125- TestUtils .withTempDir { dir =>
118+ withTempPath { dir =>
126119 val avroSchema : Schema = {
127120 val union =
128121 Schema .createUnion(List (Schema .create(Type .INT ), Schema .create(Type .LONG )).asJava)
129- val fields = Seq (new Field (" field1" , union, " doc" , null )).asJava
122+ val fields = Seq (new Field (" field1" , union, " doc" , null . asInstanceOf [ Any ] )).asJava
130123 val schema = Schema .createRecord(" name" , " docs" , " namespace" , false )
131124 schema.setFields(fields)
132125 schema
@@ -150,11 +143,11 @@ class AvroSuite extends SparkFunSuite {
150143 }
151144
152145 test(" union(float, double) is read as double" ) {
153- TestUtils .withTempDir { dir =>
146+ withTempPath { dir =>
154147 val avroSchema : Schema = {
155148 val union =
156149 Schema .createUnion(List (Schema .create(Type .FLOAT ), Schema .create(Type .DOUBLE )).asJava)
157- val fields = Seq (new Field (" field1" , union, " doc" , null )).asJava
150+ val fields = Seq (new Field (" field1" , union, " doc" , null . asInstanceOf [ Any ] )).asJava
158151 val schema = Schema .createRecord(" name" , " docs" , " namespace" , false )
159152 schema.setFields(fields)
160153 schema
@@ -178,15 +171,15 @@ class AvroSuite extends SparkFunSuite {
178171 }
179172
180173 test(" union(float, double, null) is read as nullable double" ) {
181- TestUtils .withTempDir { dir =>
174+ withTempPath { dir =>
182175 val avroSchema : Schema = {
183176 val union = Schema .createUnion(
184177 List (Schema .create(Type .FLOAT ),
185178 Schema .create(Type .DOUBLE ),
186179 Schema .create(Type .NULL )
187180 ).asJava
188181 )
189- val fields = Seq (new Field (" field1" , union, " doc" , null )).asJava
182+ val fields = Seq (new Field (" field1" , union, " doc" , null . asInstanceOf [ Any ] )).asJava
190183 val schema = Schema .createRecord(" name" , " docs" , " namespace" , false )
191184 schema.setFields(fields)
192185 schema
@@ -210,9 +203,9 @@ class AvroSuite extends SparkFunSuite {
210203 }
211204
212205 test(" Union of a single type" ) {
213- TestUtils .withTempDir { dir =>
206+ withTempPath { dir =>
214207 val UnionOfOne = Schema .createUnion(List (Schema .create(Type .INT )).asJava)
215- val fields = Seq (new Field (" field1" , UnionOfOne , " doc" , null )).asJava
208+ val fields = Seq (new Field (" field1" , UnionOfOne , " doc" , null . asInstanceOf [ Any ] )).asJava
216209 val schema = Schema .createRecord(" name" , " docs" , " namespace" , false )
217210 schema.setFields(fields)
218211
@@ -233,16 +226,16 @@ class AvroSuite extends SparkFunSuite {
233226 }
234227
235228 test(" Complex Union Type" ) {
236- TestUtils .withTempDir { dir =>
229+ withTempPath { dir =>
237230 val fixedSchema = Schema .createFixed(" fixed_name" , " doc" , " namespace" , 4 )
238231 val enumSchema = Schema .createEnum(" enum_name" , " doc" , " namespace" , List (" e1" , " e2" ).asJava)
239232 val complexUnionType = Schema .createUnion(
240233 List (Schema .create(Type .INT ), Schema .create(Type .STRING ), fixedSchema, enumSchema).asJava)
241234 val fields = Seq (
242- new Field (" field1" , complexUnionType, " doc" , null ),
243- new Field (" field2" , complexUnionType, " doc" , null ),
244- new Field (" field3" , complexUnionType, " doc" , null ),
245- new Field (" field4" , complexUnionType, " doc" , null )
235+ new Field (" field1" , complexUnionType, " doc" , null . asInstanceOf [ Any ] ),
236+ new Field (" field2" , complexUnionType, " doc" , null . asInstanceOf [ Any ] ),
237+ new Field (" field3" , complexUnionType, " doc" , null . asInstanceOf [ Any ] ),
238+ new Field (" field4" , complexUnionType, " doc" , null . asInstanceOf [ Any ] )
246239 ).asJava
247240 val schema = Schema .createRecord(" name" , " docs" , " namespace" , false )
248241 schema.setFields(fields)
@@ -271,7 +264,7 @@ class AvroSuite extends SparkFunSuite {
271264 }
272265
273266 test(" Lots of nulls" ) {
274- TestUtils .withTempDir { dir =>
267+ withTempPath { dir =>
275268 val schema = StructType (Seq (
276269 StructField (" binary" , BinaryType , true ),
277270 StructField (" timestamp" , TimestampType , true ),
@@ -290,7 +283,7 @@ class AvroSuite extends SparkFunSuite {
290283 }
291284
292285 test(" Struct field type" ) {
293- TestUtils .withTempDir { dir =>
286+ withTempPath { dir =>
294287 val schema = StructType (Seq (
295288 StructField (" float" , FloatType , true ),
296289 StructField (" short" , ShortType , true ),
@@ -309,7 +302,7 @@ class AvroSuite extends SparkFunSuite {
309302 }
310303
311304 test(" Date field type" ) {
312- TestUtils .withTempDir { dir =>
305+ withTempPath { dir =>
313306 val schema = StructType (Seq (
314307 StructField (" float" , FloatType , true ),
315308 StructField (" date" , DateType , true )
@@ -329,7 +322,7 @@ class AvroSuite extends SparkFunSuite {
329322 }
330323
331324 test(" Array data types" ) {
332- TestUtils .withTempDir { dir =>
325+ withTempPath { dir =>
333326 val testSchema = StructType (Seq (
334327 StructField (" byte_array" , ArrayType (ByteType ), true ),
335328 StructField (" short_array" , ArrayType (ShortType ), true ),
@@ -363,13 +356,12 @@ class AvroSuite extends SparkFunSuite {
363356 }
364357
365358 test(" write with compression" ) {
366- TestUtils .withTempDir { dir =>
359+ withTempPath { dir =>
367360 val AVRO_COMPRESSION_CODEC = " spark.sql.avro.compression.codec"
368361 val AVRO_DEFLATE_LEVEL = " spark.sql.avro.deflate.level"
369362 val uncompressDir = s " $dir/uncompress "
370363 val deflateDir = s " $dir/deflate "
371364 val snappyDir = s " $dir/snappy "
372- val fakeDir = s " $dir/fake "
373365
374366 val df = spark.read.avro(testFile)
375367 spark.conf.set(AVRO_COMPRESSION_CODEC , " uncompressed" )
@@ -439,7 +431,7 @@ class AvroSuite extends SparkFunSuite {
439431 test(" sql test" ) {
440432 spark.sql(
441433 s """
442- |CREATE TEMPORARY TABLE avroTable
434+ |CREATE TEMPORARY VIEW avroTable
443435 |USING avro
444436 |OPTIONS (path " $episodesFile")
445437 """ .stripMargin.replaceAll(" \n " , " " ))
@@ -450,24 +442,24 @@ class AvroSuite extends SparkFunSuite {
450442 test(" conversion to avro and back" ) {
451443 // Note that test.avro includes a variety of types, some of which are nullable. We expect to
452444 // get the same values back.
453- TestUtils .withTempDir { dir =>
445+ withTempPath { dir =>
454446 val avroDir = s " $dir/avro "
455447 spark.read.avro(testFile).write.avro(avroDir)
456- TestUtils . checkReloadMatchesSaved(spark, testFile, avroDir)
448+ checkReloadMatchesSaved(testFile, avroDir)
457449 }
458450 }
459451
460452 test(" conversion to avro and back with namespace" ) {
461453 // Note that test.avro includes a variety of types, some of which are nullable. We expect to
462454 // get the same values back.
463- TestUtils .withTempDir { tempDir =>
455+ withTempPath { tempDir =>
464456 val name = " AvroTest"
465457 val namespace = " com.databricks.spark.avro"
466458 val parameters = Map (" recordName" -> name, " recordNamespace" -> namespace)
467459
468460 val avroDir = tempDir + " /namedAvro"
469461 spark.read.avro(testFile).write.options(parameters).avro(avroDir)
470- TestUtils . checkReloadMatchesSaved(spark, testFile, avroDir)
462+ checkReloadMatchesSaved(testFile, avroDir)
471463
472464 // Look at raw file and make sure has namespace info
473465 val rawSaved = spark.sparkContext.textFile(avroDir)
@@ -478,7 +470,7 @@ class AvroSuite extends SparkFunSuite {
478470 }
479471
480472 test(" converting some specific sparkSQL types to avro" ) {
481- TestUtils .withTempDir { tempDir =>
473+ withTempPath { tempDir =>
482474 val testSchema = StructType (Seq (
483475 StructField (" Name" , StringType , false ),
484476 StructField (" Length" , IntegerType , true ),
@@ -520,7 +512,7 @@ class AvroSuite extends SparkFunSuite {
520512 }
521513
522514 test(" correctly read long as date/timestamp type" ) {
523- TestUtils .withTempDir { tempDir =>
515+ withTempPath { tempDir =>
524516 val sparkSession = spark
525517 import sparkSession .implicits ._
526518
@@ -549,7 +541,7 @@ class AvroSuite extends SparkFunSuite {
549541 }
550542
551543 test(" does not coerce null date/timestamp value to 0 epoch." ) {
552- TestUtils .withTempDir { tempDir =>
544+ withTempPath { tempDir =>
553545 val sparkSession = spark
554546 import sparkSession .implicits ._
555547
@@ -610,7 +602,7 @@ class AvroSuite extends SparkFunSuite {
610602
611603 // Directory given has no avro files
612604 intercept[AnalysisException ] {
613- TestUtils .withTempDir (dir => spark.read.avro(dir.getCanonicalPath))
605+ withTempPath (dir => spark.read.avro(dir.getCanonicalPath))
614606 }
615607
616608 intercept[AnalysisException ] {
@@ -624,7 +616,7 @@ class AvroSuite extends SparkFunSuite {
624616 }
625617
626618 intercept[FileNotFoundException ] {
627- TestUtils .withTempDir { dir =>
619+ withTempPath { dir =>
628620 FileUtils .touch(new File (dir, " test" ))
629621 spark.read.avro(dir.toString)
630622 }
@@ -633,19 +625,19 @@ class AvroSuite extends SparkFunSuite {
633625 }
634626
635627 test(" SQL test insert overwrite" ) {
636- TestUtils .withTempDir { tempDir =>
628+ withTempPath { tempDir =>
637629 val tempEmptyDir = s " $tempDir/sqlOverwrite "
638630 // Create a temp directory for table that will be overwritten
639631 new File (tempEmptyDir).mkdirs()
640632 spark.sql(
641633 s """
642- |CREATE TEMPORARY TABLE episodes
634+ |CREATE TEMPORARY VIEW episodes
643635 |USING avro
644636 |OPTIONS (path " $episodesFile")
645637 """ .stripMargin.replaceAll(" \n " , " " ))
646638 spark.sql(
647639 s """
648- |CREATE TEMPORARY TABLE episodesEmpty
640+ |CREATE TEMPORARY VIEW episodesEmpty
649641 |(name string, air_date string, doctor int)
650642 |USING avro
651643 |OPTIONS (path " $tempEmptyDir")
@@ -665,7 +657,7 @@ class AvroSuite extends SparkFunSuite {
665657
666658 test(" test save and load" ) {
667659 // Test if load works as expected
668- TestUtils .withTempDir { tempDir =>
660+ withTempPath { tempDir =>
669661 val df = spark.read.avro(episodesFile)
670662 assert(df.count == 8 )
671663
@@ -679,7 +671,7 @@ class AvroSuite extends SparkFunSuite {
679671
680672 test(" test load with non-Avro file" ) {
681673 // Test if load works as expected
682- TestUtils .withTempDir { tempDir =>
674+ withTempPath { tempDir =>
683675 val df = spark.read.avro(episodesFile)
684676 assert(df.count == 8 )
685677
@@ -737,7 +729,7 @@ class AvroSuite extends SparkFunSuite {
737729 }
738730
739731 test(" read avro file partitioned" ) {
740- TestUtils .withTempDir { dir =>
732+ withTempPath { dir =>
741733 val sparkSession = spark
742734 import sparkSession .implicits ._
743735 val df = (0 to 1024 * 3 ).toDS.map(i => s " record ${i}" ).toDF(" records" )
@@ -756,7 +748,7 @@ class AvroSuite extends SparkFunSuite {
756748 case class NestedTop (id : Int , data : NestedMiddle )
757749
758750 test(" saving avro that has nested records with the same name" ) {
759- TestUtils .withTempDir { tempDir =>
751+ withTempPath { tempDir =>
760752 // Save avro file on output folder path
761753 val writeDf = spark.createDataFrame(List (NestedTop (1 , NestedMiddle (2 , NestedBottom (3 , " 1" )))))
762754 val outputFolder = s " $tempDir/duplicate_names/ "
@@ -773,7 +765,7 @@ class AvroSuite extends SparkFunSuite {
773765 case class NestedTopArray (id : Int , data : NestedMiddleArray )
774766
775767 test(" saving avro that has nested records with the same name inside an array" ) {
776- TestUtils .withTempDir { tempDir =>
768+ withTempPath { tempDir =>
777769 // Save avro file on output folder path
778770 val writeDf = spark.createDataFrame(
779771 List (NestedTopArray (1 , NestedMiddleArray (2 , Array (
@@ -794,7 +786,7 @@ class AvroSuite extends SparkFunSuite {
794786 case class NestedTopMap (id : Int , data : NestedMiddleMap )
795787
796788 test(" saving avro that has nested records with the same name inside a map" ) {
797- TestUtils .withTempDir { tempDir =>
789+ withTempPath { tempDir =>
798790 // Save avro file on output folder path
799791 val writeDf = spark.createDataFrame(
800792 List (NestedTopMap (1 , NestedMiddleMap (2 , Map (
0 commit comments