Skip to content

Commit 20751f8

Browse files
committed
Naive Bayes classifier bug fixes and enhancement after testing.
1 parent a0a37d5 commit 20751f8

File tree

3 files changed

+67
-37
lines changed

3 files changed

+67
-37
lines changed

edu.usc.cssl.tacit.classify.naivebayes.ui/src/edu/usc/cssl/tacit/classify/naivebayes/ui/NaiveBayesClassifierView.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ public void run() {
528528
int kValue = Integer.parseInt(tempkValue);
529529
monitor.worked(1); // done with the validation
530530
if (isPreprocessEnabled) {
531+
ConsoleView.printlInConsole("Preprocessing input data...");
531532
monitor.subTask("Preprocessing...");
532533
try {
533534
preprocessTask = new Preprocess("NB_Classifier");
@@ -699,6 +700,7 @@ public void run() {
699700
job.schedule(); // schedule the job
700701
job.addJobChangeListener(new JobChangeAdapter() {
701702

703+
@Override
702704
public void done(IJobChangeEvent event) {
703705
if (!event.getResult().isOK()) {
704706
TacitFormComposite
@@ -793,6 +795,12 @@ private boolean canItProceed(Map<String, List<String>> classPaths) {
793795
"Provide valid K-Value for cross validation", null,
794796
IMessageProvider.ERROR);
795797
return false;
798+
799+
} else if(Integer.parseInt(kValueText.getText())<2) {
800+
form.getMessageManager().addMessage("kvalue",
801+
"K-Value must be atleast 2", null,
802+
IMessageProvider.ERROR);
803+
return false;
796804
} else {
797805
form.getMessageManager().removeMessage("kvalue");
798806
}

edu.usc.cssl.tacit.classify.naivebayes/src/edu/usc/cssl/tacit/classify/naivebayes/services/NaiveBayesClassifier.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ public NaiveBayesClassifier() {
3535
new File(tempoutputDir).mkdir();
3636
}
3737
this.currTime = System.currentTimeMillis();
38-
// this.tmpLocation =
39-
// "F:\\NLP\\Naive Bayes Classifier\\2 Class Analysis\\preprocess\\NB_Classifier";
40-
4138
String outputDir = this.outputDir;
4239
if (!new File(outputDir).exists()) {
4340
new File(outputDir).mkdirs();

edu.usc.cssl.tacit.classify.naivebayes/src/edu/usc/cssl/tacit/classify/naivebayes/weka/NaiveBayesClassifierWeka.java

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.text.DateFormat;
1010
import java.text.SimpleDateFormat;
1111
import java.util.Date;
12+
import java.util.HashMap;
1213
import java.util.List;
1314
import java.util.Map;
1415
import java.util.Random;
@@ -44,30 +45,10 @@ public void initializeInstances() throws Exception{
4445
filter = new StringToWordVector();
4546
filter.setInputFormat(dataRaw);
4647
dataFiltered = Filter.useFilter(dataRaw, filter);
47-
nbc = createClassifier(dataFiltered);
48+
nbc = createClassifier(dataFiltered);
4849
}
49-
// public static void main(String[] args) throws Exception {
50-
// String[] classes = {
51-
// "F:\\NLP\\Naive Bayes Classifier\\2 Class Analysis\\Train\\Ham",
52-
// "F:\\NLP\\Naive Bayes Classifier\\2 Class Analysis\\Train\\Spam" };
53-
// // DirectoryToArff.createTrainInstances(classes);
54-
//
55-
// Instances dataRaw =null;// DirectoryToArff.loadArff();
56-
// StringToWordVector filter = new StringToWordVector();
57-
// filter.setInputFormat(dataRaw);
58-
// Instances dataFiltered = Filter.useFilter(dataRaw, filter);
59-
//
60-
// final Classifier nbc = createClassifier(dataFiltered);
61-
// crossValidate(nbc, dataFiltered, 2);
62-
//
63-
//// classify(
64-
//// nbc,
65-
//// "F:\\NLP\\Naive Bayes Classifier\\2 Class Analysis\\Classify\\Input",
66-
//// dataFiltered, filter);
67-
// }
6850

6951
public boolean doCrossValidate(int k, IProgressMonitor monitor, Date dateObj)throws Exception {
70-
7152
crossValidate(nbc, dataFiltered, k);
7253
return true;
7354
}
@@ -76,22 +57,21 @@ public boolean doClassify(String classificationInputDir, String classificationO
7657
IProgressMonitor monitor,Date dateObj) throws Exception {
7758
DateFormat df = new SimpleDateFormat("MM-dd-yy-HH-mm-ss");
7859
ConsoleView.printlInConsoleln("Classification starts ..");
79-
String outputPath = classificationOutputDir
80-
+ System.getProperty("file.separator") +"Naive_Bayes_classification_results"
81-
+ "-" + df.format(dateObj);
82-
BufferedWriter bw = new BufferedWriter(new FileWriter(new File(
83-
outputPath + "-output.csv")));
8460
Instances rawTestData = new DirectoryToArff().createTestInstances(classificationInputDir);
8561
Instances filteredTestData = Filter.useFilter(rawTestData, filter);
8662
Evaluation testEval = new Evaluation(dataFiltered);
8763
testEval.evaluateModel(nbc, filteredTestData);
8864
FastVector predictions = testEval.predictions();
65+
66+
String outputPath = classificationOutputDir + System.getProperty("file.separator") +"Naive_Bayes_classification_results" + "-" + df.format(dateObj);
67+
BufferedWriter bw = new BufferedWriter(new FileWriter(new File(outputPath + "-output.csv")));
68+
bw.write("Filename,Predicted Class\n");
8969
for (int i = 0; i < predictions.size(); i++) {
9070
NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
9171
int pred = (int) np.predicted();
92-
bw.write(DirectoryToArff.instanceIdNameMap.get(i) + "\t"
93-
+ dataFiltered.classAttribute().value(pred) +"\n");
72+
bw.write(DirectoryToArff.instanceIdNameMap.get(i) + "," + getClassName(dataFiltered.classAttribute().value(pred)) +"\n");
9473
}
74+
bw.close();
9575
return true;
9676
}
9777

@@ -102,18 +82,63 @@ private static Classifier createClassifier(Instances dataFiltered)
10282
return classifier;
10383
}
10484

105-
private static void crossValidate(Classifier nbc, Instances dataFiltered,
106-
int k) throws Exception {
85+
private static void crossValidate(Classifier nbc, Instances dataFiltered, int k) throws Exception {
86+
ConsoleView.printlInConsoleln("Cross Validating...");
10787
Evaluation eval = new Evaluation(dataFiltered);
10888
eval.crossValidateModel(nbc, dataFiltered, k, new Random(1));
109-
ConsoleView.printlInConsoleln(eval.toSummaryString("\nResults\n======\n", false));
110-
double[][] confusion = eval.confusionMatrix();
89+
ConsoleView.printlInConsoleln(eval.toSummaryString("\nK-fold Cross Validation Results\n", false));
90+
91+
//printConfusionMatrix(eval.confusionMatrix()); //Not required
92+
93+
String[] attributes = {"TP Rate", "FP Rate", "Precision", "Recall", "F-Measure", "ROC Area"};
94+
HashMap<String, HashMap<String, String>> detailedResults = new HashMap<String, HashMap<String, String>>();
95+
String[] temp = eval.toClassDetailsString().toString().split("\\n");
96+
for(int i = 3; i<temp.length-1; i++) {
97+
String[] tmp = temp[i].split("\\s+");
98+
String cName = getClassName(tmp[tmp.length-1]);
99+
HashMap<String, String> classDetails = new HashMap<String, String>();
100+
int index = 0;
101+
for(String val : tmp) {
102+
val = val.replaceAll("\\s", "");
103+
if(val.length() != 0) {
104+
classDetails.put(attributes[index], val);
105+
index++;
106+
}
107+
if(index == attributes.length) break;
108+
}
109+
detailedResults.put(cName, classDetails);
110+
}
111+
112+
StringBuilder header = new StringBuilder();
113+
header.append("Class" + "\t");
114+
for(String attr : attributes)
115+
header.append(attr + "\t");
116+
ConsoleView.printlInConsoleln(new String(header));
117+
118+
for(String cName : detailedResults.keySet()) {
119+
StringBuilder cDetails = new StringBuilder();
120+
cDetails.append(cName + "\t");
121+
for(String attr : attributes) {
122+
cDetails.append(detailedResults.get(cName).get(attr) + "\t");
123+
}
124+
ConsoleView.printlInConsoleln(new String(cDetails));
125+
}
126+
ConsoleView.printlInConsoleln();
127+
ConsoleView.printlInConsoleln("\nAccuracy: " + calculateAccuracy(eval.predictions()));
128+
}
129+
130+
private static void printConfusionMatrix(double[][] confusionMatrix) {
131+
double[][] confusion = confusionMatrix;
111132
for (int i = 0; i < confusion.length; i++) {
112133
for (int j = 0; j < confusion[0].length; j++)
113134
ConsoleView.printlInConsole(confusion[i][j] + "\t");
114135
ConsoleView.printlInConsoleln();
115-
}
116-
ConsoleView.printlInConsoleln("Accuracy:" + calculateAccuracy(eval.predictions()));
136+
}
137+
}
138+
139+
private static String getClassName(String path) {
140+
String[] temp = path.split("\\\\");
141+
return temp[temp.length-1];
117142
}
118143

119144
public static double calculateAccuracy(FastVector predictions) {

0 commit comments

Comments
 (0)