SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <shogun/evaluation/CrossValidationSplitting.h> 00012 #include <shogun/multiclass/tree/RelaxedTreeUtil.h> 00013 #include <shogun/evaluation/MulticlassAccuracy.h> 00014 00015 using namespace shogun; 00016 00017 SGMatrix<float64_t> RelaxedTreeUtil::estimate_confusion_matrix(CBaseMulticlassMachine *machine, CFeatures *X, CMulticlassLabels *Y, int32_t num_classes) 00018 { 00019 const int32_t N_splits = 2; // 5 00020 CCrossValidationSplitting *split = new CCrossValidationSplitting(Y, N_splits); 00021 split->build_subsets(); 00022 00023 SGMatrix<float64_t> conf_mat(num_classes, num_classes), tmp_mat(num_classes, num_classes); 00024 conf_mat.zero(); 00025 00026 machine->set_labels(Y); 00027 machine->set_store_model_features(true); 00028 00029 for (int32_t i=0; i < N_splits; ++i) 00030 { 00031 // subset for training 00032 SGVector<index_t> inverse_subset_indices = split->generate_subset_inverse(i); 00033 X->add_subset(inverse_subset_indices); 00034 Y->add_subset(inverse_subset_indices); 00035 00036 machine->train(X); 00037 X->remove_subset(); 00038 Y->remove_subset(); 00039 00040 // subset for predicting 00041 SGVector<index_t> subset_indices = split->generate_subset_indices(i); 00042 X->add_subset(subset_indices); 00043 Y->add_subset(subset_indices); 00044 00045 CMulticlassLabels *pred = machine->apply_multiclass(X); 00046 00047 get_confusion_matrix(tmp_mat, Y, pred); 00048 00049 for (index_t j=0; j < tmp_mat.num_rows; ++j) 00050 { 00051 for (index_t k=0; k < tmp_mat.num_cols; ++k) 00052 { 00053 conf_mat(j, k) += tmp_mat(j, k); 00054 } 00055 } 00056 00057 SG_UNREF(pred); 00058 00059 X->remove_subset(); 00060 Y->remove_subset(); 00061 } 00062 00063 SG_UNREF(split); 00064 00065 for (index_t j=0; j < tmp_mat.num_rows; ++j) 00066 { 00067 for (index_t k=0; k < tmp_mat.num_cols; ++k) 00068 { 00069 conf_mat(j, k) /= N_splits; 00070 } 00071 } 00072 00073 return conf_mat; 00074 } 00075 00076 void RelaxedTreeUtil::get_confusion_matrix(SGMatrix<float64_t> &conf_mat, CMulticlassLabels *gt, CMulticlassLabels *pred) 00077 { 00078 SGMatrix<int32_t> conf_mat_int = CMulticlassAccuracy::get_confusion_matrix(pred, gt); 00079 00080 for (index_t i=0; i < conf_mat.num_rows; ++i) 00081 { 00082 float64_t n=0; 00083 for (index_t j=0; j < conf_mat.num_cols; ++j) 00084 { 00085 conf_mat(i, j) = conf_mat_int(i, j); 00086 n += conf_mat(i, j); 00087 } 00088 00089 if (n != 0) 00090 { 00091 for (index_t j=0; j < conf_mat.num_cols; ++j) 00092 conf_mat(i, j) /= n; 00093 } 00094 } 00095 }