Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,10 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
}
}

/**
* Creates a param pair with the given value (for Java).
*/
/** Creates a param pair with the given value (for Java). */
def w(value: T): ParamPair[T] = this -> value

/**
* Creates a param pair with the given value (for Scala).
*/
/** Creates a param pair with the given value (for Scala). */
def ->(value: T): ParamPair[T] = ParamPair(this, value)

override final def toString: String = s"${parent}__$name"
Expand Down Expand Up @@ -190,6 +186,7 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, it doesn't inherit the parent doc.

override def w(value: Double): ParamPair[Double] = super.w(value)
}

Expand All @@ -209,6 +206,7 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value)
}

Expand All @@ -228,6 +226,7 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Float): ParamPair[Float] = super.w(value)
}

Expand All @@ -247,6 +246,7 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Long): ParamPair[Long] = super.w(value)
}

Expand All @@ -260,6 +260,7 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV

def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)

/** Creates a param pair with the given value (for Java). */
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}

Expand All @@ -274,8 +275,6 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)

/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}
Expand All @@ -291,10 +290,9 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)

/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
w(value.asScala.map(_.asInstanceOf[Double]).toArray)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.IndexedSeq
import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
Expand Down Expand Up @@ -188,6 +189,9 @@ class GaussianMixture private (
new GaussianMixtureModel(weights, gaussians)
}

/** Java-friendly version of [[run()]] */
def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)

/** Average of dense breeze vectors */
private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
val v = BDV.zeros[Double](x(0).length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
Expand All @@ -46,7 +47,7 @@ import org.apache.spark.sql.{SQLContext, Row}
@Experimental
class GaussianMixtureModel(
val weights: Array[Double],
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable {

require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")

Expand All @@ -65,6 +66,10 @@ class GaussianMixtureModel(
responsibilityMatrix.map(r => r.indexOf(r.max))
}

/** Java-friendly version of [[predict()]] */
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]

/**
* Given the input vectors, return the membership value of each vector
* to all mixture components.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -345,6 +346,11 @@ class DistributedLDAModel private (
}
}

/** Java-friendly version of [[topicDistributions]] */
def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = {
JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
}

// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import scala.reflect.ClassTag

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
Expand Down Expand Up @@ -234,6 +236,9 @@ class StreamingKMeans(
}
}

/** Java-friendly version of `trainOn`. */
def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream)

/**
* Use the clustering model to make predictions on batches of data from a DStream.
*
Expand All @@ -245,6 +250,11 @@ class StreamingKMeans(
data.map(model.predict)
}

/** Java-friendly version of `predictOn`. */
def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = {
JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]])
}

/**
* Use the model to make predictions on the values of a DStream and carry over its keys.
*
Expand All @@ -257,6 +267,14 @@ class StreamingKMeans(
data.mapValues(model.predict)
}

/** Java-friendly version of `predictOnValues`. */
def predictOnValues[K](
data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = {
implicit val tag = fakeClassTag[K]
JavaPairDStream.fromPairDStream(
predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Integer)]])
}

/** Check whether cluster centers have been initialized. */
private[this] def assertInitialized(): Unit = {
if (model.clusterCenters == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.stat

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -80,6 +81,10 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y)

/** Java-friendly version of [[corr()]] */
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double =
corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]])

/**
* Compute the correlation for the input RDDs using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
Expand All @@ -96,6 +101,10 @@ object Statistics {
*/
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)

/** Java-friendly version of [[corr()]] */
def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing doc

corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method)

/**
* Conduct Pearson's chi-squared goodness of fit test of the observed data against the
* expected distribution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public void testParams() {
testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
Assert.assertEquals(testParams.getMyStringParam(), "a");
Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0);
}

@Test
Expand Down
29 changes: 24 additions & 5 deletions mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ public String uid() {
public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }

public JavaTestParams setMyIntParam(int value) {
set(myIntParam_, value); return this;
set(myIntParam_, value);
return this;
}

private DoubleParam myDoubleParam_;
Expand All @@ -60,7 +61,8 @@ public JavaTestParams setMyIntParam(int value) {
public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }

public JavaTestParams setMyDoubleParam(double value) {
set(myDoubleParam_, value); return this;
set(myDoubleParam_, value);
return this;
}

private Param<String> myStringParam_;
Expand All @@ -69,7 +71,18 @@ public JavaTestParams setMyDoubleParam(double value) {
public String getMyStringParam() { return getOrDefault(myStringParam_); }

public JavaTestParams setMyStringParam(String value) {
set(myStringParam_, value); return this;
set(myStringParam_, value);
return this;
}

private DoubleArrayParam myDoubleArrayParam_;
public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }

public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); }

public JavaTestParams setMyDoubleArrayParam(double[] value) {
set(myDoubleArrayParam_, value);
return this;
}

private void init() {
Expand All @@ -79,8 +92,14 @@ private void init() {
List<String> validStrings = Lists.newArrayList("a", "b");
myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
setDefault(myIntParam_, 1);
setDefault(myDoubleParam_, 0.5);
myDoubleArrayParam_ =
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");

setDefault(myIntParam(), 1);
setDefault(myIntParam().w(1));
setDefault(myDoubleParam(), 0.5);
setDefault(myIntParam().w(1), myDoubleParam().w(0.5));
setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.ml.classification;
package org.apache.spark.mllib.classification;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!


import java.io.Serializable;
import java.util.List;
Expand All @@ -28,7 +28,6 @@
import org.junit.Test;

import org.apache.spark.SparkConf;
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering;

import java.io.Serializable;
import java.util.List;

import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import static org.junit.Assert.assertEquals;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;

public class JavaGaussianMixtureSuite implements Serializable {
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaGaussianMixture");
}

@After
public void tearDown() {
sc.stop();
sc = null;
}

@Test
public void runGaussianMixture() {
List<Vector> points = Lists.newArrayList(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
);

JavaRDD<Vector> data = sc.parallelize(points, 2);
GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
.run(data);
assertEquals(model.gaussians().length, 2);
JavaRDD<Integer> predictions = model.predict(data);
predictions.first();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ public void distributedLDAModel() {
// Check: log probabilities
assert(model.logLikelihood() < 0.0);
assert(model.logPrior() < 0.0);

// Check: topic distributions
JavaPairRDD<Long, Vector> topicDistributions = model.javaTopicDistributions();
assertEquals(topicDistributions.count(), corpus.count());
}

@Test
Expand Down
Loading