| 
17 | 17 | 
 
  | 
18 | 18 | package org.apache.spark.mllib.util  | 
19 | 19 | 
 
  | 
 | 20 | +import scala.reflect.runtime.universe.TypeTag  | 
 | 21 | + | 
20 | 22 | import org.apache.spark.SparkContext  | 
21 | 23 | import org.apache.spark.annotation.DeveloperApi  | 
 | 24 | +import org.apache.spark.sql.catalyst.ScalaReflection  | 
 | 25 | +import org.apache.spark.sql.types.{DataType, StructType, StructField}  | 
 | 26 | +import org.apache.spark.sql.{DataFrame, Row, SQLContext}  | 
22 | 27 | 
 
  | 
23 | 28 | /**  | 
24 | 29 |  * :: DeveloperApi ::  | 
@@ -46,11 +51,7 @@ trait Exportable {  | 
46 | 51 | 
 
  | 
47 | 52 | }  | 
48 | 53 | 
 
  | 
49 |  | -/**  | 
50 |  | - * :: DeveloperApi ::  | 
51 |  | - */  | 
52 |  | -@DeveloperApi  | 
53 |  | -object Exportable {  | 
 | 54 | +private[mllib] object Exportable {  | 
54 | 55 | 
 
  | 
55 | 56 |   /** Current version of model import/export format. */  | 
56 | 57 |   val latestVersion: String = "1.0"  | 
@@ -79,34 +80,32 @@ trait Importable[Model <: Exportable] {  | 
79 | 80 | 
 
  | 
80 | 81 | }  | 
81 | 82 | 
 
  | 
82 |  | -/*  | 
83 |  | -/**  | 
84 |  | - * :: DeveloperApi ::  | 
85 |  | - *  | 
86 |  | - * Trait for models and transformers which may be saved as files.  | 
87 |  | - * This should be inherited by the class which implements model instances.  | 
88 |  | - *  | 
89 |  | - * This specializes [[Exportable]] for local models which can be stored on a single machine.  | 
90 |  | - * This provides helper functionality, but developers can choose to use [[Exportable]] instead,  | 
91 |  | - * even for local models.  | 
92 |  | - */  | 
93 |  | -@DeveloperApi  | 
94 |  | -trait LocalExportable {  | 
 | 83 | +private[mllib] object Importable {  | 
95 | 84 | 
 
  | 
96 | 85 |   /**  | 
97 |  | -   * Save this model to the given path.  | 
98 |  | -   *  | 
99 |  | -   * This saves:  | 
100 |  | -   *  - human-readable (JSON) model metadata to path/metadata/  | 
101 |  | -   *  - Parquet formatted data to path/data/  | 
 | 86 | +   * Check the schema of loaded model data.  | 
102 | 87 |    *  | 
103 |  | -   * The model may be loaded using [[Importable.load]].  | 
 | 88 | +   * This checks every field in the expected schema to make sure that a field with the same  | 
 | 89 | +   * name and DataType appears in the loaded schema.  Note that this does NOT check metadata  | 
 | 90 | +   * or containsNull.  | 
104 | 91 |    *  | 
105 |  | -   * @param sc  Spark context used to save model data.  | 
106 |  | -   * @param path  Path specifying the directory in which to save this model.  | 
107 |  | -   *              This directory and any intermediate directory will be created if needed.  | 
 | 92 | +   * @param loadedSchema  Schema for model data loaded from file.  | 
 | 93 | +   * @tparam Data  Expected data type from which an expected schema can be derived.  | 
108 | 94 |    */  | 
109 |  | -  def save(sc: SparkContext, path: String): Unit  | 
 | 95 | +  def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = {  | 
 | 96 | +    // Check schema explicitly since erasure makes it hard to use match-case for checking.  | 
 | 97 | +    val expectedFields: Array[StructField] =  | 
 | 98 | +      ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields  | 
 | 99 | +    val loadedFields: Map[String, DataType] =  | 
 | 100 | +      loadedSchema.map(field => field.name -> field.dataType).toMap  | 
 | 101 | +    expectedFields.foreach { field =>  | 
 | 102 | +      assert(loadedFields.contains(field.name), s"Unable to parse model data." +  | 
 | 103 | +        s"  Expected field with name ${field.name} was missing in loaded schema:" +  | 
 | 104 | +        s" ${loadedFields.mkString(", ")}")  | 
 | 105 | +      assert(loadedFields(field.name) == field.dataType,  | 
 | 106 | +        s"Unable to parse model data.  Expected field $field but found field" +  | 
 | 107 | +          s" with different type: ${loadedFields(field.name)}")  | 
 | 108 | +    }  | 
 | 109 | +  }  | 
110 | 110 | 
 
  | 
111 | 111 | }  | 
112 |  | -*/  | 
0 commit comments