|
16 | 16 | # |
17 | 17 |
|
18 | 18 | """ |
19 | | -An example of how to use DataFrame as a dataset for ML. Run with:: |
20 | | - bin/spark-submit examples/src/main/python/mllib/dataset_example.py |
| 19 | +An example of how to use DataFrame for ML. Run with:: |
| 20 | + bin/spark-submit examples/src/main/python/ml/dataframe_example.py <input> |
21 | 21 | """ |
22 | 22 | from __future__ import print_function |
23 | 23 |
|
|
28 | 28 |
|
29 | 29 | from pyspark import SparkContext |
30 | 30 | from pyspark.sql import SQLContext |
31 | | -from pyspark.mllib.util import MLUtils |
32 | 31 | from pyspark.mllib.stat import Statistics |
33 | 32 |
|
34 | | - |
35 | | -def summarize(dataset): |
36 | | - print("schema: %s" % dataset.schema().json()) |
37 | | - labels = dataset.map(lambda r: r.label) |
38 | | - print("label average: %f" % labels.mean()) |
39 | | - features = dataset.map(lambda r: r.features) |
40 | | - summary = Statistics.colStats(features) |
41 | | - print("features average: %r" % summary.mean()) |
42 | | - |
43 | 33 | if __name__ == "__main__": |
44 | 34 | if len(sys.argv) > 2: |
45 | | - print("Usage: dataset_example.py <libsvm file>", file=sys.stderr) |
| 35 | + print("Usage: dataframe_example.py <libsvm file>", file=sys.stderr) |
46 | 36 | exit(-1) |
47 | | - sc = SparkContext(appName="DatasetExample") |
| 37 | + sc = SparkContext(appName="DataFrameExample") |
48 | 38 | sqlContext = SQLContext(sc) |
49 | 39 | if len(sys.argv) == 2: |
50 | 40 | input = sys.argv[1] |
51 | 41 | else: |
52 | 42 | input = "data/mllib/sample_libsvm_data.txt" |
53 | | - points = MLUtils.loadLibSVMFile(sc, input) |
54 | | - dataset0 = sqlContext.inferSchema(points).setName("dataset0").cache() |
55 | | - summarize(dataset0) |
| 43 | + |
| 44 | + # Load input data |
| 45 | + print("Loading LIBSVM file with UDT from " + input + ".") |
| 46 | + df = sqlContext.read.format("libsvm").load(input).cache() |
| 47 | + print("Schema from LIBSVM:") |
| 48 | + df.printSchema() |
| 49 | + print("Loaded training data as a DataFrame with " + |
| 50 | + str(df.count()) + " records.") |
| 51 | + |
| 52 | + # Show statistical summary of labels. |
| 53 | + labelSummary = df.describe("label") |
| 54 | + labelSummary.show() |
| 55 | + |
| 56 | + # Convert features column to an RDD of vectors. |
| 57 | + features = df.select("features").map(lambda r: r.features) |
| 58 | + summary = Statistics.colStats(features) |
| 59 | + print("Selected features column with average values:\n" + |
| 60 | + str(summary.mean())) |
| 61 | + |
| 62 | + # Save the records in a parquet file. |
56 | 63 | tempdir = tempfile.NamedTemporaryFile(delete=False).name |
57 | 64 | os.unlink(tempdir) |
58 | | - print("Save dataset as a Parquet file to %s." % tempdir) |
59 | | - dataset0.saveAsParquetFile(tempdir) |
60 | | - print("Load it back and summarize it again.") |
61 | | - dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() |
62 | | - summarize(dataset1) |
| 65 | + print("Saving to " + tempdir + " as Parquet file.") |
| 66 | + df.write.parquet(tempdir) |
| 67 | + |
| 68 | + # Load the records back. |
| 69 | + print("Loading Parquet file with UDT from " + tempdir) |
| 70 | + newDF = sqlContext.read.parquet(tempdir) |
| 71 | + print("Schema from Parquet:") |
| 72 | + newDF.printSchema() |
63 | 73 | shutil.rmtree(tempdir) |
| 74 | + |
| 75 | + sc.stop() |
0 commit comments