@@ -5,15 +5,30 @@ import java.io.File
55import com .google .common .io .Files
66import com .indix .utils .spark .parquet .avro .ParquetAvroDataSource
77import org .apache .commons .io .FileUtils
8+ import org .apache .parquet .hadoop .metadata .CompressionCodecName
89import org .apache .spark .sql .SparkSession
10+ import org .scalactic .Equality
11+ import org .scalatest .Matchers .{be , convertToAnyShouldWrapper , equal }
912import org .scalatest .{BeforeAndAfterAll , FlatSpec }
10- import org .scalatest .Matchers .{be , convertToAnyShouldWrapper }
11- import org .apache .parquet .hadoop .metadata .CompressionCodecName
13+ import java .util .{Arrays => JArrays }
1214
13- case class SampleAvroRecord (a : Int , b : String , c : Seq [String ], d : Boolean , e : Double , f : collection.Map [String ,String ], g : Seq [Byte ])
15+ case class SampleAvroRecord (a : Int , b : String , c : Seq [String ], d : Boolean , e : Double , f : collection.Map [String , String ], g : Array [Byte ])
1416
1517class ParquetAvroDataSourceSpec extends FlatSpec with BeforeAndAfterAll with ParquetAvroDataSource {
1618 private var spark : SparkSession = _
19+ implicit val sampleAvroRecordEq = new Equality [SampleAvroRecord ] {
20+ override def areEqual (left : SampleAvroRecord , b : Any ): Boolean = b match {
21+ case right : SampleAvroRecord =>
22+ left.a == right.a &&
23+ left.b == right.b &&
24+ Equality .default[Seq [String ]].areEqual(left.c, right.c) &&
25+ left.d == right.d &&
26+ left.e == right.e &&
27+ Equality .default[collection.Map [String , String ]].areEqual(left.f, right.f) &&
28+ JArrays .equals(left.g, right.g)
29+ case _ => false
30+ }
31+ }
1732
1833 override protected def beforeAll (): Unit = {
1934 super .beforeAll()
@@ -33,11 +48,11 @@ class ParquetAvroDataSourceSpec extends FlatSpec with BeforeAndAfterAll with Par
3348 val outputLocation = Files .createTempDir().getAbsolutePath + " /output"
3449
3550 val sampleRecords : Seq [SampleAvroRecord ] = Seq (
36- SampleAvroRecord (1 , " 1" , List (" a1" ), true , 1.0d , Map (" a1" -> " b1" ), Seq ( " 1" .toByte) ),
37- SampleAvroRecord (2 , " 2" , List (" a2" ), false , 2.0d , Map (" a2" -> " b2" ), Seq ( " 2" .toByte) ),
38- SampleAvroRecord (3 , " 3" , List (" a3" ), true , 3.0d , Map (" a3" -> " b3" ), Seq ( " 3" .toByte) ),
39- SampleAvroRecord (4 , " 4" , List (" a4" ), true , 4.0d , Map (" a4" -> " b4" ), Seq ( " 4" .toByte) ),
40- SampleAvroRecord (5 , " 5" , List (" a5" ), false , 5.0d , Map (" a5" -> " b5" ), Seq ( " 5" .toByte) )
51+ SampleAvroRecord (1 , " 1" , List (" a1" ), true , 1.0d , Map (" a1" -> " b1" ), " 1" .getBytes ),
52+ SampleAvroRecord (2 , " 2" , List (" a2" ), false , 2.0d , Map (" a2" -> " b2" ), " 2" .getBytes ),
53+ SampleAvroRecord (3 , " 3" , List (" a3" ), true , 3.0d , Map (" a3" -> " b3" ), " 3" .getBytes ),
54+ SampleAvroRecord (4 , " 4" , List (" a4" ), true , 4.0d , Map (" a4" -> " b4" ), " 4" .getBytes ),
55+ SampleAvroRecord (5 , " 5" , List (" a5" ), false , 5.0d , Map (" a5" -> " b5" ), " 5" .getBytes )
4156 )
4257
4358 val sampleDf = spark.createDataFrame(sampleRecords)
@@ -51,7 +66,9 @@ class ParquetAvroDataSourceSpec extends FlatSpec with BeforeAndAfterAll with Par
5166 val records : Array [SampleAvroRecord ] = spark.read.parquet(outputLocation).as[SampleAvroRecord ].collect()
5267
5368 records.length should be(5 )
54- records.sortBy(_.a) should be (sampleRecords.sortBy(_.a))
69+ // We use === to use the custom Equality defined above for comparing Array[Byte]
70+ // Ref - https://github.com/scalatest/scalatest/issues/491
71+ records.sortBy(_.a) === sampleRecords.sortBy(_.a)
5572
5673 FileUtils .deleteDirectory(new File (outputLocation))
5774 }
0 commit comments