SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012-2013 Heiko Strathmann 00008 */ 00009 00010 #include <shogun/statistics/MMDKernelSelection.h> 00011 #include <shogun/kernel/CombinedKernel.h> 00012 #include <shogun/statistics/KernelTwoSampleTestStatistic.h> 00013 #include <shogun/statistics/LinearTimeMMD.h> 00014 #include <shogun/statistics/QuadraticTimeMMD.h> 00015 00016 00017 using namespace shogun; 00018 00019 CMMDKernelSelection::CMMDKernelSelection() 00020 { 00021 init(); 00022 } 00023 00024 CMMDKernelSelection::CMMDKernelSelection( 00025 CKernelTwoSampleTestStatistic* mmd) 00026 { 00027 init(); 00028 00029 /* ensure that mmd contains an instance of a MMD related class */ 00030 REQUIRE(mmd, "CMMDKernelSelection::CMMDKernelSelection(): No MMD instance " 00031 "provided!\n"); 00032 REQUIRE(mmd->get_statistic_type()==S_LINEAR_TIME_MMD || 00033 mmd->get_statistic_type()==S_QUADRATIC_TIME_MMD, 00034 "CMMDKernelSelection::CMMDKernelSelection(): provided instance " 00035 "for kernel two sample testing has to be a MMD-based class! The " 00036 "provided is of class \"%s\"\n", mmd->get_name()); 00037 00038 /* ensure that there is a combined kernel */ 00039 CKernel* kernel=mmd->get_kernel(); 00040 REQUIRE(kernel, "CMMDKernelSelection::CMMDKernelSelection(): underlying " 00041 "\"%s\" has no kernel set!\n", mmd->get_name()); 00042 REQUIRE(kernel->get_kernel_type()==K_COMBINED, "CMMDKernelSelection::" 00043 "CMMDKernelSelection(): kernel of underlying \"%s\" is of type \"%s\"" 00044 " but is has to be CCombinedKernel\n", mmd->get_name(), 00045 kernel->get_name()); 00046 SG_UNREF(kernel); 00047 00048 m_mmd=mmd; 00049 SG_REF(m_mmd); 00050 } 00051 00052 00053 CMMDKernelSelection::~CMMDKernelSelection() 00054 { 00055 SG_UNREF(m_mmd); 00056 } 00057 00058 void CMMDKernelSelection::init() 00059 { 00060 m_mmd=NULL; 00061 00062 SG_ADD((CSGObject**)&m_mmd, "mmd", "Underlying MMD instance", 00063 MS_NOT_AVAILABLE); 00064 } 00065 00066 CKernel* CMMDKernelSelection::select_kernel() 00067 { 00068 SG_DEBUG("entering CMMDKernelSelection::select_kernel()\n") 00069 00070 /* compute measures and return single kernel with maximum measure */ 00071 SGVector<float64_t> measures=compute_measures(); 00072 00073 /* find maximum and return corresponding kernel */ 00074 float64_t max=measures[0]; 00075 index_t max_idx=0; 00076 for (index_t i=1; i<measures.vlen; ++i) 00077 { 00078 if (measures[i]>max) 00079 { 00080 max=measures[i]; 00081 max_idx=i; 00082 } 00083 } 00084 00085 /* find kernel with corresponding index */ 00086 CCombinedKernel* combined=(CCombinedKernel*)m_mmd->get_kernel(); 00087 CKernel* current=combined->get_kernel(max_idx); 00088 00089 SG_UNREF(combined); 00090 SG_DEBUG("leaving CMMDKernelSelection::select_kernel()\n"); 00091 00092 /* current is not SG_UNREF'ed nor SG_REF'ed since the counter needs to be 00093 * incremented exactly by one */ 00094 return current; 00095 } 00096