Skip to content

Commit a27eb6d

Browse files
committed
corrections to last log loss commit
1 parent ed5da2c commit a27eb6d

File tree

1 file changed

+10
-2
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/tree/loss

1 file changed

+10
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,16 @@ object LogLoss extends Loss {
6060
override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
6161
data.map { case point =>
6262
val prediction = model.predict(point.features)
63-
// Use log1p since it is more stable than explicitly writing log(1 + exp()).
64-
2.0 * math.log1p(math.exp(-2.0 * point.label * prediction))
63+
val margin = 2.0 * point.label * prediction
64+
// The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
65+
// stable.
66+
if (margin >= 0) {
67+
2.0 * math.log1p(math.exp(-margin))
68+
//math.log1p(math.exp(w))
69+
} else {
70+
//w + math.log1p(math.exp(-w))
71+
2.0 * (-margin + math.log1p(math.exp(margin)))
72+
}
6573
}.mean()
6674
}
6775
}

0 commit comments

Comments
 (0)