SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Thoralf Klein 00008 * Written (W) 2012 Fernando José Iglesias García 00009 * Copyright (C) 2012 Fernando José Iglesias García 00010 */ 00011 00012 #include <shogun/features/DotFeatures.h> 00013 #include <shogun/mathematics/Math.h> 00014 #include <shogun/structure/MulticlassModel.h> 00015 #include <shogun/structure/MulticlassSOLabels.h> 00016 00017 using namespace shogun; 00018 00019 CMulticlassModel::CMulticlassModel() 00020 : CStructuredModel() 00021 { 00022 init(); 00023 } 00024 00025 CMulticlassModel::CMulticlassModel(CFeatures* features, CStructuredLabels* labels) 00026 : CStructuredModel(features, labels) 00027 { 00028 init(); 00029 } 00030 00031 CMulticlassModel::~CMulticlassModel() 00032 { 00033 } 00034 00035 CStructuredLabels* CMulticlassModel::structured_labels_factory(int32_t num_labels) 00036 { 00037 return new CMulticlassSOLabels(num_labels); 00038 } 00039 00040 int32_t CMulticlassModel::get_dim() const 00041 { 00042 // TODO make the casts safe! 00043 int32_t num_classes = ((CMulticlassSOLabels*) m_labels)->get_num_classes(); 00044 int32_t feats_dim = ((CDotFeatures*) m_features)->get_dim_feature_space(); 00045 00046 return feats_dim*num_classes; 00047 } 00048 00049 SGVector< float64_t > CMulticlassModel::get_joint_feature_vector(int32_t feat_idx, CStructuredData* y) 00050 { 00051 SGVector< float64_t > psi( get_dim() ); 00052 psi.zero(); 00053 00054 SGVector< float64_t > x = ((CDotFeatures*) m_features)-> 00055 get_computed_dot_feature_vector(feat_idx); 00056 CRealNumber* r = CRealNumber::obtain_from_generic(y); 00057 ASSERT(r != NULL) 00058 float64_t label_value = r->value; 00059 00060 for ( index_t i = 0, j = label_value*x.vlen ; i < x.vlen ; ++i, ++j ) 00061 psi[j] = x[i]; 00062 00063 return psi; 00064 } 00065 00066 CResultSet* CMulticlassModel::argmax( 00067 SGVector< float64_t > w, 00068 int32_t feat_idx, 00069 bool const training) 00070 { 00071 CDotFeatures* df = (CDotFeatures*) m_features; 00072 int32_t feats_dim = df->get_dim_feature_space(); 00073 00074 if ( training ) 00075 { 00076 CMulticlassSOLabels* ml = (CMulticlassSOLabels*) m_labels; 00077 m_num_classes = ml->get_num_classes(); 00078 } 00079 else 00080 { 00081 REQUIRE(m_num_classes > 0, "The model needs to be trained before " 00082 "using it for prediction\n"); 00083 } 00084 00085 int32_t dim = get_dim(); 00086 ASSERT(dim == w.vlen) 00087 00088 // Find the class that gives the maximum score 00089 00090 float64_t score = 0, ypred = 0; 00091 float64_t max_score = -CMath::INFTY; 00092 00093 for ( int32_t c = 0 ; c < m_num_classes ; ++c ) 00094 { 00095 score = df->dense_dot(feat_idx, w.vector+c*feats_dim, feats_dim); 00096 if ( training ) 00097 score += delta_loss(feat_idx, c); 00098 00099 if ( score > max_score ) 00100 { 00101 max_score = score; 00102 ypred = c; 00103 } 00104 } 00105 00106 // Build the CResultSet object to return 00107 CResultSet* ret = new CResultSet(); 00108 SG_REF(ret); 00109 CRealNumber* y = new CRealNumber(ypred); 00110 SG_REF(y); 00111 00112 ret->psi_pred = get_joint_feature_vector(feat_idx, y); 00113 ret->score = max_score; 00114 ret->argmax = y; 00115 if ( training ) 00116 { 00117 ret->delta = CStructuredModel::delta_loss(feat_idx, y); 00118 ret->psi_truth = CStructuredModel::get_joint_feature_vector( 00119 feat_idx, feat_idx); 00120 ret->score -= SGVector< float64_t >::dot(w.vector, 00121 ret->psi_truth.vector, dim); 00122 } 00123 00124 return ret; 00125 } 00126 00127 float64_t CMulticlassModel::delta_loss(CStructuredData* y1, CStructuredData* y2) 00128 { 00129 CRealNumber* rn1 = CRealNumber::obtain_from_generic(y1); 00130 CRealNumber* rn2 = CRealNumber::obtain_from_generic(y2); 00131 ASSERT(rn1 != NULL) 00132 ASSERT(rn2 != NULL) 00133 00134 return delta_loss(rn1->value, rn2->value); 00135 } 00136 00137 float64_t CMulticlassModel::delta_loss(int32_t y1_idx, float64_t y2) 00138 { 00139 REQUIRE(y1_idx >= 0 || y1_idx < m_labels->get_num_labels(), 00140 "The label index must be inside [0, num_labels-1]\n"); 00141 00142 CRealNumber* rn1 = CRealNumber::obtain_from_generic(m_labels->get_label(y1_idx)); 00143 float64_t ret = delta_loss(rn1->value, y2); 00144 SG_UNREF(rn1); 00145 00146 return ret; 00147 } 00148 00149 float64_t CMulticlassModel::delta_loss(float64_t y1, float64_t y2) 00150 { 00151 return (y1 == y2) ? 0 : 1; 00152 } 00153 00154 void CMulticlassModel::init_primal_opt( 00155 float64_t regularization, 00156 SGMatrix< float64_t > & A, 00157 SGVector< float64_t > a, 00158 SGMatrix< float64_t > B, 00159 SGVector< float64_t > & b, 00160 SGVector< float64_t > lb, 00161 SGVector< float64_t > ub, 00162 SGMatrix< float64_t > & C) 00163 { 00164 C = SGMatrix< float64_t >::create_identity_matrix(get_dim(), regularization); 00165 } 00166 00167 void CMulticlassModel::init() 00168 { 00169 SG_ADD(&m_num_classes, "m_num_classes", "The number of classes", 00170 MS_NOT_AVAILABLE); 00171 00172 m_num_classes = 0; 00173 } 00174