Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,22 @@ final class QuantileDiscretizer(override val uid: String)

@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {

/**
* Minimum number of samples required for finding splits, regardless of number of bins. If
* the dataset has fewer rows than this value, the entire dataset will be used.
*/
private[spark] val minSamplesRequired: Int = 10000

/**
* Sampling from the given dataset to collect quantile statistics.
*/
private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
val totalSamples = dataset.count()
require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
val requiredSamples = math.max(numBins * numBins, 10000)
val fraction = math.min(requiredSamples / dataset.count(), 1.0)
val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0)
dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ class QuantileDiscretizerSuite
}
}

test("Test splits on dataset larger than minSamplesRequired") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._

val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
val numBuckets = 5
val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)

val result = discretizer.fit(df).transform(df)
val observedNumBuckets = result.select("result").distinct.count

assert(observedNumBuckets === numBuckets,
"Observed number of buckets does not equal expected number of buckets.")
}

test("read/write") {
val t = new QuantileDiscretizer()
.setInputCol("myInputCol")
Expand Down