SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MMDKernelSelectionMedian.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Heiko Strathmann
00008  */
00009 
00010 #include <shogun/statistics/MMDKernelSelectionMedian.h>
00011 #include <shogun/statistics/LinearTimeMMD.h>
00012 #include <shogun/features/streaming/StreamingFeatures.h>
00013 #include <shogun/statistics/QuadraticTimeMMD.h>
00014 #include <shogun/distance/EuclideanDistance.h>
00015 #include <shogun/kernel/GaussianKernel.h>
00016 #include <shogun/kernel/CombinedKernel.h>
00017 #include <shogun/mathematics/Statistics.h>
00018 
00019 
00020 using namespace shogun;
00021 
00022 CMMDKernelSelectionMedian::CMMDKernelSelectionMedian() :
00023         CMMDKernelSelection()
00024 {
00025     init();
00026 }
00027 
00028 CMMDKernelSelectionMedian::CMMDKernelSelectionMedian(
00029         CKernelTwoSampleTestStatistic* mmd, index_t num_data_distance) :
00030         CMMDKernelSelection(mmd)
00031 {
00032     /* assert that a combined kernel is used */
00033     CKernel* kernel=mmd->get_kernel();
00034     CFeatures* lhs=kernel->get_lhs();
00035     CFeatures* rhs=kernel->get_rhs();
00036     REQUIRE(kernel, "%s::%s(): No kernel set!\n", get_name(), get_name());
00037     REQUIRE(kernel->get_kernel_type()==K_COMBINED, "%s::%s(): Requires "
00038             "CombinedKernel as kernel. Yours is %s", get_name(), get_name(),
00039             kernel->get_name());
00040 
00041     /* assert that all subkernels are Gaussian kernels */
00042     CCombinedKernel* combined=(CCombinedKernel*)kernel;
00043 
00044     for (index_t k_idx=0; k_idx<combined->get_num_kernels(); k_idx++)
00045     {
00046         CKernel* subkernel=combined->get_kernel(k_idx);
00047         REQUIRE(kernel, "%s::%s(): Subkernel (%d) of current kernel is not"
00048                 " of type GaussianKernel\n", get_name(), get_name(), k_idx);
00049         SG_UNREF(subkernel);
00050     }
00051 
00052     /* assert 64 bit dense features since EuclideanDistance can only handle
00053      * those */
00054     if (m_mmd->get_statistic_type()==S_QUADRATIC_TIME_MMD)
00055     {
00056         CFeatures* features=((CQuadraticTimeMMD*)m_mmd)->get_p_and_q();
00057         REQUIRE(features->get_feature_class()==C_DENSE &&
00058                 features->get_feature_type()==F_DREAL, "%s::select_kernel(): "
00059                 "Only 64 bit float dense features allowed, these are \"%s\""
00060                 " and of type %d\n",
00061                 get_name(), features->get_name(), features->get_feature_type());
00062         SG_UNREF(features);
00063     }
00064     else if (m_mmd->get_statistic_type()==S_LINEAR_TIME_MMD)
00065     {
00066         CStreamingFeatures* p=((CLinearTimeMMD*)m_mmd)->get_streaming_p();
00067         CStreamingFeatures* q=((CLinearTimeMMD*)m_mmd)->get_streaming_q();
00068         REQUIRE(p->get_feature_class()==C_STREAMING_DENSE &&
00069                 p->get_feature_type()==F_DREAL, "%s::select_kernel(): "
00070                 "Only 64 bit float streaming dense features allowed, these (p) "
00071                 "are \"%s\" and of type %d\n",
00072                 get_name(), p->get_name(), p->get_feature_type());
00073 
00074         REQUIRE(p->get_feature_class()==C_STREAMING_DENSE &&
00075                 p->get_feature_type()==F_DREAL, "%s::select_kernel(): "
00076                 "Only 64 bit float streaming dense features allowed, these (q) "
00077                 "are \"%s\" and of type %d\n",
00078                 get_name(), q->get_name(), q->get_feature_type());
00079         SG_UNREF(p);
00080         SG_UNREF(q);
00081     }
00082 
00083     SG_UNREF(kernel);
00084     SG_UNREF(lhs);
00085     SG_UNREF(rhs);
00086 
00087     init();
00088 
00089     m_num_data_distance=num_data_distance;
00090 }
00091 
00092 CMMDKernelSelectionMedian::~CMMDKernelSelectionMedian()
00093 {
00094 }
00095 
00096 void CMMDKernelSelectionMedian::init()
00097 {
00098     SG_ADD(&m_num_data_distance, "m_num_data_distance", "Number of elements to "
00099             "to compute median distance on", MS_NOT_AVAILABLE);
00100 
00101     /* this is a sensible value */
00102     m_num_data_distance=1000;
00103 }
00104 
00105 SGVector<float64_t> CMMDKernelSelectionMedian::compute_measures()
00106 {
00107     SG_ERROR("%s::compute_measures(): Not implemented. Use select_kernel() "
00108             "method!\n", get_name());
00109     return SGVector<float64_t>();
00110 }
00111 
00112 CKernel* CMMDKernelSelectionMedian::select_kernel()
00113 {
00114     /* number of data for distace */
00115     index_t num_data=CMath::min(m_num_data_distance, m_mmd->get_m());
00116 
00117     SGMatrix<float64_t> dists;
00118 
00119     /* compute all pairwise distances, depends which mmd statistic is used */
00120     if (m_mmd->get_statistic_type()==S_QUADRATIC_TIME_MMD)
00121     {
00122         /* fixed data, create merged copy of a random subset */
00123 
00124         /* create vector with that correspond to the num_data first points of
00125          * each distribution, remember data is stored jointly */
00126         SGVector<index_t> subset(num_data*2);
00127         index_t m=m_mmd->get_m();
00128         for (index_t i=0; i<num_data; ++i)
00129         {
00130             /* num_data samples from each half of joint sample */
00131             subset[i]=i;
00132             subset[i+num_data]=i+m;
00133         }
00134 
00135         /* add subset and compute pairwise distances */
00136         CQuadraticTimeMMD* quad_mmd=(CQuadraticTimeMMD*)m_mmd;
00137         CFeatures* features=quad_mmd->get_p_and_q();
00138         features->add_subset(subset);
00139 
00140         /* cast is safe, see constructor */
00141         CDenseFeatures<float64_t>* dense_features=
00142                 (CDenseFeatures<float64_t>*) features;
00143 
00144         CEuclideanDistance* distance=new CEuclideanDistance(dense_features,
00145                 dense_features);
00146         dists=distance->get_distance_matrix();
00147         features->remove_subset();
00148         SG_UNREF(distance);
00149         SG_UNREF(features);
00150     }
00151     else if (m_mmd->get_statistic_type()==S_LINEAR_TIME_MMD)
00152     {
00153         /* just stream the desired number of points */
00154         CLinearTimeMMD* linear_mmd=(CLinearTimeMMD*)m_mmd;
00155 
00156         CStreamingFeatures* p=linear_mmd->get_streaming_p();
00157         CStreamingFeatures* q=linear_mmd->get_streaming_q();
00158 
00159         /* cast is safe, see constructor */
00160         CDenseFeatures<float64_t>* p_streamed=(CDenseFeatures<float64_t>*)
00161                 p->get_streamed_features(num_data);
00162         CDenseFeatures<float64_t>* q_streamed=(CDenseFeatures<float64_t>*)
00163                     q->get_streamed_features(num_data);
00164 
00165         /* for safety */
00166         SG_REF(p_streamed);
00167         SG_REF(q_streamed);
00168 
00169         /* create merged feature object */
00170         CDenseFeatures<float64_t>* merged=(CDenseFeatures<float64_t>*)
00171                 p_streamed->create_merged_copy(q_streamed);
00172 
00173         /* compute pairwise distances */
00174         CEuclideanDistance* distance=new CEuclideanDistance(merged, merged);
00175         dists=distance->get_distance_matrix();
00176 
00177         /* clean up */
00178         SG_UNREF(distance);
00179         SG_UNREF(p_streamed);
00180         SG_UNREF(q_streamed);
00181         SG_UNREF(p);
00182         SG_UNREF(q);
00183     }
00184 
00185     /* create a vector where the zeros have been removed, use upper triangle
00186      * only since distances are symmetric */
00187     SGVector<float64_t> dist_vec(dists.num_rows*(dists.num_rows-1)/2);
00188     index_t write_idx=0;
00189     for (index_t i=0; i<dists.num_rows; ++i)
00190     {
00191         for (index_t j=i+1; j<dists.num_rows; ++j)
00192             dist_vec[write_idx++]=dists(i,j);
00193     }
00194 
00195     /* now we have distance matrix, compute median, allow to modify matrix */
00196     float64_t median_distance=CStatistics::median(dist_vec, true);
00197     SG_DEBUG("median_distance: %f\n", median_distance);
00198 
00199     /* shogun has no square and factor two in its kernel width, MATLAB does
00200      * median_width = sqrt(0.5*median_distance), we do this */
00201     float64_t shogun_sigma=median_distance;
00202     SG_DEBUG("kernel width (shogun): %f\n", shogun_sigma);
00203 
00204     /* now of all kernels, find the one which has its width closest
00205      * Cast is safe due to constructor of MMDKernelSelection class */
00206     CCombinedKernel* combined=(CCombinedKernel*)m_mmd->get_kernel();
00207     float64_t min_distance=CMath::MAX_REAL_NUMBER;
00208     CKernel* min_kernel=NULL;
00209     float64_t distance;
00210     for (index_t i=0; i<combined->get_num_subkernels(); ++i)
00211     {
00212         CKernel* current=combined->get_kernel(i);
00213         REQUIRE(current->get_kernel_type()==K_GAUSSIAN, "%s::select_kernel(): "
00214                 "%d-th kernel is not a Gaussian but \"%s\"!\n", get_name(), i,
00215                 current->get_name());
00216 
00217         /* check if width is closer to median width */
00218         distance=CMath::abs(((CGaussianKernel*)current)->get_width()-
00219                 shogun_sigma);
00220 
00221         if (distance<min_distance)
00222         {
00223             min_distance=distance;
00224             min_kernel=current;
00225         }
00226 
00227         /* next kernel */
00228         SG_UNREF(current);
00229     }
00230     SG_UNREF(combined);
00231 
00232     /* returned referenced kernel */
00233     SG_REF(min_kernel);
00234     return min_kernel;
00235 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation