1818package  org .apache .spark .mllib .export .pmml 
1919
2020import  org .apache .spark .mllib .clustering .KMeansModel 
21+ import  org .dmg .pmml .DataDictionary 
22+ import  org .dmg .pmml .FieldName 
23+ import  org .dmg .pmml .DataField 
24+ import  org .dmg .pmml .OpType 
25+ import  org .dmg .pmml .DataType 
26+ import  org .dmg .pmml .MiningSchema 
27+ import  org .dmg .pmml .MiningField 
28+ import  org .dmg .pmml .FieldUsageType 
29+ import  org .dmg .pmml .ComparisonMeasure 
30+ import  org .dmg .pmml .ComparisonMeasure .Kind 
31+ import  org .dmg .pmml .SquaredEuclidean 
32+ import  org .dmg .pmml .ClusteringModel 
33+ import  org .dmg .pmml .MiningFunctionType 
34+ import  org .dmg .pmml .ClusteringModel .ModelClass 
35+ import  org .dmg .pmml .ClusteringField 
36+ import  org .dmg .pmml .CompareFunctionType 
37+ import  org .dmg .pmml .Cluster 
38+ import  org .dmg .pmml .Array .Type 
2139
2240/** 
2341 * PMML Model Export for KMeansModel class 
@@ -30,9 +48,48 @@ class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
3048  populateKMeansPMML(model);
3149
3250  private  def  populateKMeansPMML (model  : KMeansModel ):  Unit  =  {
33-      // TODO: set here header description 
34-      pmml.setVersion(" testing... kmeans..."  ); 
35-      // TODO: generate the model...
51+     
52+      pmml.getHeader().setDescription(" k-means clustering"  ); 
53+      
54+      if (model.clusterCenters.length >  0 ){
55+        
56+        val  clusterCenter  =  model.clusterCenters(0 )
57+        
58+        var  fields  =  new  Array [FieldName ](clusterCenter.size)
59+        
60+        var  dataDictionary  =  new  DataDictionary ()
61+        
62+        var  miningSchema  =  new  MiningSchema ()
63+        
64+        for  ( i <-  0  to (clusterCenter.size -  1 )) {
65+          fields(i) =  FieldName .create(" field_" + i)
66+          dataDictionary.withDataFields(new  DataField (fields(i), OpType .CONTINUOUS , DataType .DOUBLE ))
67+          miningSchema.withMiningFields(new  MiningField (fields(i)).withUsageType(FieldUsageType .ACTIVE ))
68+        }
69+        
70+        var  comparisonMeasure  =  new  ComparisonMeasure ()
71+        	.withKind(Kind .DISTANCE )
72+         .withMeasure(new  SquaredEuclidean ()
73+        );
74+        
75+        dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size());
76+               
77+        pmml.setDataDictionary(dataDictionary);
78+        
79+        var  clusteringModel  =  new  ClusteringModel (miningSchema, comparisonMeasure, MiningFunctionType .CLUSTERING , ModelClass .CENTER_BASED , model.clusterCenters.length)
80+        	.withModelName(" k-means"  );
81+        
82+        for  ( i <-  0  to (clusterCenter.size -  1 )) {
83+     	   clusteringModel.withClusteringFields(new  ClusteringField (fields(i)).withCompareFunction(CompareFunctionType .ABS_DIFF ))
84+     	   var  cluster  =  new  Cluster ().withName(" cluster_" + i).withArray(new  org.dmg.pmml.Array ().withType(Type .REAL ).withN(clusterCenter.size).withValue(model.clusterCenters(i).toArray.mkString("  "  )))
85+     	   // cluster.withSize(value) //we don't have the size of the single cluster but only the centroids (withValue)
86+     	   clusteringModel.withClusters(cluster)
87+        }
88+ 
89+        pmml.withModels(clusteringModel);
90+        
91+      }
92+  
3693  }
3794
3895}
0 commit comments