SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
RandomSearchModelSelection.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2011 Heiko Strathmann
00008  * Copyright (C) 2012 Sergey Lisitsyn
00009  */
00010 
00011 #include <shogun/modelselection/RandomSearchModelSelection.h>
00012 #include <shogun/modelselection/ParameterCombination.h>
00013 #include <shogun/modelselection/ModelSelectionParameters.h>
00014 #include <shogun/evaluation/CrossValidation.h>
00015 #include <shogun/mathematics/Statistics.h>
00016 #include <shogun/machine/Machine.h>
00017 
00018 using namespace shogun;
00019 
00020 CRandomSearchModelSelection::CRandomSearchModelSelection() : CModelSelection()
00021 {
00022     set_ratio(0.5);
00023 }
00024 
00025 CRandomSearchModelSelection::CRandomSearchModelSelection(
00026         CMachineEvaluation* machine_eval,
00027         CModelSelectionParameters* model_parameters, float64_t ratio)
00028         : CModelSelection(machine_eval, model_parameters)
00029 {
00030     set_ratio(ratio);
00031 }
00032 
00033 CRandomSearchModelSelection::~CRandomSearchModelSelection()
00034 {
00035 }
00036 
00037 CParameterCombination* CRandomSearchModelSelection::select_model(bool print_state)
00038 {
00039     if (print_state)
00040         SG_PRINT("Generating parameter combinations\n")
00041 
00042     /* Retrieve all possible parameter combinations */
00043     CDynamicObjectArray* all_combinations=
00044             (CDynamicObjectArray*)m_model_parameters->get_combinations();
00045 
00046     int32_t n_all_combinations=all_combinations->get_num_elements();
00047     SGVector<index_t> combinations_indices=CStatistics::sample_indices(n_all_combinations*m_ratio, n_all_combinations);
00048 
00049     CDynamicObjectArray* combinations=new CDynamicObjectArray();
00050 
00051     for (int32_t i=0; i<combinations_indices.vlen; i++)
00052         combinations->append_element(all_combinations->get_element(i));
00053 
00054     CCrossValidationResult* best_result=new CCrossValidationResult();
00055 
00056     CParameterCombination* best_combination=NULL;
00057     if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00058     {
00059         if (print_state) SG_PRINT("Direction is maximize\n")
00060         best_result->mean=CMath::ALMOST_NEG_INFTY;
00061     }
00062     else
00063     {
00064         if (print_state) SG_PRINT("Direction is minimize\n")
00065         best_result->mean=CMath::ALMOST_INFTY;
00066     }
00067 
00068     /* underlying learning machine */
00069     CMachine* machine=m_machine_eval->get_machine();
00070 
00071     /* apply all combinations and search for best one */
00072     for (index_t i=0; i<combinations->get_num_elements(); ++i)
00073     {
00074         CParameterCombination* current_combination=(CParameterCombination*)
00075                 combinations->get_element(i);
00076 
00077         /* eventually print */
00078         if (print_state)
00079         {
00080             SG_PRINT("trying combination:\n")
00081             current_combination->print_tree();
00082         }
00083 
00084         current_combination->apply_to_modsel_parameter(
00085                 machine->m_model_selection_parameters);
00086 
00087         /* note that this may implicitly lock and unlockthe machine */
00088         CCrossValidationResult* result =
00089                 (CCrossValidationResult*)(m_machine_eval->evaluate());
00090 
00091         if (result->get_result_type() != CROSSVALIDATION_RESULT)
00092             SG_ERROR("Evaluation result is not of type CCrossValidationResult!")
00093 
00094         if (print_state)
00095             result->print_result();
00096 
00097         /* check if current result is better, delete old combinations */
00098         if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00099         {
00100             if (result->mean>best_result->mean)
00101             {
00102                 if (best_combination)
00103                     SG_UNREF(best_combination);
00104 
00105                 best_combination=(CParameterCombination*)
00106                         combinations->get_element(i);
00107 
00108                 SG_REF(result);
00109                 SG_UNREF(best_result);
00110                 best_result=result;
00111             }
00112             else
00113             {
00114                 CParameterCombination* combination=(CParameterCombination*)
00115                         combinations->get_element(i);
00116                 SG_UNREF(combination);
00117             }
00118         }
00119         else
00120         {
00121             if (result->mean<best_result->mean)
00122             {
00123                 if (best_combination)
00124                     SG_UNREF(best_combination);
00125 
00126                 best_combination=(CParameterCombination*)
00127                         combinations->get_element(i);
00128 
00129                 SG_REF(result);
00130                 SG_UNREF(best_result);
00131                 best_result=result;
00132             }
00133             else
00134             {
00135                 CParameterCombination* combination=(CParameterCombination*)
00136                         combinations->get_element(i);
00137                 SG_UNREF(combination);
00138             }
00139         }
00140 
00141         SG_UNREF(result);
00142         SG_UNREF(current_combination);
00143     }
00144 
00145     SG_UNREF(best_result);
00146     SG_UNREF(machine);
00147     SG_UNREF(combinations);
00148 
00149     return best_combination;
00150 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation