SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 #ifdef HAVE_LAPACK 00013 #include <shogun/transfer/domain_adaptation/DomainAdaptationMulticlassLibLinear.h> 00014 #include <shogun/labels/MulticlassLabels.h> 00015 00016 using namespace shogun; 00017 00018 CDomainAdaptationMulticlassLibLinear::CDomainAdaptationMulticlassLibLinear() : 00019 CMulticlassLibLinear() 00020 { 00021 init_defaults(); 00022 } 00023 00024 CDomainAdaptationMulticlassLibLinear::CDomainAdaptationMulticlassLibLinear( 00025 float64_t target_C, CDotFeatures* target_features, CLabels* target_labels, 00026 CLinearMulticlassMachine* source_machine) : 00027 CMulticlassLibLinear(target_C,target_features,target_labels) 00028 { 00029 init_defaults(); 00030 00031 set_source_machine(source_machine); 00032 } 00033 00034 void CDomainAdaptationMulticlassLibLinear::init_defaults() 00035 { 00036 m_train_factor = 1.0; 00037 m_source_bias = 0.5; 00038 m_source_machine = NULL; 00039 00040 register_parameters(); 00041 } 00042 00043 float64_t CDomainAdaptationMulticlassLibLinear::get_source_bias() const 00044 { 00045 return m_source_bias; 00046 } 00047 00048 void CDomainAdaptationMulticlassLibLinear::set_source_bias(float64_t source_bias) 00049 { 00050 m_source_bias = source_bias; 00051 } 00052 00053 float64_t CDomainAdaptationMulticlassLibLinear::get_train_factor() const 00054 { 00055 return m_train_factor; 00056 } 00057 00058 void CDomainAdaptationMulticlassLibLinear::set_train_factor(float64_t train_factor) 00059 { 00060 m_train_factor = train_factor; 00061 } 00062 00063 CLinearMulticlassMachine* CDomainAdaptationMulticlassLibLinear::get_source_machine() const 00064 { 00065 SG_REF(m_source_machine); 00066 return m_source_machine; 00067 } 00068 00069 void CDomainAdaptationMulticlassLibLinear::set_source_machine( 00070 CLinearMulticlassMachine* source_machine) 00071 { 00072 SG_REF(source_machine); 00073 SG_UNREF(m_source_machine); 00074 m_source_machine = source_machine; 00075 } 00076 00077 void CDomainAdaptationMulticlassLibLinear::register_parameters() 00078 { 00079 SG_ADD((CSGObject**)&m_source_machine, "source_machine", "source domain machine", 00080 MS_NOT_AVAILABLE); 00081 SG_ADD(&m_train_factor, "train_factor", "factor of target domain regularization", 00082 MS_AVAILABLE); 00083 SG_ADD(&m_source_bias, "source_bias", "bias to source domain", 00084 MS_AVAILABLE); 00085 } 00086 00087 CDomainAdaptationMulticlassLibLinear::~CDomainAdaptationMulticlassLibLinear() 00088 { 00089 } 00090 00091 SGMatrix<float64_t> CDomainAdaptationMulticlassLibLinear::obtain_regularizer_matrix() const 00092 { 00093 ASSERT(get_use_bias()==false) 00094 int32_t n_classes = ((CMulticlassLabels*)m_source_machine->get_labels())->get_num_classes(); 00095 int32_t n_features = ((CDotFeatures*)m_source_machine->get_features())->get_dim_feature_space(); 00096 SGMatrix<float64_t> w0(n_classes,n_features); 00097 00098 for (int32_t i=0; i<n_classes; i++) 00099 { 00100 SGVector<float64_t> w = ((CLinearMachine*)m_source_machine->get_machine(i))->get_w(); 00101 for (int32_t j=0; j<n_features; j++) 00102 w0(j,i) = m_train_factor*w[j]; 00103 } 00104 00105 return w0; 00106 } 00107 00108 CBinaryLabels* CDomainAdaptationMulticlassLibLinear::get_submachine_outputs(int32_t i) 00109 { 00110 CBinaryLabels* target_outputs = CMulticlassMachine::get_submachine_outputs(i); 00111 CBinaryLabels* source_outputs = m_source_machine->get_submachine_outputs(i); 00112 int32_t n_target_outputs = target_outputs->get_num_labels(); 00113 ASSERT(n_target_outputs==source_outputs->get_num_labels()) 00114 SGVector<float64_t> result(n_target_outputs); 00115 for (int32_t j=0; j<result.vlen; j++) 00116 result[j] = (1-m_source_bias)*target_outputs->get_value(j) + m_source_bias*source_outputs->get_value(j); 00117 00118 SG_UNREF(target_outputs); 00119 SG_UNREF(source_outputs); 00120 00121 return new CBinaryLabels(result); 00122 } 00123 #endif /* HAVE_LAPACK */