SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2011 Heiko Strathmann 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/modelselection/RandomSearchModelSelection.h> 00012 #include <shogun/modelselection/ParameterCombination.h> 00013 #include <shogun/modelselection/ModelSelectionParameters.h> 00014 #include <shogun/evaluation/CrossValidation.h> 00015 #include <shogun/mathematics/Statistics.h> 00016 #include <shogun/machine/Machine.h> 00017 00018 using namespace shogun; 00019 00020 CRandomSearchModelSelection::CRandomSearchModelSelection() : CModelSelection() 00021 { 00022 set_ratio(0.5); 00023 } 00024 00025 CRandomSearchModelSelection::CRandomSearchModelSelection( 00026 CMachineEvaluation* machine_eval, 00027 CModelSelectionParameters* model_parameters, float64_t ratio) 00028 : CModelSelection(machine_eval, model_parameters) 00029 { 00030 set_ratio(ratio); 00031 } 00032 00033 CRandomSearchModelSelection::~CRandomSearchModelSelection() 00034 { 00035 } 00036 00037 CParameterCombination* CRandomSearchModelSelection::select_model(bool print_state) 00038 { 00039 if (print_state) 00040 SG_PRINT("Generating parameter combinations\n") 00041 00042 /* Retrieve all possible parameter combinations */ 00043 CDynamicObjectArray* all_combinations= 00044 (CDynamicObjectArray*)m_model_parameters->get_combinations(); 00045 00046 int32_t n_all_combinations=all_combinations->get_num_elements(); 00047 SGVector<index_t> combinations_indices=CStatistics::sample_indices(n_all_combinations*m_ratio, n_all_combinations); 00048 00049 CDynamicObjectArray* combinations=new CDynamicObjectArray(); 00050 00051 for (int32_t i=0; i<combinations_indices.vlen; i++) 00052 combinations->append_element(all_combinations->get_element(i)); 00053 00054 CCrossValidationResult* best_result=new CCrossValidationResult(); 00055 00056 CParameterCombination* best_combination=NULL; 00057 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00058 { 00059 if (print_state) SG_PRINT("Direction is maximize\n") 00060 best_result->mean=CMath::ALMOST_NEG_INFTY; 00061 } 00062 else 00063 { 00064 if (print_state) SG_PRINT("Direction is minimize\n") 00065 best_result->mean=CMath::ALMOST_INFTY; 00066 } 00067 00068 /* underlying learning machine */ 00069 CMachine* machine=m_machine_eval->get_machine(); 00070 00071 /* apply all combinations and search for best one */ 00072 for (index_t i=0; i<combinations->get_num_elements(); ++i) 00073 { 00074 CParameterCombination* current_combination=(CParameterCombination*) 00075 combinations->get_element(i); 00076 00077 /* eventually print */ 00078 if (print_state) 00079 { 00080 SG_PRINT("trying combination:\n") 00081 current_combination->print_tree(); 00082 } 00083 00084 current_combination->apply_to_modsel_parameter( 00085 machine->m_model_selection_parameters); 00086 00087 /* note that this may implicitly lock and unlockthe machine */ 00088 CCrossValidationResult* result = 00089 (CCrossValidationResult*)(m_machine_eval->evaluate()); 00090 00091 if (result->get_result_type() != CROSSVALIDATION_RESULT) 00092 SG_ERROR("Evaluation result is not of type CCrossValidationResult!") 00093 00094 if (print_state) 00095 result->print_result(); 00096 00097 /* check if current result is better, delete old combinations */ 00098 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00099 { 00100 if (result->mean>best_result->mean) 00101 { 00102 if (best_combination) 00103 SG_UNREF(best_combination); 00104 00105 best_combination=(CParameterCombination*) 00106 combinations->get_element(i); 00107 00108 SG_REF(result); 00109 SG_UNREF(best_result); 00110 best_result=result; 00111 } 00112 else 00113 { 00114 CParameterCombination* combination=(CParameterCombination*) 00115 combinations->get_element(i); 00116 SG_UNREF(combination); 00117 } 00118 } 00119 else 00120 { 00121 if (result->mean<best_result->mean) 00122 { 00123 if (best_combination) 00124 SG_UNREF(best_combination); 00125 00126 best_combination=(CParameterCombination*) 00127 combinations->get_element(i); 00128 00129 SG_REF(result); 00130 SG_UNREF(best_result); 00131 best_result=result; 00132 } 00133 else 00134 { 00135 CParameterCombination* combination=(CParameterCombination*) 00136 combinations->get_element(i); 00137 SG_UNREF(combination); 00138 } 00139 } 00140 00141 SG_UNREF(result); 00142 SG_UNREF(current_combination); 00143 } 00144 00145 SG_UNREF(best_result); 00146 SG_UNREF(machine); 00147 SG_UNREF(combinations); 00148 00149 return best_combination; 00150 }