@@ -22,29 +22,38 @@ import org.apache.spark.mllib.linalg.Vectors
2222import org .apache .spark .mllib .export .ModelExportFactory
2323import org .apache .spark .mllib .clustering .KMeansModel
2424import org .apache .spark .mllib .export .ModelExportType
25+ import org .dmg .pmml .ClusteringModel
26+ import javax .xml .parsers .DocumentBuilderFactory
27+ import java .io .ByteArrayOutputStream
2528
2629class KMeansPMMLModelExportSuite extends FunSuite {
2730
2831 test(" KMeansPMMLModelExport generate PMML format" ) {
2932
33+ // arrange model to test
3034 val clusterCenters = Array (
3135 Vectors .dense(1.0 , 2.0 , 6.0 ),
3236 Vectors .dense(1.0 , 3.0 , 0.0 ),
3337 Vectors .dense(1.0 , 4.0 , 6.0 )
3438 )
35-
3639 val kmeansModel = new KMeansModel (clusterCenters);
3740
41+ // act by exporting the model to the PMML format
3842 val modelExport = ModelExportFactory .createModelExport(kmeansModel, ModelExportType .PMML )
39-
43+
44+ // assert that the PMML format is as expected
4045 assert(modelExport.isInstanceOf [PMMLModelExport ])
46+ var pmml = modelExport.asInstanceOf [PMMLModelExport ].getPmml()
47+ assert(pmml.getHeader().getDescription() === " k-means clustering" )
48+ // check that the number of fields match the single vector size
49+ assert(pmml.getDataDictionary().getNumberOfFields() === clusterCenters(0 ).size)
50+ // this verify that there is a model attached to the pmml object and the model is a clustering one
51+ // it also verifies that the pmml model has the same number of clusters of the spark model
52+ assert(pmml.getModels().get(0 ).asInstanceOf [ClusteringModel ].getNumberOfClusters() === clusterCenters.size)
4153
42- // TODO: asserts
43- // compare pmml fields to strings
44- modelExport.asInstanceOf [PMMLModelExport ].getPmml()
45- // use document builder to load the xml generated and validated the notes by looking for them
46- modelExport.asInstanceOf [PMMLModelExport ].save(System .out)
47- // saveLocalFile too??? search how to unit test file creating in java
54+ // manual checking
55+ // modelExport.asInstanceOf[PMMLModelExport].save(System.out)
56+ // modelExport.asInstanceOf[PMMLModelExport].saveLocalFile("/tmp/kmeans.xml")
4857
4958 }
5059
0 commit comments