diff --git a/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java b/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java index 029a75cc6..8283a6d40 100644 --- a/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java +++ b/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java @@ -60,7 +60,7 @@ public static > double accuracy(T label, ConfusionMatr double support = cm.support(label); // handle div-by-zero if (support == 0d) { - logger.warning("No predictions: accuracy ill-defined"); + logger.warning("No predictions for " + label + ": accuracy ill-defined"); return Double.NaN; } return cm.tp(label) / cm.support(label); diff --git a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java index c46b11b3e..115d01b44 100644 --- a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java +++ b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java @@ -28,6 +28,7 @@ import java.util.List; import java.util.Set; import java.util.function.Function; +import java.util.stream.Collectors; /** * A {@link ConfusionMatrix} which accepts {@link MultiLabel}s. @@ -158,15 +159,18 @@ public double confusion(MultiLabel predicted, MultiLabel truth) { @Override public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("["); - for (int i = 0; i < mcm.length; i++) { - DenseMatrix cm = mcm[i]; - sb.append(cm.toString()); - sb.append("\n"); - } - sb.append("]"); - return sb.toString(); + return getDomain().getDomain().stream() + .map(multiLabel -> { + final int tp = (int) tp(multiLabel); + final int fn = (int) fn(multiLabel); + final int fp = (int) fp(multiLabel); + final int tn = (int) tn(multiLabel); + return String.join("\n", + multiLabel.toString(), + String.format(" [tn: %,d fn: %,d]", tn, fn), + String.format(" [fp: %,d tp: %,d]", fp, tp)); + } + ).collect(Collectors.joining("\n")); } static ConfusionMatrixTuple tabulate(ImmutableOutputInfo domain, List> predictions) {