SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2011 Christian Widmer 00008 * Copyright (C) 2007-2011 Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 00013 #ifdef USE_SVMLIGHT 00014 00015 #include <shogun/transfer/domain_adaptation/DomainAdaptationSVM.h> 00016 #include <shogun/io/SGIO.h> 00017 #include <shogun/labels/Labels.h> 00018 #include <shogun/labels/BinaryLabels.h> 00019 #include <shogun/labels/RegressionLabels.h> 00020 #include <iostream> 00021 #include <vector> 00022 00023 using namespace shogun; 00024 00025 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight() 00026 { 00027 init(); 00028 } 00029 00030 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab) 00031 { 00032 init(); 00033 init(pre_svm, B_param); 00034 } 00035 00036 CDomainAdaptationSVM::~CDomainAdaptationSVM() 00037 { 00038 SG_UNREF(presvm); 00039 SG_DEBUG("deleting DomainAdaptationSVM\n") 00040 } 00041 00042 00043 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param) 00044 { 00045 REQUIRE(pre_svm != NULL, "Pre SVM should not be null"); 00046 // increase reference counts 00047 SG_REF(pre_svm); 00048 00049 this->presvm=pre_svm; 00050 this->B=B_param; 00051 this->train_factor=1.0; 00052 00053 // set bias of parent svm to zero 00054 this->presvm->set_bias(0.0); 00055 00056 // invoke sanity check 00057 is_presvm_sane(); 00058 } 00059 00060 bool CDomainAdaptationSVM::is_presvm_sane() 00061 { 00062 if (!presvm) { 00063 SG_ERROR("presvm is null") 00064 } 00065 00066 if (presvm->get_num_support_vectors() == 0) { 00067 SG_ERROR("presvm has no support vectors, please train first") 00068 } 00069 00070 if (presvm->get_bias() != 0) { 00071 SG_ERROR("presvm bias not set to zero") 00072 } 00073 00074 if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) { 00075 SG_ERROR("kernel types do not agree") 00076 } 00077 00078 if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) { 00079 SG_ERROR("feature types do not agree") 00080 } 00081 00082 return true; 00083 } 00084 00085 00086 bool CDomainAdaptationSVM::train_machine(CFeatures* data) 00087 { 00088 00089 if (data) 00090 { 00091 if (m_labels->get_num_labels() != data->get_num_vectors()) 00092 SG_ERROR("Number of training vectors does not match number of labels\n") 00093 kernel->init(data, data); 00094 } 00095 00096 if (m_labels->get_label_type() != LT_BINARY) 00097 SG_ERROR("DomainAdaptationSVM requires binary labels\n") 00098 00099 int32_t num_training_points = get_labels()->get_num_labels(); 00100 CBinaryLabels* labels = (CBinaryLabels*) get_labels(); 00101 00102 float64_t* lin_term = SG_MALLOC(float64_t, num_training_points); 00103 00104 // grab current training features 00105 CFeatures* train_data = get_kernel()->get_lhs(); 00106 00107 // bias of parent SVM was set to zero in constructor, already contains B 00108 CBinaryLabels* parent_svm_out = presvm->apply_binary(train_data); 00109 00110 // pre-compute linear term 00111 for (int32_t i=0; i<num_training_points; i++) 00112 { 00113 lin_term[i] = train_factor * B * labels->get_label(i) * parent_svm_out->get_label(i) - 1.0; 00114 } 00115 00116 //set linear term for QP 00117 this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points)); 00118 00119 //train SVM 00120 bool success = CSVMLight::train_machine(); 00121 SG_UNREF(labels); 00122 00123 ASSERT(presvm) 00124 00125 return success; 00126 00127 } 00128 00129 00130 CSVM* CDomainAdaptationSVM::get_presvm() 00131 { 00132 SG_REF(presvm); 00133 return presvm; 00134 } 00135 00136 00137 float64_t CDomainAdaptationSVM::get_B() 00138 { 00139 return B; 00140 } 00141 00142 00143 float64_t CDomainAdaptationSVM::get_train_factor() 00144 { 00145 return train_factor; 00146 } 00147 00148 00149 void CDomainAdaptationSVM::set_train_factor(float64_t factor) 00150 { 00151 train_factor = factor; 00152 } 00153 00154 00155 CBinaryLabels* CDomainAdaptationSVM::apply_binary(CFeatures* data) 00156 { 00157 ASSERT(data) 00158 ASSERT(presvm->get_bias()==0.0) 00159 00160 int32_t num_examples = data->get_num_vectors(); 00161 00162 CBinaryLabels* out_current = CSVMLight::apply_binary(data); 00163 00164 // recursive call if used on DomainAdaptationSVM object 00165 CBinaryLabels* out_presvm = presvm->apply_binary(data); 00166 00167 // combine outputs 00168 SGVector<float64_t> out_combined(num_examples); 00169 for (int32_t i=0; i<num_examples; i++) 00170 { 00171 out_combined[i] = out_current->get_value(i) + B*out_presvm->get_value(i); 00172 } 00173 SG_UNREF(out_current); 00174 SG_UNREF(out_presvm); 00175 00176 return new CBinaryLabels(out_combined); 00177 00178 } 00179 00180 void CDomainAdaptationSVM::init() 00181 { 00182 presvm = NULL; 00183 B = 0; 00184 train_factor = 1.0; 00185 00186 SG_ADD((CSGObject**) &presvm, "presvm", "SVM to regularize against.", 00187 MS_NOT_AVAILABLE); 00188 SG_ADD(&B, "B", "regularization parameter B.", MS_AVAILABLE); 00189 SG_ADD(&train_factor, "train_factor", 00190 "flag to switch off regularization in training.", MS_AVAILABLE); 00191 } 00192 00193 #endif //USE_SVMLIGHT