SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Philippe Tillet 00008 */ 00009 00010 #include <shogun/classifier/NearestCentroid.h> 00011 #include <shogun/labels/MulticlassLabels.h> 00012 #include <shogun/features/Features.h> 00013 #include <shogun/features/FeatureTypes.h> 00014 00015 00016 00017 namespace shogun{ 00018 00019 CNearestCentroid::CNearestCentroid() : CDistanceMachine() 00020 { 00021 init(); 00022 } 00023 00024 CNearestCentroid::CNearestCentroid(CDistance* d, CLabels* trainlab) : CDistanceMachine() 00025 { 00026 init(); 00027 ASSERT(d) 00028 ASSERT(trainlab) 00029 set_distance(d); 00030 set_labels(trainlab); 00031 } 00032 00033 CNearestCentroid::~CNearestCentroid() 00034 { 00035 if(m_is_trained) 00036 distance->remove_lhs(); 00037 else 00038 delete m_centroids; 00039 } 00040 00041 void CNearestCentroid::init() 00042 { 00043 m_shrinking=0; 00044 m_is_trained=false; 00045 m_centroids = new CDenseFeatures<float64_t>(); 00046 } 00047 00048 00049 bool CNearestCentroid::train_machine(CFeatures* data) 00050 { 00051 ASSERT(m_labels) 00052 ASSERT(m_labels->get_label_type() == LT_MULTICLASS) 00053 ASSERT(distance) 00054 ASSERT( data->get_feature_class() == C_DENSE) 00055 if (data) 00056 { 00057 if (m_labels->get_num_labels() != data->get_num_vectors()) 00058 SG_ERROR("Number of training vectors does not match number of labels\n") 00059 distance->init(data, data); 00060 } 00061 else 00062 { 00063 data = distance->get_lhs(); 00064 } 00065 int32_t num_vectors = data->get_num_vectors(); 00066 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00067 int32_t num_feats = ((CDenseFeatures<float64_t>*) data)->get_num_features(); 00068 SGMatrix<float64_t> centroids(num_feats,num_classes); 00069 centroids.zero(); 00070 00071 m_centroids->set_num_features(num_feats); 00072 m_centroids->set_num_vectors(num_classes); 00073 00074 int64_t* num_per_class = new int64_t[num_classes]; 00075 for (int32_t i=0 ; i<num_classes ; i++) 00076 { 00077 num_per_class[i]=0; 00078 } 00079 00080 for (int32_t idx=0 ; idx<num_vectors ; idx++) 00081 { 00082 int32_t current_len; 00083 bool current_free; 00084 int32_t current_class = ((CMulticlassLabels*) m_labels)->get_label(idx); 00085 float64_t* target = centroids.matrix + num_feats*current_class; 00086 float64_t* current = ((CDenseFeatures<float64_t>*)data)->get_feature_vector(idx,current_len,current_free); 00087 SGVector<float64_t>::add(target,1.0,target,1.0,current,current_len); 00088 num_per_class[current_class]++; 00089 ((CDenseFeatures<float64_t>*)data)->free_feature_vector(current, current_len, current_free); 00090 } 00091 00092 00093 for (int32_t i=0 ; i<num_classes ; i++) 00094 { 00095 float64_t* target = centroids.matrix + num_feats*i; 00096 int32_t total = num_per_class[i]; 00097 float64_t scale = 0; 00098 if(total>1) 00099 scale = 1.0/((float64_t)(total-1)); 00100 else 00101 scale = 1.0/(float64_t)total; 00102 00103 SGVector<float64_t>::scale_vector(scale,target,num_feats); 00104 } 00105 00106 m_centroids->free_feature_matrix(); 00107 m_centroids->set_feature_matrix(centroids); 00108 00109 00110 m_is_trained=true; 00111 distance->init(m_centroids,distance->get_rhs()); 00112 00113 SG_FREE(num_per_class); 00114 00115 return true; 00116 } 00117 00118 }