SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2008-2009 Soeren Sonnenburg 00008 * Written (W) 2011-2013 Heiko Strathmann 00009 * Written (W) 2013 Thoralf Klein 00010 * Copyright (C) 2008-2009 Fraunhofer Institute FIRST and Max Planck Society 00011 */ 00012 00013 #include <shogun/lib/config.h> 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/io/SGIO.h> 00016 #include <shogun/base/Parallel.h> 00017 #include <shogun/base/init.h> 00018 #include <shogun/base/Version.h> 00019 #include <shogun/base/Parameter.h> 00020 #include <shogun/base/ParameterMap.h> 00021 #include <shogun/base/DynArray.h> 00022 #include <shogun/lib/Map.h> 00023 #include <shogun/lib/SGStringList.h> 00024 00025 #include "class_list.h" 00026 00027 #include <stdlib.h> 00028 #include <stdio.h> 00029 00030 namespace shogun 00031 { 00032 class CMath; 00033 class Parallel; 00034 class IO; 00035 class Version; 00036 00037 extern Parallel* sg_parallel; 00038 extern SGIO* sg_io; 00039 extern Version* sg_version; 00040 00041 template<> void CSGObject::set_generic<bool>() 00042 { 00043 m_generic = PT_BOOL; 00044 } 00045 00046 template<> void CSGObject::set_generic<char>() 00047 { 00048 m_generic = PT_CHAR; 00049 } 00050 00051 template<> void CSGObject::set_generic<int8_t>() 00052 { 00053 m_generic = PT_INT8; 00054 } 00055 00056 template<> void CSGObject::set_generic<uint8_t>() 00057 { 00058 m_generic = PT_UINT8; 00059 } 00060 00061 template<> void CSGObject::set_generic<int16_t>() 00062 { 00063 m_generic = PT_INT16; 00064 } 00065 00066 template<> void CSGObject::set_generic<uint16_t>() 00067 { 00068 m_generic = PT_UINT16; 00069 } 00070 00071 template<> void CSGObject::set_generic<int32_t>() 00072 { 00073 m_generic = PT_INT32; 00074 } 00075 00076 template<> void CSGObject::set_generic<uint32_t>() 00077 { 00078 m_generic = PT_UINT32; 00079 } 00080 00081 template<> void CSGObject::set_generic<int64_t>() 00082 { 00083 m_generic = PT_INT64; 00084 } 00085 00086 template<> void CSGObject::set_generic<uint64_t>() 00087 { 00088 m_generic = PT_UINT64; 00089 } 00090 00091 template<> void CSGObject::set_generic<float32_t>() 00092 { 00093 m_generic = PT_FLOAT32; 00094 } 00095 00096 template<> void CSGObject::set_generic<float64_t>() 00097 { 00098 m_generic = PT_FLOAT64; 00099 } 00100 00101 template<> void CSGObject::set_generic<floatmax_t>() 00102 { 00103 m_generic = PT_FLOATMAX; 00104 } 00105 00106 template<> void CSGObject::set_generic<CSGObject*>() 00107 { 00108 m_generic = PT_SGOBJECT; 00109 } 00110 00111 template<> void CSGObject::set_generic<complex128_t>() 00112 { 00113 m_generic = PT_COMPLEX128; 00114 } 00115 00116 } /* namespace shogun */ 00117 00118 using namespace shogun; 00119 00120 CSGObject::CSGObject() 00121 : SGRefObject() 00122 { 00123 init(); 00124 set_global_objects(); 00125 } 00126 00127 CSGObject::CSGObject(const CSGObject& orig) 00128 :SGRefObject(orig), io(orig.io), parallel(orig.parallel), version(orig.version) 00129 { 00130 init(); 00131 set_global_objects(); 00132 } 00133 00134 CSGObject::~CSGObject() 00135 { 00136 unset_global_objects(); 00137 delete m_parameters; 00138 delete m_model_selection_parameters; 00139 delete m_gradient_parameters; 00140 delete m_parameter_map; 00141 } 00142 00143 void CSGObject::set_global_objects() 00144 { 00145 if (!sg_io || !sg_parallel || !sg_version) 00146 { 00147 fprintf(stderr, "call init_shogun() before using the library, dying.\n"); 00148 exit(1); 00149 } 00150 00151 SG_REF(sg_io); 00152 SG_REF(sg_parallel); 00153 SG_REF(sg_version); 00154 00155 io=sg_io; 00156 parallel=sg_parallel; 00157 version=sg_version; 00158 } 00159 00160 void CSGObject::unset_global_objects() 00161 { 00162 SG_UNREF(version); 00163 SG_UNREF(parallel); 00164 SG_UNREF(io); 00165 } 00166 00167 void CSGObject::set_global_io(SGIO* new_io) 00168 { 00169 SG_REF(new_io); 00170 SG_UNREF(sg_io); 00171 sg_io=new_io; 00172 } 00173 00174 SGIO* CSGObject::get_global_io() 00175 { 00176 SG_REF(sg_io); 00177 return sg_io; 00178 } 00179 00180 void CSGObject::set_global_parallel(Parallel* new_parallel) 00181 { 00182 SG_REF(new_parallel); 00183 SG_UNREF(sg_parallel); 00184 sg_parallel=new_parallel; 00185 } 00186 00187 bool CSGObject::update_parameter_hash() 00188 { 00189 uint32_t new_hash = 0; 00190 uint32_t carry = 0; 00191 uint32_t length = 0; 00192 00193 get_parameter_incremental_hash(m_parameters, new_hash, 00194 carry, length); 00195 00196 new_hash = CHash::FinalizeIncrementalMurmurHash3(new_hash, 00197 carry, length); 00198 00199 if(new_hash != m_hash) 00200 { 00201 m_hash = new_hash; 00202 return true; 00203 } 00204 00205 else 00206 return false; 00207 } 00208 00209 Parallel* CSGObject::get_global_parallel() 00210 { 00211 SG_REF(sg_parallel); 00212 return sg_parallel; 00213 } 00214 00215 void CSGObject::set_global_version(Version* new_version) 00216 { 00217 SG_REF(new_version); 00218 SG_UNREF(sg_version); 00219 sg_version=new_version; 00220 } 00221 00222 Version* CSGObject::get_global_version() 00223 { 00224 SG_REF(sg_version); 00225 return sg_version; 00226 } 00227 00228 bool CSGObject::is_generic(EPrimitiveType* generic) const 00229 { 00230 *generic = m_generic; 00231 00232 return m_generic != PT_NOT_GENERIC; 00233 } 00234 00235 void CSGObject::unset_generic() 00236 { 00237 m_generic = PT_NOT_GENERIC; 00238 } 00239 00240 void CSGObject::print_serializable(const char* prefix) 00241 { 00242 SG_PRINT("\n%s\n================================================================================\n", get_name()) 00243 m_parameters->print(prefix); 00244 } 00245 00246 bool CSGObject::save_serializable(CSerializableFile* file, 00247 const char* prefix, int32_t param_version) 00248 { 00249 SG_DEBUG("START SAVING CSGObject '%s'\n", get_name()) 00250 try 00251 { 00252 save_serializable_pre(); 00253 } 00254 catch (ShogunException& e) 00255 { 00256 SG_SWARNING("%s%s::save_serializable_pre(): ShogunException: " 00257 "%s\n", prefix, get_name(), 00258 e.get_exception_string()); 00259 return false; 00260 } 00261 00262 if (!m_save_pre_called) 00263 { 00264 SG_SWARNING("%s%s::save_serializable_pre(): Implementation " 00265 "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not " 00266 "called!\n", prefix, get_name()); 00267 return false; 00268 } 00269 00270 /* save parameter version */ 00271 if (!save_parameter_version(file, prefix, param_version)) 00272 return false; 00273 00274 if (!m_parameters->save(file, prefix)) 00275 return false; 00276 00277 try 00278 { 00279 save_serializable_post(); 00280 } 00281 catch (ShogunException& e) 00282 { 00283 SG_SWARNING("%s%s::save_serializable_post(): ShogunException: " 00284 "%s\n", prefix, get_name(), 00285 e.get_exception_string()); 00286 return false; 00287 } 00288 00289 if (!m_save_post_called) 00290 { 00291 SG_SWARNING("%s%s::save_serializable_post(): Implementation " 00292 "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not " 00293 "called!\n", prefix, get_name()); 00294 return false; 00295 } 00296 00297 if (prefix == NULL || *prefix == '\0') 00298 file->close(); 00299 00300 SG_DEBUG("DONE SAVING CSGObject '%s' (%p)\n", get_name(), this) 00301 00302 return true; 00303 } 00304 00305 bool CSGObject::load_serializable(CSerializableFile* file, 00306 const char* prefix, int32_t param_version) 00307 { 00308 SG_DEBUG("START LOADING CSGObject '%s'\n", get_name()) 00309 try 00310 { 00311 load_serializable_pre(); 00312 } 00313 catch (ShogunException& e) 00314 { 00315 SG_SWARNING("%s%s::load_serializable_pre(): ShogunException: " 00316 "%s\n", prefix, get_name(), 00317 e.get_exception_string()); 00318 return false; 00319 } 00320 if (!m_load_pre_called) 00321 { 00322 SG_SWARNING("%s%s::load_serializable_pre(): Implementation " 00323 "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not " 00324 "called!\n", prefix, get_name()); 00325 return false; 00326 } 00327 00328 /* try to load version of parameters */ 00329 int32_t file_version=load_parameter_version(file, prefix); 00330 SG_DEBUG("file_version=%d, current_version=%d\n", file_version, param_version) 00331 00332 if (file_version<0) 00333 { 00334 SG_WARNING("%s%s::load_serializable(): File contains no parameter " 00335 "version. Seems like your file is from the days before this " 00336 "was introduced. Ignore warning or serialize with this version " 00337 "of shogun to get rid of above and this warnings.\n", 00338 prefix, get_name()); 00339 } 00340 00341 if (file_version>param_version) 00342 { 00343 if (param_version==Version::get_version_parameter()) 00344 { 00345 SG_WARNING("%s%s::load_serializable(): parameter version of file " 00346 "larger than the one of shogun. Try with a more recent" 00347 "version of shogun.\n", prefix, get_name()); 00348 } 00349 else 00350 { 00351 SG_WARNING("%s%s::load_serializable(): parameter version of file " 00352 "larger than the current. This is probably an implementation" 00353 " error.\n", prefix, get_name()); 00354 } 00355 return false; 00356 } 00357 00358 if (file_version==param_version) 00359 { 00360 /* load normally if file has current version */ 00361 SG_DEBUG("loading normally\n") 00362 00363 /* load all parameters, except new ones */ 00364 for (int32_t i=0; i<m_parameters->get_num_parameters(); i++) 00365 { 00366 TParameter* current=m_parameters->get_parameter(i); 00367 00368 /* skip new parameters */ 00369 if (is_param_new(SGParamInfo(current, param_version))) 00370 continue; 00371 00372 if (!current->load(file, prefix)) 00373 return false; 00374 } 00375 } 00376 else 00377 { 00378 /* load all parameters from file, mappings to current version */ 00379 DynArray<TParameter*>* param_base=load_all_file_parameters(file_version, 00380 param_version, file, prefix); 00381 00382 /* create an array of param infos from current parameters */ 00383 DynArray<const SGParamInfo*>* param_infos= 00384 new DynArray<const SGParamInfo*>(); 00385 for (index_t i=0; i<m_parameters->get_num_parameters(); ++i) 00386 { 00387 TParameter* current=m_parameters->get_parameter(i); 00388 00389 /* skip new parameters */ 00390 if (is_param_new(SGParamInfo(current, param_version))) 00391 continue; 00392 00393 param_infos->append_element( 00394 new SGParamInfo(current, param_version)); 00395 } 00396 00397 /* map all parameters, result may be empty if input is */ 00398 map_parameters(param_base, file_version, param_infos); 00399 SG_DEBUG("mapping is done!\n") 00400 00401 /* this is assumed now, mapping worked or no parameters in base */ 00402 ASSERT(file_version==param_version || !param_base->get_num_elements()) 00403 00404 /* delete above created param infos */ 00405 for (index_t i=0; i<param_infos->get_num_elements(); ++i) 00406 delete param_infos->get_element(i); 00407 00408 delete param_infos; 00409 00410 /* replace parameters by loaded and mapped */ 00411 SG_DEBUG("replacing parameter data by loaded/mapped values\n") 00412 for (index_t i=0; i<m_parameters->get_num_parameters(); ++i) 00413 { 00414 TParameter* current=m_parameters->get_parameter(i); 00415 char* s=SG_MALLOC(char, 200); 00416 current->m_datatype.to_string(s, 200); 00417 SG_DEBUG("processing \"%s\": %s\n", current->m_name, s) 00418 SG_FREE(s); 00419 00420 /* skip new parameters */ 00421 if (is_param_new(SGParamInfo(current, param_version))) 00422 { 00423 SG_DEBUG("%s is new, skipping\n", current->m_name) 00424 continue; 00425 } 00426 00427 /* search for current parameter in mapped ones */ 00428 index_t index=CMath::binary_search(param_base->get_array(), 00429 param_base->get_num_elements(), current); 00430 00431 TParameter* migrated=param_base->get_element(index); 00432 00433 /* now copy data from migrated TParameter instance 00434 * (this automatically deletes the old data allocations) */ 00435 SG_DEBUG("copying migrated data into parameter\n") 00436 current->copy_data(migrated); 00437 } 00438 00439 /* delete the migrated parameter data base */ 00440 SG_DEBUG("deleting old parameter base\n") 00441 for (index_t i=0; i<param_base->get_num_elements(); ++i) 00442 { 00443 TParameter* current=param_base->get_element(i); 00444 SG_DEBUG("deleting old \"%s\"\n", current->m_name) 00445 delete current; 00446 } 00447 delete param_base; 00448 } 00449 00450 try 00451 { 00452 load_serializable_post(); 00453 } 00454 catch (ShogunException& e) 00455 { 00456 SG_SWARNING("%s%s::load_serializable_post(): ShogunException: " 00457 "%s\n", prefix, get_name(), 00458 e.get_exception_string()); 00459 return false; 00460 } 00461 00462 if (!m_load_post_called) 00463 { 00464 SG_SWARNING("%s%s::load_serializable_post(): Implementation " 00465 "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not " 00466 "called!\n", prefix, get_name()); 00467 return false; 00468 } 00469 SG_DEBUG("DONE LOADING CSGObject '%s' (%p)\n", get_name(), this) 00470 00471 return true; 00472 } 00473 00474 DynArray<TParameter*>* CSGObject::load_file_parameters( 00475 const SGParamInfo* param_info, int32_t file_version, 00476 CSerializableFile* file, const char* prefix) 00477 { 00478 /* ensure that recursion works */ 00479 SG_SDEBUG("entering %s::load_file_parameters\n", get_name()) 00480 if (file_version>param_info->m_param_version) 00481 { 00482 SG_SERROR("parameter version of \"%s\" in file (%d) is more recent than" 00483 " provided %d!\n", param_info->m_name, file_version, 00484 param_info->m_param_version); 00485 } 00486 00487 DynArray<TParameter*>* result_array=new DynArray<TParameter*>(); 00488 00489 /* do mapping */ 00490 char* s=param_info->to_string(); 00491 SG_SDEBUG("try to get mapping for: %s\n", s) 00492 SG_FREE(s); 00493 00494 /* mapping has only be deleted if was created here (no mapping was found) */ 00495 bool free_mapped=false; 00496 DynArray<const SGParamInfo*>* mapped=m_parameter_map->get(param_info); 00497 if (!mapped) 00498 { 00499 /* since a new mapped array will be created, set deletion flag */ 00500 free_mapped=true; 00501 mapped=new DynArray<const SGParamInfo*>(); 00502 00503 /* if no mapping was found, nothing has changed. Simply create new param 00504 * info with decreased version */ 00505 SG_SDEBUG("no mapping found\n") 00506 if (file_version<param_info->m_param_version) 00507 { 00508 /* create new array and put param info with decreased version in */ 00509 mapped->append_element(new SGParamInfo(param_info->m_name, 00510 param_info->m_ctype, param_info->m_stype, 00511 param_info->m_ptype, param_info->m_param_version-1)); 00512 00513 SG_SDEBUG("using:\n") 00514 for (index_t i=0; i<mapped->get_num_elements(); ++i) 00515 { 00516 s=mapped->get_element(i)->to_string(); 00517 SG_SDEBUG("\t%s\n", s) 00518 SG_FREE(s); 00519 } 00520 } 00521 else 00522 { 00523 /* create new array and put original param info in */ 00524 SG_SDEBUG("reached file version\n") 00525 mapped->append_element(param_info->duplicate()); 00526 } 00527 } 00528 else 00529 { 00530 SG_SDEBUG("found:\n") 00531 for (index_t i=0; i<mapped->get_num_elements(); ++i) 00532 { 00533 s=mapped->get_element(i)->to_string(); 00534 SG_SDEBUG("\t%s\n", s) 00535 SG_FREE(s); 00536 } 00537 } 00538 00539 00540 /* case file version same as provided version. 00541 * means that parameters have to be loaded from file, recursion stops */ 00542 if (file_version==param_info->m_param_version) 00543 { 00544 SG_SDEBUG("recursion stop, loading from file\n") 00545 /* load all parameters in mapping from file */ 00546 for (index_t i=0; i<mapped->get_num_elements(); ++i) 00547 { 00548 const SGParamInfo* current=mapped->get_element(i); 00549 s=current->to_string(); 00550 SG_SDEBUG("loading %s\n", s) 00551 SG_FREE(s); 00552 00553 TParameter* loaded; 00554 /* allocate memory for length and matrix/vector 00555 * This has to be done because this stuff normally is in the class 00556 * variables which do not exist in this case. Deletion is handled 00557 * via the allocated_from_scratch flag of TParameter */ 00558 00559 /* create type and copy lengths, empty data for now */ 00560 TSGDataType type(current->m_ctype, current->m_stype, 00561 current->m_ptype); 00562 loaded=new TParameter(&type, NULL, current->m_name, ""); 00563 00564 /* allocate data/length variables for the TParameter, lengths are not 00565 * important now, so set to one */ 00566 SGVector<index_t> dims(2); 00567 dims[0]=1; 00568 dims[1]=1; 00569 loaded->allocate_data_from_scratch(dims); 00570 00571 /* tell instance to load data from file */ 00572 if (!loaded->load(file, prefix)) 00573 { 00574 s=param_info->to_string(); 00575 SG_ERROR("Could not load %s. The reason for this might be wrong " 00576 "parameter mappings\n", s); 00577 SG_FREE(s); 00578 } 00579 00580 SG_DEBUG("loaded lengths: y=%d, x=%d\n", 00581 loaded->m_datatype.m_length_y ? *loaded->m_datatype.m_length_y : -1, 00582 loaded->m_datatype.m_length_x ? *loaded->m_datatype.m_length_x : -1); 00583 00584 /* append new TParameter to result array */ 00585 result_array->append_element(loaded); 00586 } 00587 SG_SDEBUG("done loading from file\n") 00588 } 00589 /* recursion with mapped type, a mapping exists in this case (ensured by 00590 * above assert) */ 00591 else 00592 { 00593 /* for all elements in mapping, do recursion */ 00594 for (index_t i=0; i<mapped->get_num_elements(); ++i) 00595 { 00596 const SGParamInfo* current=mapped->get_element(i); 00597 s=current->to_string(); 00598 SG_SDEBUG("starting recursion over %s\n", s) 00599 00600 /* recursively get all file parameters for this parameter */ 00601 DynArray<TParameter*>* recursion_array= 00602 load_file_parameters(current, file_version, file, prefix); 00603 00604 SG_SDEBUG("recursion over %s done\n", s) 00605 SG_FREE(s); 00606 00607 /* append all recursion data to current array */ 00608 SG_SDEBUG("appending all results to current result\n") 00609 for (index_t j=0; j<recursion_array->get_num_elements(); ++j) 00610 result_array->append_element(recursion_array->get_element(j)); 00611 00612 /* clean up */ 00613 delete recursion_array; 00614 } 00615 } 00616 00617 SG_SDEBUG("cleaning up old mapping \n") 00618 00619 00620 /* clean up mapping */ 00621 if (free_mapped) 00622 { 00623 for (index_t i=0; i<mapped->get_num_elements(); ++i) 00624 delete mapped->get_element(i); 00625 00626 delete mapped; 00627 } 00628 00629 SG_SDEBUG("leaving %s::load_file_parameters\n", get_name()) 00630 return result_array; 00631 } 00632 00633 DynArray<TParameter*>* CSGObject::load_all_file_parameters(int32_t file_version, 00634 int32_t current_version, CSerializableFile* file, const char* prefix) 00635 { 00636 DynArray<TParameter*>* result=new DynArray<TParameter*>(); 00637 00638 for (index_t i=0; i<m_parameters->get_num_parameters(); ++i) 00639 { 00640 TParameter* current=m_parameters->get_parameter(i); 00641 00642 /* extract current parameter info */ 00643 const SGParamInfo* info=new SGParamInfo(current, current_version); 00644 00645 /* skip new parameters */ 00646 if (is_param_new(*info)) 00647 { 00648 delete info; 00649 continue; 00650 } 00651 00652 /* in the other case, load parameters data from file */ 00653 DynArray<TParameter*>* temp=load_file_parameters(info, file_version, 00654 file, prefix); 00655 00656 /* and append them all to array */ 00657 for (index_t j=0; j<temp->get_num_elements(); ++j) 00658 result->append_element(temp->get_element(j)); 00659 00660 /* clean up */ 00661 delete temp; 00662 delete info; 00663 } 00664 00665 /* sort array before returning */ 00666 CMath::qsort(result->get_array(), result->get_num_elements()); 00667 00668 return result; 00669 } 00670 00671 void CSGObject::map_parameters(DynArray<TParameter*>* param_base, 00672 int32_t& base_version, DynArray<const SGParamInfo*>* target_param_infos) 00673 { 00674 SG_DEBUG("entering %s::map_parameters\n", get_name()) 00675 /* NOTE: currently the migration is done step by step over every version */ 00676 00677 if (!target_param_infos->get_num_elements()) 00678 { 00679 SG_DEBUG("no target parameter infos\n") 00680 SG_DEBUG("leaving %s::map_parameters\n", get_name()) 00681 return; 00682 } 00683 00684 /* map all target parameter infos once */ 00685 DynArray<const SGParamInfo*>* mapped_infos= 00686 new DynArray<const SGParamInfo*>(); 00687 DynArray<SGParamInfo*>* to_delete=new DynArray<SGParamInfo*>(); 00688 for (index_t i=0; i<target_param_infos->get_num_elements(); ++i) 00689 { 00690 const SGParamInfo* current=target_param_infos->get_element(i); 00691 00692 char* s=current->to_string(); 00693 SG_DEBUG("trying to get parameter mapping for %s\n", s) 00694 SG_FREE(s); 00695 00696 DynArray<const SGParamInfo*>* mapped=m_parameter_map->get(current); 00697 00698 if (mapped) 00699 { 00700 mapped_infos->append_element(mapped->get_element(0)); 00701 for (index_t j=0; j<mapped->get_num_elements(); ++j) 00702 { 00703 s=mapped->get_element(j)->to_string(); 00704 SG_DEBUG("found mapping: %s\n", s) 00705 SG_FREE(s); 00706 } 00707 } 00708 else 00709 { 00710 /* these have to be deleted above */ 00711 SGParamInfo* no_change=new SGParamInfo(*current); 00712 no_change->m_param_version--; 00713 s=no_change->to_string(); 00714 SG_DEBUG("no mapping found, using %s\n", s) 00715 SG_FREE(s); 00716 mapped_infos->append_element(no_change); 00717 to_delete->append_element(no_change); 00718 } 00719 } 00720 00721 /* assert that at least one mapping exists */ 00722 ASSERT(mapped_infos->get_num_elements()) 00723 int32_t mapped_version=mapped_infos->get_element(0)->m_param_version; 00724 00725 /* assert that all param versions are equal for now (if not empty param) */ 00726 for (index_t i=1; i<mapped_infos->get_num_elements(); ++i) 00727 { 00728 ASSERT(mapped_infos->get_element(i)->m_param_version==mapped_version || 00729 *mapped_infos->get_element(i)==SGParamInfo()); 00730 } 00731 00732 /* recursion, after this call, base is at version of mapped infos */ 00733 if (mapped_version>base_version) 00734 map_parameters(param_base, base_version, mapped_infos); 00735 00736 /* delete mapped parameter infos array */ 00737 delete mapped_infos; 00738 00739 /* delete newly created parameter infos which have to name or type change */ 00740 for (index_t i=0; i<to_delete->get_num_elements(); ++i) 00741 delete to_delete->get_element(i); 00742 00743 delete to_delete; 00744 00745 ASSERT(base_version==mapped_version) 00746 00747 /* do migration of one version step, create new base */ 00748 DynArray<TParameter*>* new_base=new DynArray<TParameter*>(); 00749 for (index_t i=0; i<target_param_infos->get_num_elements(); ++i) 00750 { 00751 char* s=target_param_infos->get_element(i)->to_string(); 00752 SG_DEBUG("migrating one step to target: %s\n", s) 00753 SG_FREE(s); 00754 TParameter* p=migrate(param_base, target_param_infos->get_element(i)); 00755 new_base->append_element(p); 00756 } 00757 00758 /* replace base by new base, delete old base, if it was created in migrate */ 00759 SG_DEBUG("deleting parameters base version %d\n", base_version) 00760 for (index_t i=0; i<param_base->get_num_elements(); ++i) 00761 delete param_base->get_element(i); 00762 00763 SG_DEBUG("replacing old parameter base\n") 00764 *param_base=*new_base; 00765 base_version=mapped_version+1; 00766 00767 SG_DEBUG("new parameter base of size %d:\n", param_base->get_num_elements()) 00768 for (index_t i=0; i<param_base->get_num_elements(); ++i) 00769 { 00770 TParameter* current=param_base->get_element(i); 00771 TSGDataType type=current->m_datatype; 00772 if (type.m_ptype==PT_SGOBJECT) 00773 { 00774 if (type.m_ctype==CT_SCALAR) 00775 { 00776 CSGObject* object=*(CSGObject**)current->m_parameter; 00777 SG_DEBUG("(%d:) \"%s\": sgobject \"%s\" at %p\n", i, 00778 current->m_name, object ? object->get_name() : "", 00779 object); 00780 } 00781 else 00782 { 00783 index_t len=1; 00784 len*=type.m_length_x ? *type.m_length_x : 1; 00785 len*=type.m_length_y ? *type.m_length_y : 1; 00786 CSGObject** array=*(CSGObject***)current->m_parameter; 00787 for (index_t j=0; j<len; ++j) 00788 { 00789 CSGObject* object=array[j]; 00790 SG_DEBUG("(%d:) \"%s\": sgobject \"%s\" at %p\n", i, 00791 current->m_name, object ? object->get_name() : "", 00792 object); 00793 } 00794 } 00795 } 00796 else 00797 { 00798 char* s=SG_MALLOC(char, 200); 00799 current->m_datatype.to_string(s, 200); 00800 SG_DEBUG("(%d:) \"%s\": type: %s at %p\n", i, current->m_name, s, 00801 current->m_parameter); 00802 SG_FREE(s); 00803 } 00804 } 00805 00806 /* because content was copied, new base may be deleted */ 00807 delete new_base; 00808 00809 /* sort the just created new base */ 00810 SG_DEBUG("sorting base\n") 00811 CMath::qsort(param_base->get_array(), param_base->get_num_elements()); 00812 00813 /* at this point the param_base is at the same version as the version of 00814 * the provided parameter infos */ 00815 SG_DEBUG("leaving %s::map_parameters\n", get_name()) 00816 } 00817 00818 void CSGObject::one_to_one_migration_prepare(DynArray<TParameter*>* param_base, 00819 const SGParamInfo* target, TParameter*& replacement, 00820 TParameter*& to_migrate, char* old_name) 00821 { 00822 SG_DEBUG("CSGObject::entering CSGObject::one_to_one_migration_prepare() for " 00823 "\"%s\"\n", target->m_name); 00824 00825 /* generate type of target structure */ 00826 TSGDataType type(target->m_ctype, target->m_stype, target->m_ptype); 00827 00828 /* first find index of needed data. 00829 * in this case, element in base with same name or old name */ 00830 char* name=target->m_name; 00831 if (old_name) 00832 name=old_name; 00833 00834 /* dummy for searching, search and save result in to_migrate parameter */ 00835 TParameter* t=new TParameter(&type, NULL, name, ""); 00836 index_t i=CMath::binary_search(param_base->get_array(), 00837 param_base->get_num_elements(), t); 00838 delete t; 00839 00840 /* assert that something is found */ 00841 ASSERT(i>=0) 00842 to_migrate=param_base->get_element(i); 00843 00844 /* result structure, data NULL for now */ 00845 replacement=new TParameter(&type, NULL, target->m_name, 00846 to_migrate->m_description); 00847 00848 SGVector<index_t> dims(2); 00849 dims[0]=1; 00850 dims[1]=1; 00851 /* allocate content to write into, lengths are needed for this */ 00852 if (to_migrate->m_datatype.m_length_x) 00853 dims[0]=*to_migrate->m_datatype.m_length_x; 00854 00855 if (to_migrate->m_datatype.m_length_y) 00856 dims[1]=*to_migrate->m_datatype.m_length_y; 00857 00858 replacement->allocate_data_from_scratch(dims); 00859 00860 /* in case of sgobject, copy pointer data and SG_REF */ 00861 if (to_migrate->m_datatype.m_ptype==PT_SGOBJECT) 00862 { 00863 /* note that the memory is already allocated before the migrate call */ 00864 CSGObject* object=*((CSGObject**)to_migrate->m_parameter); 00865 *((CSGObject**)replacement->m_parameter)=object; 00866 SG_REF(object); 00867 SG_DEBUG("copied and SG_REF sgobject pointer for \"%s\" at %p\n", 00868 object->get_name(), object); 00869 } 00870 00871 /* tell the old TParameter to delete its data on deletion */ 00872 to_migrate->m_delete_data=true; 00873 00874 SG_DEBUG("CSGObject::leaving CSGObject::one_to_one_migration_prepare() for " 00875 "\"%s\"\n", target->m_name); 00876 } 00877 00878 TParameter* CSGObject::migrate(DynArray<TParameter*>* param_base, 00879 const SGParamInfo* target) 00880 { 00881 SG_DEBUG("entering %s::migrate\n", get_name()) 00882 /* this is only executed, iff there was no migration method which handled 00883 * migration to the provided target. In this case, it is assumed that the 00884 * parameter simply has not changed. Verify this here and return copy of 00885 * data in case its true. 00886 * If not, throw an exception -- parameter migration HAS to be implemented 00887 * by hand everytime, a parameter changes type or name. */ 00888 00889 TParameter* result=NULL; 00890 00891 /* first find index of needed data. 00892 * in this case, element in base with same name */ 00893 /* type is also needed */ 00894 TSGDataType type(target->m_ctype, target->m_stype, 00895 target->m_ptype); 00896 00897 /* dummy for searching, search and save result */ 00898 TParameter* t=new TParameter(&type, NULL, target->m_name, ""); 00899 index_t i=CMath::binary_search(param_base->get_array(), 00900 param_base->get_num_elements(), t); 00901 delete t; 00902 00903 /* check if name change occurred while no migration method was specified */ 00904 if (i<0) 00905 { 00906 SG_ERROR("Name change for parameter that has to be mapped to \"%s\"," 00907 " and to no migration method available\n", target->m_name); 00908 } 00909 00910 TParameter* to_migrate=param_base->get_element(i); 00911 00912 /* check if element in base is equal to target one */ 00913 if (*target==SGParamInfo(to_migrate, target->m_param_version)) 00914 { 00915 char* s=SG_MALLOC(char, 200); 00916 to_migrate->m_datatype.to_string(s, 200); 00917 SG_DEBUG("nothing changed, using old data: %s\n", s) 00918 SG_FREE(s); 00919 result=new TParameter(&to_migrate->m_datatype, NULL, to_migrate->m_name, 00920 to_migrate->m_description); 00921 00922 SGVector<index_t> dims(2); 00923 dims[0]=1; 00924 dims[1]=1; 00925 if (to_migrate->m_datatype.m_length_x) 00926 dims[0]=*to_migrate->m_datatype.m_length_x; 00927 00928 if (to_migrate->m_datatype.m_length_y) 00929 dims[1]=*to_migrate->m_datatype.m_length_y; 00930 00931 /* allocate lengths and evtl scalar data but not non-scalar data (no 00932 * new_cont call */ 00933 result->allocate_data_from_scratch(dims, false); 00934 00935 /* now use old data */ 00936 if (to_migrate->m_datatype.m_ctype==CT_SCALAR && 00937 to_migrate->m_datatype.m_ptype!=PT_SGOBJECT) 00938 { 00939 /* copy data */ 00940 SG_DEBUG("copying scalar data\n") 00941 memcpy(result->m_parameter,to_migrate->m_parameter, 00942 to_migrate->m_datatype.get_size()); 00943 } 00944 else 00945 { 00946 /* copy content of pointer */ 00947 SG_DEBUG("copying content of poitner for non-scalar data\n") 00948 *(void**)result->m_parameter=*(void**)(to_migrate->m_parameter); 00949 } 00950 } 00951 else 00952 { 00953 char* s=target->to_string(); 00954 SG_ERROR("No migration method available for %s!\n", s) 00955 SG_FREE(s); 00956 } 00957 00958 SG_DEBUG("leaving %s::migrate\n", get_name()) 00959 00960 return result; 00961 } 00962 00963 bool CSGObject::save_parameter_version(CSerializableFile* file, 00964 const char* prefix, int32_t param_version) 00965 { 00966 TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32); 00967 TParameter p(&t, ¶m_version, "version_parameter", 00968 "Version of parameters of this object"); 00969 return p.save(file, prefix); 00970 } 00971 00972 int32_t CSGObject::load_parameter_version(CSerializableFile* file, 00973 const char* prefix) 00974 { 00975 TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32); 00976 int32_t v; 00977 TParameter tp(&t, &v, "version_parameter", ""); 00978 if (tp.load(file, prefix)) 00979 return v; 00980 else 00981 return -1; 00982 } 00983 00984 void CSGObject::load_serializable_pre() throw (ShogunException) 00985 { 00986 m_load_pre_called = true; 00987 } 00988 00989 void CSGObject::load_serializable_post() throw (ShogunException) 00990 { 00991 m_load_post_called = true; 00992 } 00993 00994 void CSGObject::save_serializable_pre() throw (ShogunException) 00995 { 00996 m_save_pre_called = true; 00997 } 00998 00999 void CSGObject::save_serializable_post() throw (ShogunException) 01000 { 01001 m_save_post_called = true; 01002 } 01003 01004 #ifdef TRACE_MEMORY_ALLOCS 01005 #include <shogun/lib/Map.h> 01006 extern CMap<void*, shogun::MemoryBlock>* sg_mallocs; 01007 #endif 01008 01009 void CSGObject::init() 01010 { 01011 #ifdef TRACE_MEMORY_ALLOCS 01012 if (sg_mallocs) 01013 { 01014 int32_t idx=sg_mallocs->index_of(this); 01015 if (idx>-1) 01016 { 01017 MemoryBlock* b=sg_mallocs->get_element_ptr(idx); 01018 b->set_sgobject(); 01019 } 01020 } 01021 #endif 01022 01023 io = NULL; 01024 parallel = NULL; 01025 version = NULL; 01026 m_parameters = new Parameter(); 01027 m_model_selection_parameters = new Parameter(); 01028 m_gradient_parameters=new Parameter(); 01029 m_parameter_map=new ParameterMap(); 01030 m_generic = PT_NOT_GENERIC; 01031 m_load_pre_called = false; 01032 m_load_post_called = false; 01033 m_hash = 0; 01034 } 01035 01036 void CSGObject::print_modsel_params() 01037 { 01038 SG_PRINT("parameters available for model selection for %s:\n", get_name()) 01039 01040 index_t num_param=m_model_selection_parameters->get_num_parameters(); 01041 01042 if (!num_param) 01043 SG_PRINT("\tnone\n") 01044 01045 for (index_t i=0; i<num_param; i++) 01046 { 01047 TParameter* current=m_model_selection_parameters->get_parameter(i); 01048 index_t l=200; 01049 char* type=SG_MALLOC(char, l); 01050 if (type) 01051 { 01052 current->m_datatype.to_string(type, l); 01053 SG_PRINT("\t%s (%s): %s\n", current->m_name, current->m_description, 01054 type); 01055 SG_FREE(type); 01056 } 01057 } 01058 } 01059 01060 SGStringList<char> CSGObject::get_modelsel_names() 01061 { 01062 index_t num_param=m_model_selection_parameters->get_num_parameters(); 01063 01064 SGStringList<char> result(num_param, -1); 01065 01066 index_t max_string_length=-1; 01067 01068 for (index_t i=0; i<num_param; i++) 01069 { 01070 char* name=m_model_selection_parameters->get_parameter(i)->m_name; 01071 index_t len=strlen(name); 01072 // +1 to have a zero terminated string 01073 result.strings[i]=SGString<char>(name, len+1); 01074 01075 if (len>max_string_length) 01076 max_string_length=len; 01077 } 01078 01079 result.max_string_length=max_string_length; 01080 01081 return result; 01082 } 01083 01084 char* CSGObject::get_modsel_param_descr(const char* param_name) 01085 { 01086 index_t index=get_modsel_param_index(param_name); 01087 01088 if (index<0) 01089 { 01090 SG_ERROR("There is no model selection parameter called \"%s\" for %s", 01091 param_name, get_name()); 01092 } 01093 01094 return m_model_selection_parameters->get_parameter(index)->m_description; 01095 } 01096 01097 index_t CSGObject::get_modsel_param_index(const char* param_name) 01098 { 01099 /* use fact that names extracted from below method are in same order than 01100 * in m_model_selection_parameters variable */ 01101 SGStringList<char> names=get_modelsel_names(); 01102 01103 /* search for parameter with provided name */ 01104 index_t index=-1; 01105 for (index_t i=0; i<names.num_strings; i++) 01106 { 01107 TParameter* current=m_model_selection_parameters->get_parameter(i); 01108 if (!strcmp(param_name, current->m_name)) 01109 { 01110 index=i; 01111 break; 01112 } 01113 } 01114 01115 return index; 01116 } 01117 01118 bool CSGObject::is_param_new(const SGParamInfo param_info) const 01119 { 01120 /* check if parameter is new in this version (has empty mapping) */ 01121 DynArray<const SGParamInfo*>* value=m_parameter_map->get(¶m_info); 01122 bool result=value && *value->get_element(0) == SGParamInfo(); 01123 01124 return result; 01125 } 01126 01127 void CSGObject::get_parameter_incremental_hash(Parameter* param, 01128 uint32_t& hash, uint32_t& carry, uint32_t& total_length) 01129 { 01130 if (!param) 01131 return; 01132 01133 for (index_t i=0; i<param->get_num_parameters(); i++) 01134 { 01135 TParameter* p = param->get_parameter(i); 01136 SG_DEBUG("Updating hash for parameter \"%s\"\n", p->m_name ? p->m_name : "(nil)"); 01137 01138 if (!p || !p->is_valid()) 01139 continue; 01140 01141 if (p->m_datatype.m_ptype != PT_SGOBJECT) 01142 { 01143 p->get_incremental_hash(hash, carry, total_length); 01144 continue; 01145 } 01146 01147 CSGObject* child = *((CSGObject**)(p->m_parameter)); 01148 01149 if (child) 01150 get_parameter_incremental_hash( 01151 child->m_parameters, hash, 01152 carry, total_length); 01153 } 01154 } 01155 01156 void CSGObject::build_gradient_parameter_dictionary(CMap<TParameter*, CSGObject*>* dict) 01157 { 01158 for (index_t i=0; i<m_gradient_parameters->get_num_parameters(); i++) 01159 { 01160 TParameter* p=m_gradient_parameters->get_parameter(i); 01161 dict->add(p, this); 01162 } 01163 01164 for (index_t i=0; i<m_model_selection_parameters->get_num_parameters(); i++) 01165 { 01166 TParameter* p=m_model_selection_parameters->get_parameter(i); 01167 CSGObject* child=*(CSGObject**)(p->m_parameter); 01168 01169 if ((p->m_datatype.m_ptype == PT_SGOBJECT) && 01170 (p->m_datatype.m_ctype == CT_SCALAR) && child) 01171 { 01172 child->build_gradient_parameter_dictionary(dict); 01173 } 01174 } 01175 } 01176 01177 bool CSGObject::equals(CSGObject* other, float64_t accuracy) 01178 { 01179 SG_DEBUG("entering %s::equals()\n", get_name()); 01180 01181 if (other==this) 01182 { 01183 SG_DEBUG("leaving %s::equals(): other object is me\n", get_name()); 01184 return true; 01185 } 01186 01187 if (!other) 01188 { 01189 SG_DEBUG("leaving %s::equals(): other object is NULL\n", get_name()); 01190 return false; 01191 } 01192 01193 SG_DEBUG("comparing \"%s\" to \"%s\"\n", get_name(), other->get_name()); 01194 01195 /* a crude type check based on the get_name */ 01196 if (strcmp(other->get_name(), get_name())) 01197 { 01198 SG_INFO("leaving %s::equals(): name of other object differs\n", get_name()); 01199 return false; 01200 } 01201 01202 /* should not be necessary but just ot be sure that type has not changed. 01203 * Will assume that parameters are in same order with same name from here */ 01204 if (m_parameters->get_num_parameters()!=other->m_parameters->get_num_parameters()) 01205 { 01206 SG_INFO("leaving %s::equals(): number of parameters of other object " 01207 "differs\n", get_name()); 01208 return false; 01209 } 01210 01211 for (index_t i=0; i<m_parameters->get_num_parameters(); ++i) 01212 { 01213 SG_DEBUG("comparing parameter %d\n", i); 01214 01215 TParameter* this_param=m_parameters->get_parameter(i); 01216 TParameter* other_param=other->m_parameters->get_parameter(i); 01217 01218 /* some checks to make sure parameters have same order and names and 01219 * are not NULL. Should never be the case but check anyway. */ 01220 if (!this_param && !other_param) 01221 continue; 01222 01223 if (!this_param && other_param) 01224 { 01225 SG_DEBUG("leaving %s::equals(): parameter %d is NULL where other's " 01226 "parameter \"%s\" is not\n", get_name(), other_param->m_name); 01227 return false; 01228 } 01229 01230 if (this_param && !other_param) 01231 { 01232 SG_DEBUG("leaving %s::equals(): parameter %d is \"%s\" where other's " 01233 "parameter is NULL\n", get_name(), this_param->m_name); 01234 return false; 01235 } 01236 01237 SG_DEBUG("comparing parameter \"%s\" to other's \"%s\"\n", 01238 this_param->m_name, other_param->m_name); 01239 01240 /* hard-wired exception for DynamicObjectArray parameter num_elements */ 01241 if (!strcmp("DynamicObjectArray", get_name()) && 01242 !strcmp(this_param->m_name, "num_elements") && 01243 !strcmp(other_param->m_name, "num_elements")) 01244 { 01245 SG_DEBUG("Ignoring DynamicObjectArray::num_elements field\n"); 01246 continue; 01247 } 01248 01249 /* hard-wired exception for DynamicArray parameter num_elements */ 01250 if (!strcmp("DynamicArray", get_name()) && 01251 !strcmp(this_param->m_name, "num_elements") && 01252 !strcmp(other_param->m_name, "num_elements")) 01253 { 01254 SG_DEBUG("Ignoring DynamicArray::num_elements field\n"); 01255 continue; 01256 } 01257 01258 /* use equals method of TParameter from here */ 01259 if (!this_param->equals(other_param, accuracy)) 01260 { 01261 SG_INFO("leaving %s::equals(): parameters at position %d with name" 01262 " \"%s\" differs from other object parameter with name " 01263 "\"%s\"\n", 01264 get_name(), i, this_param->m_name, other_param->m_name); 01265 return false; 01266 } 01267 } 01268 01269 SG_DEBUG("leaving %s::equals(): object are equal\n", get_name()); 01270 return true; 01271 } 01272 01273 CSGObject* CSGObject::clone() 01274 { 01275 SG_DEBUG("entering %s::clone()\n", get_name()); 01276 01277 SG_DEBUG("constructing an empty instance of %s\n", get_name()); 01278 CSGObject* copy=new_sgserializable(get_name(), this->m_generic); 01279 01280 SG_REF(copy); 01281 01282 REQUIRE(copy, "Could not create empty instance of \"%s\". The reason for " 01283 "this usually is that get_name() of the class returns something " 01284 "wrong, or that a class has a wrongly set generic type.\n", 01285 get_name()); 01286 01287 for (index_t i=0; i<m_parameters->get_num_parameters(); ++i) 01288 { 01289 SG_DEBUG("cloning parameter \"%s\" at index %d\n", 01290 m_parameters->get_parameter(i)->m_name, i); 01291 01292 if (!m_parameters->get_parameter(i)->copy(copy->m_parameters->get_parameter(i))) 01293 { 01294 SG_DEBUG("leaving %s::clone(): Clone failed. Returning NULL\n", 01295 get_name()); 01296 return NULL; 01297 } 01298 } 01299 01300 SG_DEBUG("leaving %s::clone(): Clone successful\n", get_name()); 01301 return copy; 01302 }