SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Thoralf Klein 00008 * Written (W) 2012 Fernando José Iglesias García 00009 * Copyright (C) 2012 Fernando José Iglesias García 00010 */ 00011 00012 #include <shogun/structure/StructuredModel.h> 00013 00014 using namespace shogun; 00015 00016 CResultSet::CResultSet() : CSGObject(), argmax(NULL) 00017 { 00018 } 00019 00020 CResultSet::~CResultSet() 00021 { 00022 SG_UNREF(argmax) 00023 } 00024 00025 CStructuredLabels* CStructuredModel::structured_labels_factory(int32_t num_labels) 00026 { 00027 return new CStructuredLabels(num_labels); 00028 } 00029 00030 const char* CResultSet::get_name() const 00031 { 00032 return "ResultSet"; 00033 } 00034 00035 CStructuredModel::CStructuredModel() : CSGObject() 00036 { 00037 init(); 00038 } 00039 00040 CStructuredModel::CStructuredModel( 00041 CFeatures* features, 00042 CStructuredLabels* labels) 00043 : CSGObject() 00044 { 00045 init(); 00046 00047 set_labels(labels); 00048 set_features(features); 00049 } 00050 00051 CStructuredModel::~CStructuredModel() 00052 { 00053 SG_UNREF(m_labels); 00054 SG_UNREF(m_features); 00055 } 00056 00057 void CStructuredModel::init_primal_opt( 00058 float64_t regularization, 00059 SGMatrix< float64_t > & A, 00060 SGVector< float64_t > a, 00061 SGMatrix< float64_t > B, 00062 SGVector< float64_t > & b, 00063 SGVector< float64_t > lb, 00064 SGVector< float64_t > ub, 00065 SGMatrix< float64_t > & C) 00066 { 00067 SG_ERROR("init_primal_opt is not implemented for %s!\n", get_name()) 00068 } 00069 00070 void CStructuredModel::set_labels(CStructuredLabels* labels) 00071 { 00072 SG_REF(labels); 00073 SG_UNREF(m_labels); 00074 m_labels = labels; 00075 } 00076 00077 CStructuredLabels* CStructuredModel::get_labels() 00078 { 00079 SG_REF(m_labels); 00080 return m_labels; 00081 } 00082 00083 void CStructuredModel::set_features(CFeatures* features) 00084 { 00085 SG_REF(features); 00086 SG_UNREF(m_features); 00087 m_features = features; 00088 } 00089 00090 CFeatures* CStructuredModel::get_features() 00091 { 00092 SG_REF(m_features); 00093 return m_features; 00094 } 00095 00096 SGVector< float64_t > CStructuredModel::get_joint_feature_vector( 00097 int32_t feat_idx, 00098 int32_t lab_idx) 00099 { 00100 CStructuredData* label = m_labels->get_label(lab_idx); 00101 SGVector< float64_t > ret = get_joint_feature_vector(feat_idx, label); 00102 SG_UNREF(label); 00103 00104 return ret; 00105 } 00106 00107 SGVector< float64_t > CStructuredModel::get_joint_feature_vector( 00108 int32_t feat_idx, 00109 CStructuredData* y) 00110 { 00111 SG_ERROR("compute_joint_feature(int32_t, CStructuredData*) is not " 00112 "implemented for %s!\n", get_name()); 00113 00114 return SGVector< float64_t >(); 00115 } 00116 00117 float64_t CStructuredModel::delta_loss(int32_t ytrue_idx, CStructuredData* ypred) 00118 { 00119 REQUIRE(ytrue_idx >= 0 || ytrue_idx < m_labels->get_num_labels(), 00120 "The label index must be inside [0, num_labels-1]\n"); 00121 00122 CStructuredData* ytrue = m_labels->get_label(ytrue_idx); 00123 float64_t ret = delta_loss(ytrue, ypred); 00124 SG_UNREF(ytrue); 00125 00126 return ret; 00127 } 00128 00129 float64_t CStructuredModel::delta_loss(CStructuredData* y1, CStructuredData* y2) 00130 { 00131 SG_ERROR("delta_loss(CStructuredData*, CStructuredData*) is not " 00132 "implemented for %s!\n", get_name()); 00133 00134 return 0.0; 00135 } 00136 00137 void CStructuredModel::init() 00138 { 00139 SG_ADD((CSGObject**) &m_labels, "m_labels", "Structured labels", 00140 MS_NOT_AVAILABLE); 00141 SG_ADD((CSGObject**) &m_features, "m_features", "Feature vectors", 00142 MS_NOT_AVAILABLE); 00143 00144 m_features = NULL; 00145 m_labels = NULL; 00146 } 00147 00148 void CStructuredModel::init_training() 00149 { 00150 // Nothing to do here 00151 } 00152 00153 bool CStructuredModel::check_training_setup() const 00154 { 00155 // Nothing to do here 00156 return true; 00157 } 00158 00159 int32_t CStructuredModel::get_num_aux() const 00160 { 00161 return 0; 00162 } 00163 00164 int32_t CStructuredModel::get_num_aux_con() const 00165 { 00166 return 0; 00167 }