File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed
mllib/src/main/scala/org/apache/spark/mllib/tree/loss Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -60,8 +60,16 @@ object LogLoss extends Loss {
6060 override def computeError (model : TreeEnsembleModel , data : RDD [LabeledPoint ]): Double = {
6161 data.map { case point =>
6262 val prediction = model.predict(point.features)
63- // Use log1p since it is more stable than explicitly writing log(1 + exp()).
64- 2.0 * math.log1p(math.exp(- 2.0 * point.label * prediction))
63+ val margin = 2.0 * point.label * prediction
64+ // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
65+ // stable.
66+ if (margin >= 0 ) {
67+ 2.0 * math.log1p(math.exp(- margin))
68+ // math.log1p(math.exp(w))
69+ } else {
70+ // w + math.log1p(math.exp(-w))
71+ 2.0 * (- margin + math.log1p(math.exp(margin)))
72+ }
6573 }.mean()
6674 }
6775}
You can’t perform that action at this time.
0 commit comments