SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
DomainAdaptationSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2011 Christian Widmer
00008  * Copyright (C) 2007-2011 Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 
00013 #ifdef USE_SVMLIGHT
00014 
00015 #include <shogun/transfer/domain_adaptation/DomainAdaptationSVM.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/BinaryLabels.h>
00019 #include <shogun/labels/RegressionLabels.h>
00020 #include <iostream>
00021 #include <vector>
00022 
00023 using namespace shogun;
00024 
00025 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight()
00026 {
00027     init();
00028 }
00029 
00030 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab)
00031 {
00032     init();
00033     init(pre_svm, B_param);
00034 }
00035 
00036 CDomainAdaptationSVM::~CDomainAdaptationSVM()
00037 {
00038     SG_UNREF(presvm);
00039     SG_DEBUG("deleting DomainAdaptationSVM\n")
00040 }
00041 
00042 
00043 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
00044 {
00045     REQUIRE(pre_svm != NULL, "Pre SVM should not be null");
00046     // increase reference counts
00047     SG_REF(pre_svm);
00048 
00049     this->presvm=pre_svm;
00050     this->B=B_param;
00051     this->train_factor=1.0;
00052 
00053     // set bias of parent svm to zero
00054     this->presvm->set_bias(0.0);
00055 
00056     // invoke sanity check
00057     is_presvm_sane();
00058 }
00059 
00060 bool CDomainAdaptationSVM::is_presvm_sane()
00061 {
00062     if (!presvm) {
00063         SG_ERROR("presvm is null")
00064     }
00065 
00066     if (presvm->get_num_support_vectors() == 0) {
00067         SG_ERROR("presvm has no support vectors, please train first")
00068     }
00069 
00070     if (presvm->get_bias() != 0) {
00071         SG_ERROR("presvm bias not set to zero")
00072     }
00073 
00074     if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) {
00075         SG_ERROR("kernel types do not agree")
00076     }
00077 
00078     if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) {
00079         SG_ERROR("feature types do not agree")
00080     }
00081 
00082     return true;
00083 }
00084 
00085 
00086 bool CDomainAdaptationSVM::train_machine(CFeatures* data)
00087 {
00088 
00089     if (data)
00090     {
00091         if (m_labels->get_num_labels() != data->get_num_vectors())
00092             SG_ERROR("Number of training vectors does not match number of labels\n")
00093         kernel->init(data, data);
00094     }
00095 
00096     if (m_labels->get_label_type() != LT_BINARY)
00097         SG_ERROR("DomainAdaptationSVM requires binary labels\n")
00098 
00099     int32_t num_training_points = get_labels()->get_num_labels();
00100     CBinaryLabels* labels = (CBinaryLabels*) get_labels();
00101 
00102     float64_t* lin_term = SG_MALLOC(float64_t, num_training_points);
00103 
00104     // grab current training features
00105     CFeatures* train_data = get_kernel()->get_lhs();
00106 
00107     // bias of parent SVM was set to zero in constructor, already contains B
00108     CBinaryLabels* parent_svm_out = presvm->apply_binary(train_data);
00109 
00110     // pre-compute linear term
00111     for (int32_t i=0; i<num_training_points; i++)
00112     {
00113         lin_term[i] = train_factor * B * labels->get_label(i) * parent_svm_out->get_label(i) - 1.0;
00114     }
00115 
00116     //set linear term for QP
00117     this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points));
00118 
00119     //train SVM
00120     bool success = CSVMLight::train_machine();
00121     SG_UNREF(labels);
00122 
00123     ASSERT(presvm)
00124 
00125     return success;
00126 
00127 }
00128 
00129 
00130 CSVM* CDomainAdaptationSVM::get_presvm()
00131 {
00132     SG_REF(presvm);
00133     return presvm;
00134 }
00135 
00136 
00137 float64_t CDomainAdaptationSVM::get_B()
00138 {
00139     return B;
00140 }
00141 
00142 
00143 float64_t CDomainAdaptationSVM::get_train_factor()
00144 {
00145     return train_factor;
00146 }
00147 
00148 
00149 void CDomainAdaptationSVM::set_train_factor(float64_t factor)
00150 {
00151     train_factor = factor;
00152 }
00153 
00154 
00155 CBinaryLabels* CDomainAdaptationSVM::apply_binary(CFeatures* data)
00156 {
00157     ASSERT(data)
00158     ASSERT(presvm->get_bias()==0.0)
00159 
00160     int32_t num_examples = data->get_num_vectors();
00161 
00162     CBinaryLabels* out_current = CSVMLight::apply_binary(data);
00163 
00164     // recursive call if used on DomainAdaptationSVM object
00165     CBinaryLabels* out_presvm = presvm->apply_binary(data);
00166 
00167     // combine outputs
00168     SGVector<float64_t> out_combined(num_examples);
00169     for (int32_t i=0; i<num_examples; i++)
00170     {
00171         out_combined[i] = out_current->get_value(i) + B*out_presvm->get_value(i);
00172     }
00173     SG_UNREF(out_current);
00174     SG_UNREF(out_presvm);
00175 
00176     return new CBinaryLabels(out_combined);
00177 
00178 }
00179 
00180 void CDomainAdaptationSVM::init()
00181 {
00182     presvm = NULL;
00183     B = 0;
00184     train_factor = 1.0;
00185 
00186     SG_ADD((CSGObject**) &presvm, "presvm", "SVM to regularize against.",
00187             MS_NOT_AVAILABLE);
00188     SG_ADD(&B, "B", "regularization parameter B.", MS_AVAILABLE);
00189     SG_ADD(&train_factor, "train_factor",
00190             "flag to switch off regularization in training.", MS_AVAILABLE);
00191 }
00192 
00193 #endif //USE_SVMLIGHT
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation