SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2008-2010 Soeren Sonnenburg 00008 * Written (W) 2011-2013 Heiko Strathmann 00009 * Written (W) 2013 Thoralf Klein 00010 * Copyright (C) 2008-2010 Fraunhofer Institute FIRST and Max Planck Society 00011 */ 00012 00013 #ifndef __SGOBJECT_H__ 00014 #define __SGOBJECT_H__ 00015 00016 #include <shogun/lib/config.h> 00017 #include <shogun/lib/common.h> 00018 #include <shogun/lib/DataType.h> 00019 #include <shogun/base/SGRefObject.h> 00020 #include <shogun/lib/ShogunException.h> 00021 00022 #include <shogun/base/Parallel.h> 00023 #include <shogun/base/Version.h> 00024 #include <shogun/io/SGIO.h> 00025 00029 namespace shogun 00030 { 00031 class IO; 00032 class Parallel; 00033 class Version; 00034 class Parameter; 00035 class ParameterMap; 00036 class SGParamInfo; 00037 class SGRefObject; 00038 class CSerializableFile; 00039 00040 template <class T, class K> class CMap; 00041 00042 struct TParameter; 00043 template <class T> class DynArray; 00044 template <class T> class SGStringList; 00045 00046 /******************************************************************************* 00047 * Macros for registering parameters/model selection parameters 00048 ******************************************************************************/ 00049 00050 #define VA_NARGS_IMPL(_1, _2, _3, _4, _5, N, ...) N 00051 #define VA_NARGS(...) VA_NARGS_IMPL(__VA_ARGS__, 5, 4, 3, 2, 1) 00052 00053 #define VARARG_IMPL2(base, count, ...) base##count(__VA_ARGS__) 00054 #define VARARG_IMPL(base, count, ...) VARARG_IMPL2(base, count, __VA_ARGS__) 00055 #define VARARG(base, ...) VARARG_IMPL(base, VA_NARGS(__VA_ARGS__), __VA_ARGS__) 00056 00057 #define SG_ADD4(param, name, description, ms_available) {\ 00058 m_parameters->add(param, name, description);\ 00059 if (ms_available)\ 00060 m_model_selection_parameters->add(param, name, description);\ 00061 } 00062 00063 #define SG_ADD5(param, name, description, ms_available, gradient_available) {\ 00064 m_parameters->add(param, name, description);\ 00065 if (ms_available)\ 00066 m_model_selection_parameters->add(param, name, description);\ 00067 if (gradient_available)\ 00068 m_gradient_parameters->add(param, name, description);\ 00069 } 00070 00071 #define SG_ADD(...) VARARG(SG_ADD, __VA_ARGS__) 00072 00073 /******************************************************************************* 00074 * End of macros for registering parameters/model selection parameters 00075 ******************************************************************************/ 00076 00078 enum EModelSelectionAvailability { 00079 MS_NOT_AVAILABLE=0, 00080 MS_AVAILABLE=1, 00081 }; 00082 00084 enum EGradientAvailability 00085 { 00086 GRADIENT_NOT_AVAILABLE=0, 00087 GRADIENT_AVAILABLE=1 00088 }; 00089 00102 class CSGObject : public SGRefObject 00103 { 00104 public: 00106 CSGObject(); 00107 00109 CSGObject(const CSGObject& orig); 00110 00112 virtual ~CSGObject(); 00113 00117 virtual CSGObject *shallow_copy() const 00118 { 00119 SG_NOTIMPLEMENTED 00120 return NULL; 00121 } 00122 00126 virtual CSGObject *deep_copy() const 00127 { 00128 SG_NOTIMPLEMENTED 00129 return NULL; 00130 } 00131 00137 virtual const char* get_name() const = 0; 00138 00147 virtual bool is_generic(EPrimitiveType* generic) const; 00148 00151 template<class T> void set_generic(); 00152 00157 void unset_generic(); 00158 00163 virtual void print_serializable(const char* prefix=""); 00164 00174 virtual bool save_serializable(CSerializableFile* file, 00175 const char* prefix="", int32_t param_version=Version::get_version_parameter()); 00176 00188 virtual bool load_serializable(CSerializableFile* file, 00189 const char* prefix="", int32_t param_version=Version::get_version_parameter()); 00190 00204 DynArray<TParameter*>* load_file_parameters(const SGParamInfo* param_info, 00205 int32_t file_version, CSerializableFile* file, 00206 const char* prefix=""); 00207 00220 DynArray<TParameter*>* load_all_file_parameters(int32_t file_version, 00221 int32_t current_version, 00222 CSerializableFile* file, const char* prefix=""); 00223 00238 void map_parameters(DynArray<TParameter*>* param_base, 00239 int32_t& base_version, 00240 DynArray<const SGParamInfo*>* target_param_infos); 00241 00246 void set_global_io(SGIO* io); 00247 00252 SGIO* get_global_io(); 00253 00258 void set_global_parallel(Parallel* parallel); 00259 00264 Parallel* get_global_parallel(); 00265 00270 void set_global_version(Version* version); 00271 00276 Version* get_global_version(); 00277 00280 SGStringList<char> get_modelsel_names(); 00281 00283 void print_modsel_params(); 00284 00291 char* get_modsel_param_descr(const char* param_name); 00292 00299 index_t get_modsel_param_index(const char* param_name); 00300 00307 void build_gradient_parameter_dictionary(CMap<TParameter*, CSGObject*>* dict); 00308 00309 protected: 00329 virtual TParameter* migrate(DynArray<TParameter*>* param_base, 00330 const SGParamInfo* target); 00331 00354 virtual void one_to_one_migration_prepare(DynArray<TParameter*>* param_base, 00355 const SGParamInfo* target, TParameter*& replacement, 00356 TParameter*& to_migrate, char* old_name=NULL); 00357 00366 virtual void load_serializable_pre() throw (ShogunException); 00367 00376 virtual void load_serializable_post() throw (ShogunException); 00377 00386 virtual void save_serializable_pre() throw (ShogunException); 00387 00396 virtual void save_serializable_post() throw (ShogunException); 00397 00398 public: 00404 virtual bool update_parameter_hash(); 00405 00417 virtual bool equals(CSGObject* other, float64_t accuracy=0.0); 00418 00427 virtual CSGObject* clone(); 00428 00429 private: 00430 void set_global_objects(); 00431 void unset_global_objects(); 00432 void init(); 00433 00439 bool is_param_new(const SGParamInfo param_info) const; 00440 00449 bool save_parameter_version(CSerializableFile* file, const char* prefix="", 00450 int32_t param_version=Version::get_version_parameter()); 00451 00455 int32_t load_parameter_version(CSerializableFile* file, 00456 const char* prefix=""); 00457 00458 /*Gets an incremental hash of all parameters as well as the parameters 00459 * of CSGObject children of the current object's parameters. 00460 * 00461 * @param param Parameter to hash 00462 * @param current hash 00463 * @param carry value for Murmur3 incremental hash 00464 * @param total_length total byte length of all hashed 00465 * parameters so far. Byte length of parameters will be added 00466 * to the total length 00467 */ 00468 void get_parameter_incremental_hash(Parameter* param, 00469 uint32_t& hash, uint32_t& carry, uint32_t& total_length); 00470 00471 public: 00473 SGIO* io; 00474 00476 Parallel* parallel; 00477 00479 Version* version; 00480 00482 Parameter* m_parameters; 00483 00485 Parameter* m_model_selection_parameters; 00486 00488 Parameter* m_gradient_parameters; 00489 00491 ParameterMap* m_parameter_map; 00492 00494 uint32_t m_hash; 00495 00496 private: 00497 00498 EPrimitiveType m_generic; 00499 bool m_load_pre_called; 00500 bool m_load_post_called; 00501 bool m_save_pre_called; 00502 bool m_save_post_called; 00503 }; 00504 } 00505 #endif // __SGOBJECT_H__