SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2010 Soeren Sonnenburg 00008 * Copyright (C) 2010 Berlin Institute of Technology 00009 */ 00010 00011 #ifndef _SCATTERKERNELNORMALIZER_H___ 00012 #define _SCATTERKERNELNORMALIZER_H___ 00013 00014 #include <shogun/kernel/normalizer/KernelNormalizer.h> 00015 #include <shogun/kernel/normalizer/IdentityKernelNormalizer.h> 00016 #include <shogun/kernel/Kernel.h> 00017 #include <shogun/labels/Labels.h> 00018 #include <shogun/labels/MulticlassLabels.h> 00019 #include <shogun/io/SGIO.h> 00020 00021 namespace shogun 00022 { 00024 class CScatterKernelNormalizer: public CKernelNormalizer 00025 { 00026 00027 public: 00029 CScatterKernelNormalizer() : CKernelNormalizer() 00030 { 00031 init(); 00032 } 00033 00036 CScatterKernelNormalizer(float64_t const_diag, float64_t const_offdiag, 00037 CLabels* labels,CKernelNormalizer* normalizer=NULL) 00038 : CKernelNormalizer() 00039 { 00040 init(); 00041 00042 m_testing_class=-1; 00043 m_const_diag=const_diag; 00044 m_const_offdiag=const_offdiag; 00045 00046 ASSERT(labels) 00047 SG_REF(labels); 00048 m_labels=labels; 00049 ASSERT(labels->get_label_type()==LT_MULTICLASS) 00050 labels->ensure_valid(); 00051 00052 if (normalizer==NULL) 00053 normalizer=new CIdentityKernelNormalizer(); 00054 SG_REF(normalizer); 00055 m_normalizer=normalizer; 00056 00057 SG_DEBUG("Constructing ScatterKernelNormalizer with const_diag=%g" 00058 " const_offdiag=%g num_labels=%d and normalizer='%s'\n", 00059 const_diag, const_offdiag, labels->get_num_labels(), 00060 normalizer->get_name()); 00061 } 00062 00064 virtual ~CScatterKernelNormalizer() 00065 { 00066 SG_UNREF(m_labels); 00067 SG_UNREF(m_normalizer); 00068 } 00069 00072 virtual bool init(CKernel* k) 00073 { 00074 m_normalizer->init(k); 00075 return true; 00076 } 00077 00082 int32_t get_testing_class() 00083 { 00084 return m_testing_class; 00085 } 00086 00091 void set_testing_class(int32_t c) 00092 { 00093 m_testing_class=c; 00094 } 00095 00101 virtual float64_t normalize(float64_t value, int32_t idx_lhs, 00102 int32_t idx_rhs) 00103 { 00104 value=m_normalizer->normalize(value, idx_lhs, idx_rhs); 00105 float64_t c=m_const_offdiag; 00106 00107 if (m_testing_class>=0) 00108 { 00109 if (((CMulticlassLabels*) m_labels)->get_label(idx_lhs) == m_testing_class) 00110 c=m_const_diag; 00111 } 00112 else 00113 { 00114 if (((CMulticlassLabels*) m_labels)->get_label(idx_lhs) == ((CMulticlassLabels*) m_labels)->get_label(idx_rhs)) 00115 c=m_const_diag; 00116 00117 } 00118 return value*c; 00119 } 00120 00125 virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs) 00126 { 00127 SG_ERROR("normalize_lhs not implemented") 00128 return 0; 00129 } 00130 00135 virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs) 00136 { 00137 SG_ERROR("normalize_rhs not implemented") 00138 return 0; 00139 } 00140 00142 virtual const char* get_name() const 00143 { 00144 return "ScatterKernelNormalizer"; 00145 } 00146 00147 private: 00148 void init() 00149 { 00150 m_const_diag = 1.0; 00151 m_const_offdiag = 1.0; 00152 00153 m_labels = NULL; 00154 m_normalizer = NULL; 00155 00156 m_testing_class = -1; 00157 00158 SG_ADD(&m_testing_class, "m_testing_class", 00159 "Testing Class.", MS_NOT_AVAILABLE); 00160 SG_ADD(&m_const_diag, "m_const_diag", 00161 "Factor to multiply to diagonal elements.", MS_AVAILABLE); 00162 SG_ADD(&m_const_offdiag, "m_const_offdiag", 00163 "Factor to multiply to off-diagonal elements.", MS_AVAILABLE); 00164 00165 SG_ADD((CSGObject**) &m_labels, "m_labels", "Labels", MS_NOT_AVAILABLE); 00166 SG_ADD((CSGObject**) &m_normalizer, "m_normalizer", "Kernel normalizer.", 00167 MS_AVAILABLE); 00168 } 00169 00170 protected: 00171 00173 float64_t m_const_diag; 00175 float64_t m_const_offdiag; 00176 00178 CLabels* m_labels; 00179 00181 CKernelNormalizer* m_normalizer; 00182 00184 int32_t m_testing_class; 00185 }; 00186 } 00187 #endif 00188