SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011-2012 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/modelselection/GridSearchModelSelection.h> 00012 #include <shogun/modelselection/ParameterCombination.h> 00013 #include <shogun/modelselection/ModelSelectionParameters.h> 00014 #include <shogun/evaluation/CrossValidation.h> 00015 #include <shogun/machine/Machine.h> 00016 00017 using namespace shogun; 00018 00019 CGridSearchModelSelection::CGridSearchModelSelection() : CModelSelection() 00020 { 00021 } 00022 00023 CGridSearchModelSelection::CGridSearchModelSelection( 00024 CMachineEvaluation* machine_eval, 00025 CModelSelectionParameters* model_parameters) 00026 : CModelSelection(machine_eval, model_parameters) 00027 { 00028 } 00029 00030 CGridSearchModelSelection::~CGridSearchModelSelection() 00031 { 00032 } 00033 00034 CParameterCombination* CGridSearchModelSelection::select_model(bool print_state) 00035 { 00036 if (print_state) 00037 SG_PRINT("Generating parameter combinations\n") 00038 00039 /* Retrieve all possible parameter combinations */ 00040 CDynamicObjectArray* combinations= 00041 (CDynamicObjectArray*)m_model_parameters->get_combinations(); 00042 00043 CCrossValidationResult* best_result=new CCrossValidationResult(); 00044 00045 CParameterCombination* best_combination=NULL; 00046 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00047 { 00048 if (print_state) SG_PRINT("Direction is maximize\n") 00049 best_result->mean=CMath::ALMOST_NEG_INFTY; 00050 } 00051 else 00052 { 00053 if (print_state) SG_PRINT("Direction is minimize\n") 00054 best_result->mean=CMath::ALMOST_INFTY; 00055 } 00056 00057 /* underlying learning machine */ 00058 CMachine* machine=m_machine_eval->get_machine(); 00059 00060 /* apply all combinations and search for best one */ 00061 for (index_t i=0; i<combinations->get_num_elements(); ++i) 00062 { 00063 CParameterCombination* current_combination=(CParameterCombination*) 00064 combinations->get_element(i); 00065 00066 /* eventually print */ 00067 if (print_state) 00068 { 00069 SG_PRINT("trying combination:\n") 00070 current_combination->print_tree(); 00071 } 00072 00073 current_combination->apply_to_modsel_parameter( 00074 machine->m_model_selection_parameters); 00075 00076 /* note that this may implicitly lock and unlockthe machine */ 00077 CCrossValidationResult* result= 00078 (CCrossValidationResult*)(m_machine_eval->evaluate()); 00079 00080 if (result->get_result_type() != CROSSVALIDATION_RESULT) 00081 SG_ERROR("Evaluation result is not of type CCrossValidationResult!") 00082 00083 if (print_state) 00084 result->print_result(); 00085 00086 /* check if current result is better, delete old combinations */ 00087 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00088 { 00089 if (result->mean>best_result->mean) 00090 { 00091 if (best_combination) 00092 SG_UNREF(best_combination); 00093 00094 best_combination=(CParameterCombination*) 00095 combinations->get_element(i); 00096 00097 SG_REF(result); 00098 SG_UNREF(best_result); 00099 best_result=result; 00100 } 00101 else 00102 { 00103 CParameterCombination* combination=(CParameterCombination*) 00104 combinations->get_element(i); 00105 SG_UNREF(combination); 00106 } 00107 } 00108 else 00109 { 00110 if (result->mean<best_result->mean) 00111 { 00112 if (best_combination) 00113 SG_UNREF(best_combination); 00114 00115 best_combination=(CParameterCombination*) 00116 combinations->get_element(i); 00117 00118 SG_REF(result); 00119 SG_UNREF(best_result); 00120 best_result=result; 00121 } 00122 else 00123 { 00124 CParameterCombination* combination=(CParameterCombination*) 00125 combinations->get_element(i); 00126 SG_UNREF(combination); 00127 } 00128 } 00129 00130 SG_UNREF(result); 00131 SG_UNREF(current_combination); 00132 } 00133 00134 SG_UNREF(best_result); 00135 SG_UNREF(machine); 00136 SG_UNREF(combinations); 00137 00138 return best_combination; 00139 }