SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/features/StringFeatures.h> 00014 #include <shogun/labels/Labels.h> 00015 #include <shogun/labels/BinaryLabels.h> 00016 #include <shogun/labels/RegressionLabels.h> 00017 #include <shogun/distributions/LinearHMM.h> 00018 #include <shogun/classifier/PluginEstimate.h> 00019 00020 using namespace shogun; 00021 00022 CPluginEstimate::CPluginEstimate(float64_t pos_pseudo, float64_t neg_pseudo) 00023 : CMachine(), m_pos_pseudo(1e-10), m_neg_pseudo(1e-10), 00024 pos_model(NULL), neg_model(NULL), features(NULL) 00025 { 00026 m_parameters->add(&m_pos_pseudo, 00027 "pos_pseudo","pseudo count for positive class"); 00028 m_parameters->add(&m_neg_pseudo, 00029 "neg_pseudo", "pseudo count for negative class"); 00030 00031 m_parameters->add((CSGObject**) &pos_model, 00032 "pos_model", "LinearHMM modelling positive class."); 00033 m_parameters->add((CSGObject**) &neg_model, 00034 "neg_model", "LinearHMM modelling negative class."); 00035 00036 m_parameters->add((CSGObject**) &features, 00037 "features", "String Features."); 00038 } 00039 00040 CPluginEstimate::~CPluginEstimate() 00041 { 00042 SG_UNREF(pos_model); 00043 SG_UNREF(neg_model); 00044 00045 SG_UNREF(features); 00046 } 00047 00048 bool CPluginEstimate::train_machine(CFeatures* data) 00049 { 00050 ASSERT(m_labels) 00051 ASSERT(m_labels->get_label_type() == LT_BINARY) 00052 if (data) 00053 { 00054 if (data->get_feature_class() != C_STRING || 00055 data->get_feature_type() != F_WORD) 00056 { 00057 SG_ERROR("Features not of class string type word\n") 00058 } 00059 00060 set_features((CStringFeatures<uint16_t>*) data); 00061 } 00062 ASSERT(features) 00063 00064 SG_UNREF(pos_model); 00065 SG_UNREF(neg_model); 00066 00067 pos_model=new CLinearHMM(features); 00068 neg_model=new CLinearHMM(features); 00069 00070 SG_REF(pos_model); 00071 SG_REF(neg_model); 00072 00073 int32_t* pos_indizes=SG_MALLOC(int32_t, ((CStringFeatures<uint16_t>*) features)->get_num_vectors()); 00074 int32_t* neg_indizes=SG_MALLOC(int32_t, ((CStringFeatures<uint16_t>*) features)->get_num_vectors()); 00075 00076 ASSERT(m_labels->get_num_labels()==features->get_num_vectors()) 00077 00078 int32_t pos_idx=0; 00079 int32_t neg_idx=0; 00080 00081 for (int32_t i=0; i<m_labels->get_num_labels(); i++) 00082 { 00083 if (((CBinaryLabels*) m_labels)->get_label(i) > 0) 00084 pos_indizes[pos_idx++]=i; 00085 else 00086 neg_indizes[neg_idx++]=i; 00087 } 00088 00089 SG_INFO("training using pseudos %f and %f\n", m_pos_pseudo, m_neg_pseudo) 00090 pos_model->train(pos_indizes, pos_idx, m_pos_pseudo); 00091 neg_model->train(neg_indizes, neg_idx, m_neg_pseudo); 00092 00093 SG_FREE(pos_indizes); 00094 SG_FREE(neg_indizes); 00095 00096 return true; 00097 } 00098 00099 CBinaryLabels* CPluginEstimate::apply_binary(CFeatures* data) 00100 { 00101 if (data) 00102 { 00103 if (data->get_feature_class() != C_STRING || 00104 data->get_feature_type() != F_WORD) 00105 { 00106 SG_ERROR("Features not of class string type word\n") 00107 } 00108 00109 set_features((CStringFeatures<uint16_t>*) data); 00110 } 00111 00112 ASSERT(features) 00113 SGVector<float64_t> result(features->get_num_vectors()); 00114 00115 for (int32_t vec=0; vec<features->get_num_vectors(); vec++) 00116 result[vec] = apply_one(vec); 00117 00118 return new CBinaryLabels(result); 00119 } 00120 00121 float64_t CPluginEstimate::apply_one(int32_t vec_idx) 00122 { 00123 ASSERT(features) 00124 00125 int32_t len; 00126 bool free_vec; 00127 uint16_t* vector=features->get_feature_vector(vec_idx, len, free_vec); 00128 00129 if ((!pos_model) || (!neg_model)) 00130 SG_ERROR("model(s) not assigned\n") 00131 00132 float64_t result=pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len); 00133 features->free_feature_vector(vector, vec_idx, free_vec); 00134 return result; 00135 }