SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
GridSearchModelSelection.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011-2012 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/modelselection/GridSearchModelSelection.h>
00012 #include <shogun/modelselection/ParameterCombination.h>
00013 #include <shogun/modelselection/ModelSelectionParameters.h>
00014 #include <shogun/evaluation/CrossValidation.h>
00015 #include <shogun/machine/Machine.h>
00016 
00017 using namespace shogun;
00018 
00019 CGridSearchModelSelection::CGridSearchModelSelection() : CModelSelection()
00020 {
00021 }
00022 
00023 CGridSearchModelSelection::CGridSearchModelSelection(
00024         CMachineEvaluation* machine_eval,
00025         CModelSelectionParameters* model_parameters)
00026         : CModelSelection(machine_eval, model_parameters)
00027 {
00028 }
00029 
00030 CGridSearchModelSelection::~CGridSearchModelSelection()
00031 {
00032 }
00033 
00034 CParameterCombination* CGridSearchModelSelection::select_model(bool print_state)
00035 {
00036     if (print_state)
00037         SG_PRINT("Generating parameter combinations\n")
00038 
00039     /* Retrieve all possible parameter combinations */
00040     CDynamicObjectArray* combinations=
00041             (CDynamicObjectArray*)m_model_parameters->get_combinations();
00042 
00043     CCrossValidationResult* best_result=new CCrossValidationResult();
00044 
00045     CParameterCombination* best_combination=NULL;
00046     if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00047     {
00048         if (print_state) SG_PRINT("Direction is maximize\n")
00049         best_result->mean=CMath::ALMOST_NEG_INFTY;
00050     }
00051     else
00052     {
00053         if (print_state) SG_PRINT("Direction is minimize\n")
00054         best_result->mean=CMath::ALMOST_INFTY;
00055     }
00056 
00057     /* underlying learning machine */
00058     CMachine* machine=m_machine_eval->get_machine();
00059 
00060     /* apply all combinations and search for best one */
00061     for (index_t i=0; i<combinations->get_num_elements(); ++i)
00062     {
00063         CParameterCombination* current_combination=(CParameterCombination*)
00064                 combinations->get_element(i);
00065 
00066         /* eventually print */
00067         if (print_state)
00068         {
00069             SG_PRINT("trying combination:\n")
00070             current_combination->print_tree();
00071         }
00072 
00073         current_combination->apply_to_modsel_parameter(
00074                 machine->m_model_selection_parameters);
00075 
00076         /* note that this may implicitly lock and unlockthe machine */
00077         CCrossValidationResult* result=
00078                 (CCrossValidationResult*)(m_machine_eval->evaluate());
00079 
00080         if (result->get_result_type() != CROSSVALIDATION_RESULT)
00081             SG_ERROR("Evaluation result is not of type CCrossValidationResult!")
00082 
00083         if (print_state)
00084             result->print_result();
00085 
00086         /* check if current result is better, delete old combinations */
00087         if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00088         {
00089             if (result->mean>best_result->mean)
00090             {
00091                 if (best_combination)
00092                     SG_UNREF(best_combination);
00093 
00094                 best_combination=(CParameterCombination*)
00095                         combinations->get_element(i);
00096 
00097                 SG_REF(result);
00098                 SG_UNREF(best_result);
00099                 best_result=result;
00100             }
00101             else
00102             {
00103                 CParameterCombination* combination=(CParameterCombination*)
00104                         combinations->get_element(i);
00105                 SG_UNREF(combination);
00106             }
00107         }
00108         else
00109         {
00110             if (result->mean<best_result->mean)
00111             {
00112                 if (best_combination)
00113                     SG_UNREF(best_combination);
00114 
00115                 best_combination=(CParameterCombination*)
00116                         combinations->get_element(i);
00117 
00118                 SG_REF(result);
00119                 SG_UNREF(best_result);
00120                 best_result=result;
00121             }
00122             else
00123             {
00124                 CParameterCombination* combination=(CParameterCombination*)
00125                         combinations->get_element(i);
00126                 SG_UNREF(combination);
00127             }
00128         }
00129 
00130         SG_UNREF(result);
00131         SG_UNREF(current_combination);
00132     }
00133 
00134     SG_UNREF(best_result);
00135     SG_UNREF(machine);
00136     SG_UNREF(combinations);
00137 
00138     return best_combination;
00139 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation