SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 2011-2012 Heiko Strathmann 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include <shogun/machine/Machine.h> 00013 #include <shogun/base/Parameter.h> 00014 #include <shogun/mathematics/Math.h> 00015 #include <shogun/base/ParameterMap.h> 00016 00017 using namespace shogun; 00018 00019 CMachine::CMachine() : CSGObject(), m_max_train_time(0), m_labels(NULL), 00020 m_solver_type(ST_AUTO) 00021 { 00022 m_data_locked=false; 00023 m_store_model_features=false; 00024 00025 SG_ADD(&m_max_train_time, "max_train_time", 00026 "Maximum training time.", MS_NOT_AVAILABLE); 00027 SG_ADD((machine_int_t*) &m_solver_type, "solver_type", 00028 "Type of solver.", MS_NOT_AVAILABLE); 00029 SG_ADD((CSGObject**) &m_labels, "labels", 00030 "Labels to be used.", MS_NOT_AVAILABLE); 00031 SG_ADD(&m_store_model_features, "store_model_features", 00032 "Should feature data of model be stored after training?", MS_NOT_AVAILABLE); 00033 SG_ADD(&m_data_locked, "data_locked", 00034 "Indicates whether data is locked", MS_NOT_AVAILABLE); 00035 00036 m_parameter_map->put( 00037 new SGParamInfo("data_locked", CT_SCALAR, ST_NONE, PT_BOOL, 1), 00038 new SGParamInfo() 00039 ); 00040 00041 m_parameter_map->finalize_map(); 00042 } 00043 00044 CMachine::~CMachine() 00045 { 00046 SG_UNREF(m_labels); 00047 } 00048 00049 bool CMachine::train(CFeatures* data) 00050 { 00051 /* not allowed to train on locked data */ 00052 if (m_data_locked) 00053 { 00054 SG_ERROR("%s::train data_lock() was called, only train_locked() is" 00055 " possible. Call data_unlock if you want to call train()\n", 00056 get_name()); 00057 } 00058 00059 if (train_require_labels()) 00060 { 00061 if (m_labels == NULL) 00062 SG_ERROR("%s@%p: No labels given", get_name(), this) 00063 00064 m_labels->ensure_valid(get_name()); 00065 } 00066 00067 bool result = train_machine(data); 00068 00069 if (m_store_model_features) 00070 store_model_features(); 00071 00072 return result; 00073 } 00074 00075 void CMachine::set_labels(CLabels* lab) 00076 { 00077 if (lab != NULL) 00078 if (!is_label_valid(lab)) 00079 SG_ERROR("Invalid label for %s", get_name()) 00080 00081 SG_REF(lab); 00082 SG_UNREF(m_labels); 00083 m_labels = lab; 00084 } 00085 00086 CLabels* CMachine::get_labels() 00087 { 00088 SG_REF(m_labels); 00089 return m_labels; 00090 } 00091 00092 void CMachine::set_max_train_time(float64_t t) 00093 { 00094 m_max_train_time = t; 00095 } 00096 00097 float64_t CMachine::get_max_train_time() 00098 { 00099 return m_max_train_time; 00100 } 00101 00102 EMachineType CMachine::get_classifier_type() 00103 { 00104 return CT_NONE; 00105 } 00106 00107 void CMachine::set_solver_type(ESolverType st) 00108 { 00109 m_solver_type = st; 00110 } 00111 00112 ESolverType CMachine::get_solver_type() 00113 { 00114 return m_solver_type; 00115 } 00116 00117 void CMachine::set_store_model_features(bool store_model) 00118 { 00119 m_store_model_features = store_model; 00120 } 00121 00122 void CMachine::data_lock(CLabels* labs, CFeatures* features) 00123 { 00124 SG_DEBUG("entering %s::data_lock\n", get_name()) 00125 if (!supports_locking()) 00126 { 00127 { 00128 SG_ERROR("%s::data_lock(): Machine does not support data locking!\n", 00129 get_name()); 00130 } 00131 } 00132 00133 if (!labs) 00134 { 00135 SG_ERROR("%s::data_lock() is not possible will NULL labels!\n", 00136 get_name()); 00137 } 00138 00139 /* first set labels */ 00140 set_labels(labs); 00141 00142 if (m_data_locked) 00143 { 00144 SG_ERROR("%s::data_lock() was already called. Dont lock twice!", 00145 get_name()); 00146 } 00147 00148 m_data_locked=true; 00149 post_lock(labs,features); 00150 SG_DEBUG("leaving %s::data_lock\n", get_name()) 00151 } 00152 00153 void CMachine::data_unlock() 00154 { 00155 SG_DEBUG("entering %s::data_lock\n", get_name()) 00156 if (m_data_locked) 00157 m_data_locked=false; 00158 00159 SG_DEBUG("leaving %s::data_lock\n", get_name()) 00160 } 00161 00162 CLabels* CMachine::apply(CFeatures* data) 00163 { 00164 SG_DEBUG("entering %s::apply(%s at %p)\n", 00165 get_name(), data ? data->get_name() : "NULL", data); 00166 00167 CLabels* result=NULL; 00168 00169 switch (get_machine_problem_type()) 00170 { 00171 case PT_BINARY: 00172 result=apply_binary(data); 00173 break; 00174 case PT_REGRESSION: 00175 result=apply_regression(data); 00176 break; 00177 case PT_MULTICLASS: 00178 result=apply_multiclass(data); 00179 break; 00180 case PT_STRUCTURED: 00181 result=apply_structured(data); 00182 break; 00183 case PT_LATENT: 00184 result=apply_latent(data); 00185 break; 00186 default: 00187 SG_ERROR("Unknown problem type") 00188 break; 00189 } 00190 00191 SG_DEBUG("leaving %s::apply(%s at %p)\n", 00192 get_name(), data ? data->get_name() : "NULL", data); 00193 00194 return result; 00195 } 00196 00197 CLabels* CMachine::apply_locked(SGVector<index_t> indices) 00198 { 00199 switch (get_machine_problem_type()) 00200 { 00201 case PT_BINARY: 00202 return apply_locked_binary(indices); 00203 case PT_REGRESSION: 00204 return apply_locked_regression(indices); 00205 case PT_MULTICLASS: 00206 return apply_locked_multiclass(indices); 00207 case PT_STRUCTURED: 00208 return apply_locked_structured(indices); 00209 case PT_LATENT: 00210 return apply_locked_latent(indices); 00211 default: 00212 SG_ERROR("Unknown problem type") 00213 break; 00214 } 00215 return NULL; 00216 } 00217 00218 CBinaryLabels* CMachine::apply_binary(CFeatures* data) 00219 { 00220 SG_ERROR("This machine does not support apply_binary()\n") 00221 return NULL; 00222 } 00223 00224 CRegressionLabels* CMachine::apply_regression(CFeatures* data) 00225 { 00226 SG_ERROR("This machine does not support apply_regression()\n") 00227 return NULL; 00228 } 00229 00230 CMulticlassLabels* CMachine::apply_multiclass(CFeatures* data) 00231 { 00232 SG_ERROR("This machine does not support apply_multiclass()\n") 00233 return NULL; 00234 } 00235 00236 CStructuredLabels* CMachine::apply_structured(CFeatures* data) 00237 { 00238 SG_ERROR("This machine does not support apply_structured()\n") 00239 return NULL; 00240 } 00241 00242 CLatentLabels* CMachine::apply_latent(CFeatures* data) 00243 { 00244 SG_ERROR("This machine does not support apply_latent()\n") 00245 return NULL; 00246 } 00247 00248 CBinaryLabels* CMachine::apply_locked_binary(SGVector<index_t> indices) 00249 { 00250 SG_ERROR("apply_locked_binary(SGVector<index_t>) is not yet implemented " 00251 "for %s\n", get_name()); 00252 return NULL; 00253 } 00254 00255 CRegressionLabels* CMachine::apply_locked_regression(SGVector<index_t> indices) 00256 { 00257 SG_ERROR("apply_locked_regression(SGVector<index_t>) is not yet implemented " 00258 "for %s\n", get_name()); 00259 return NULL; 00260 } 00261 00262 CMulticlassLabels* CMachine::apply_locked_multiclass(SGVector<index_t> indices) 00263 { 00264 SG_ERROR("apply_locked_multiclass(SGVector<index_t>) is not yet implemented " 00265 "for %s\n", get_name()); 00266 return NULL; 00267 } 00268 00269 CStructuredLabels* CMachine::apply_locked_structured(SGVector<index_t> indices) 00270 { 00271 SG_ERROR("apply_locked_structured(SGVector<index_t>) is not yet implemented " 00272 "for %s\n", get_name()); 00273 return NULL; 00274 } 00275 00276 CLatentLabels* CMachine::apply_locked_latent(SGVector<index_t> indices) 00277 { 00278 SG_ERROR("apply_locked_latent(SGVector<index_t>) is not yet implemented " 00279 "for %s\n", get_name()); 00280 return NULL; 00281 } 00282 00283