@@ -55,9 +55,26 @@ class GaussianMixtureModelEM private (
5555 // number of samples per cluster to use when initializing Gaussians
5656 private val nSamples = 5
5757
58+ // an initializing GMM can be provided rather than using the
59+ // default random starting point
60+ private var initialGmm : Option [GaussianMixtureModel ] = None
61+
5862 /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
5963 def this () = this (2 , 0.01 , 100 )
6064
65+ /** Set the initial GMM starting point, bypassing the random initialization */
66+ def setInitialGmm (gmm : GaussianMixtureModel ): this .type = {
67+ if (gmm.k == k) {
68+ initialGmm = Some (gmm)
69+ } else {
70+ throw new IllegalArgumentException (" initialing GMM has mismatched cluster count (gmm.k != k)" )
71+ }
72+ this
73+ }
74+
75+ /** Return the user supplied initial GMM, if supplied */
76+ def getInitialiGmm : Option [GaussianMixtureModel ] = initialGmm
77+
6178 /** Set the number of Gaussians in the mixture model. Default: 2 */
6279 def setK (k : Int ): this .type = {
6380 this .k = k
@@ -103,20 +120,35 @@ class GaussianMixtureModelEM private (
103120 // Get length of the input vectors
104121 val d = breezeData.first.length
105122
106- // For each Gaussian, we will initialize the mean as the average
107- // of some random samples from the data
108- val samples = breezeData.takeSample(true , k * nSamples, scala.util.Random .nextInt)
109-
110- // gaussians will be array of (weight, mean, covariance) tuples
123+ // gaussians will be array of (weight, mean, covariance) tuples.
124+ // If the user supplied an initial GMM, we use those values, otherwise
111125 // we start with uniform weights, a random mean from the data, and
112126 // diagonal covariance matrices using component variances
113127 // derived from the samples
114- var gaussians = (0 until k).map{ i =>
128+ var gaussians = initialGmm match {
129+ case Some (gmm) => (0 until k).map{ i =>
130+ (gmm.weight(i), gmm.mu(i).toBreeze.toDenseVector, gmm.sigma(i).toBreeze.toDenseMatrix)
131+ }.toArray
132+
133+ case None => {
134+ // For each Gaussian, we will initialize the mean as the average
135+ // of some random samples from the data
136+ val samples = breezeData.takeSample(true , k * nSamples, scala.util.Random .nextInt)
137+
138+ (0 until k).map{ i =>
139+ (1.0 / k,
140+ vectorMean(samples.slice(i * nSamples, (i + 1 ) * nSamples)),
141+ initCovariance(samples.slice(i * nSamples, (i + 1 ) * nSamples)))
142+ }.toArray
143+ }
144+ }
145+
146+ /* var gaussians = (0 until k).map{ i =>
115147 (1.0 / k,
116148 vectorMean(samples.slice(i * nSamples, (i + 1) * nSamples)),
117149 initCovariance(samples.slice(i * nSamples, (i + 1) * nSamples)))
118150 }.toArray
119-
151+ */
120152 val accW = new Array [Accumulator [Double ]](k)
121153 val accMu = new Array [Accumulator [DenseDoubleVector ]](k)
122154 val accSigma = new Array [Accumulator [DenseDoubleMatrix ]](k)
0 commit comments