本文共 2730 字,大约阅读时间需要 9 分钟。
package cn.edu.xmu.bdm.wekainjava.test;import java.io.File;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.bayes.NaiveBayes;import weka.classifiers.functions.LibSVM;import weka.classifiers.meta.Vote;import weka.core.Instances;import weka.core.SelectedTag;import cn.edu.xmu.bdm.wekainjava.utils.WekaFactory;import cn.edu.xmu.bdm.wekainjava.utils.WekaFactoryImpl;public class EnsembleTest {public static void main(String[] args) throws Exception {// LibSVM classifier = new LibSVM();File trainFile = new File("C://Program Files//Weka-3-6//data//segment-challenge.arff");File testFile = new File("C://Program Files//Weka-3-6//data//segment-test.arff");/*** 1. 获取weka工厂类*/WekaFactory wi = WekaFactoryImpl.getInstance();/*** 3. 从工厂中获取训练样本和测试样本实例*/Instances instancesTrain = wi.getInstance(trainFile);instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);Instances instancesTest = wi.getInstance(testFile);instancesTest.setClassIndex(instancesTest.numAttributes() - 1);/*** 2. 从工厂中获取分类器 具体使用哪一种特定的分类器可以选择 这样就构建了一个简单的分类器*/Classifier j48 = (Classifier) wi.getClassifier(LibSVM.class);Classifier naiveBayes = (Classifier)wi.getClassifier(NaiveBayes.class);Classifier libSVM = (Classifier)wi.getClassifier(LibSVM.class);/*** 2.1 设置集成分类器*/Classifier[] cfsArray = new Classifier[3]; cfsArray[0] = j48;cfsArray[1] = naiveBayes;cfsArray[2] = libSVM;/*** 2.2 定制集成分类器的决策方式* AVERAGE_RULE* PRODUCT_RULE* MAJORITY_VOTING_RULE* MIN_RULE* MAX_RULE* MEDIAN_RULE* 它们具体的工作方式,参考weka的说明文档。* 通常情况下选择的是多数投票的决策规则*/Vote ensemble = new Vote();SelectedTag tag = new SelectedTag(Vote.MAJORITY_VOTING_RULE, Vote.TAGS_RULES);ensemble.setCombinationRule(tag);ensemble.setClassifiers(cfsArray);//设置随机数种子ensemble.setSeed(2);//训练ensemble分类器ensemble.buildClassifier(instancesTrain);/*** 5. 从工厂中获取使用Evaluation,测试样本测试分类器的学习效果*/double sum = instancesTrain.numInstances();Evaluation testingEvaluation = wi.getEvaluation(ensemble, instancesTest);int length = instancesTest.numInstances();for (int i = 0; i < length; i++) {// 通过这个方法来用每个测试样本测试分类器的效果testingEvaluation.evaluateModelOnceAndRecordPrediction(ensemble,instancesTest.instance(i));}// double[][] confusionMatrix = testingEvaluation.confusionMatrix();// for (int i = 0; i < confusionMatrix.length; i++) {// double[] ds = confusionMatrix[i];// for (int j = 0; j < ds.length; j++) {// System.out.print(ds[j]);// }// System.out.println();// }System.out.println(testingEvaluation.toSummaryString());System.out.println(testingEvaluation.toMatrixString());System.out.println(testingEvaluation.toClassDetailsString());// System.out.println(testingEvaluation.toCumulativeMarginDistributionString());System.out.println("分类器的正确率:" + (1 - testingEvaluation.errorRate()));}}
转载地址:http://zvwob.baihongyu.com/