SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Viktor Gal 00008 * Copyright (C) 2012 Viktor Gal 00009 */ 00010 00011 #include <shogun/latent/LatentModel.h> 00012 #include <shogun/labels/BinaryLabels.h> 00013 00014 using namespace shogun; 00015 00016 CLatentModel::CLatentModel() 00017 : m_features(NULL), 00018 m_labels(NULL), 00019 m_do_caching(false), 00020 m_cached_psi(NULL) 00021 { 00022 register_parameters(); 00023 } 00024 00025 CLatentModel::CLatentModel(CLatentFeatures* feats, CLatentLabels* labels, bool do_caching) 00026 : m_features(feats), 00027 m_labels(labels), 00028 m_do_caching(do_caching), 00029 m_cached_psi(NULL) 00030 { 00031 register_parameters(); 00032 SG_REF(m_features); 00033 SG_REF(m_labels); 00034 } 00035 00036 CLatentModel::~CLatentModel() 00037 { 00038 SG_UNREF(m_labels); 00039 SG_UNREF(m_features); 00040 SG_UNREF(m_cached_psi); 00041 } 00042 00043 int32_t CLatentModel::get_num_vectors() const 00044 { 00045 return m_features->get_num_vectors(); 00046 } 00047 00048 void CLatentModel::set_labels(CLatentLabels* labs) 00049 { 00050 SG_REF(labs); 00051 SG_UNREF(m_labels); 00052 m_labels = labs; 00053 } 00054 00055 CLatentLabels* CLatentModel::get_labels() const 00056 { 00057 SG_REF(m_labels); 00058 return m_labels; 00059 } 00060 00061 void CLatentModel::set_features(CLatentFeatures* feats) 00062 { 00063 SG_REF(feats); 00064 SG_UNREF(m_features); 00065 m_features = feats; 00066 } 00067 00068 void CLatentModel::argmax_h(const SGVector<float64_t>& w) 00069 { 00070 int32_t num = get_num_vectors(); 00071 CBinaryLabels* y = CLabelsFactory::to_binary(m_labels->get_labels()); 00072 ASSERT(num > 0) 00073 ASSERT(num == m_labels->get_num_labels()) 00074 00075 // argmax_h only for positive examples 00076 for (int32_t i = 0; i < num; ++i) 00077 { 00078 if (y->get_label(i) == 1) 00079 { 00080 // infer h and set it for the argmax_h <w,psi(x,h)> 00081 CData* latent_data = infer_latent_variable(w, i); 00082 m_labels->set_latent_label(i, latent_data); 00083 } 00084 } 00085 } 00086 00087 void CLatentModel::register_parameters() 00088 { 00089 m_parameters->add((CSGObject**) &m_features, "features", "Latent features"); 00090 m_parameters->add((CSGObject**) &m_labels, "labels", "Latent labels"); 00091 m_parameters->add((CSGObject**) &m_cached_psi, "cached_psi", "Cached PSI features after argmax_h"); 00092 m_parameters->add(&m_do_caching, "do_caching", "Indicate whether or not do PSI vector caching after argmax_h"); 00093 } 00094 00095 00096 CLatentFeatures* CLatentModel::get_features() const 00097 { 00098 SG_REF(m_features); 00099 return m_features; 00100 } 00101 00102 void CLatentModel::cache_psi_features() 00103 { 00104 if (m_do_caching) 00105 { 00106 if (m_cached_psi) 00107 SG_UNREF(m_cached_psi); 00108 m_cached_psi = this->get_psi_feature_vectors(); 00109 SG_REF(m_cached_psi); 00110 } 00111 } 00112 00113 CDotFeatures* CLatentModel::get_cached_psi_features() const 00114 { 00115 if (m_do_caching) 00116 { 00117 SG_REF(m_cached_psi); 00118 return m_cached_psi; 00119 } 00120 return NULL; 00121 }