SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Copyright (C) 2012 Fernando José Iglesias García 00009 */ 00010 00011 #ifndef _STRUCTURED_MODEL__H__ 00012 #define _STRUCTURED_MODEL__H__ 00013 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/features/Features.h> 00016 #include <shogun/labels/StructuredLabels.h> 00017 00018 #include <shogun/lib/common.h> 00019 #include <shogun/lib/SGVector.h> 00020 #include <shogun/lib/StructuredData.h> 00021 00022 namespace shogun 00023 { 00024 00025 #define IGNORE_IN_CLASSLIST 00026 00031 IGNORE_IN_CLASSLIST struct TMultipleCPinfo { 00037 TMultipleCPinfo(uint32_t from, uint32_t N) : m_from(from), m_N(N) { } 00039 uint32_t m_from; 00041 uint32_t m_N; 00042 }; 00043 00044 class CStructuredModel; 00045 00047 struct CResultSet : public CSGObject 00048 { 00050 CResultSet(); 00051 00053 virtual ~CResultSet(); 00054 00056 virtual const char* get_name() const; 00057 00059 CStructuredData* argmax; 00060 00062 SGVector< float64_t > psi_truth; 00063 00065 SGVector< float64_t > psi_pred; 00066 00069 float64_t score; 00070 00072 float64_t delta; 00073 }; 00074 00085 class CStructuredModel : public CSGObject 00086 { 00087 public: 00089 CStructuredModel(); 00090 00096 CStructuredModel(CFeatures* features, CStructuredLabels* labels); 00097 00099 virtual ~CStructuredModel(); 00100 00112 virtual void init_primal_opt( 00113 float64_t regularization, 00114 SGMatrix< float64_t > & A, SGVector< float64_t > a, 00115 SGMatrix< float64_t > B, SGVector< float64_t > & b, 00116 SGVector< float64_t > lb, SGVector< float64_t > ub, 00117 SGMatrix < float64_t > & C); 00118 00123 virtual int32_t get_dim() const = 0; 00124 00129 void set_labels(CStructuredLabels* labs); 00130 00135 CStructuredLabels* get_labels(); 00136 00138 virtual CStructuredLabels* structured_labels_factory(int32_t num_labels=0); 00139 00144 void set_features(CFeatures* feats); 00145 00150 CFeatures* get_features(); 00151 00164 SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, int32_t lab_idx); 00165 00178 virtual SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, CStructuredData* y); 00179 00193 virtual CResultSet* argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training = true) = 0; 00194 00202 float64_t delta_loss(int32_t ytrue_idx, CStructuredData* ypred); 00203 00211 virtual float64_t delta_loss(CStructuredData* y1, CStructuredData* y2); 00212 00214 virtual const char* get_name() const { return "StructuredModel"; } 00215 00220 virtual void init_training(); 00221 00229 virtual bool check_training_setup() const; 00230 00240 virtual int32_t get_num_aux() const; 00241 00251 virtual int32_t get_num_aux_con() const; 00252 00253 private: 00255 void init(); 00256 00257 protected: 00259 CStructuredLabels* m_labels; 00260 00262 CFeatures* m_features; 00263 00264 }; /* class CStructuredModel */ 00265 00266 } /* namespace shogun */ 00267 00268 #endif /* _STRUCTURED_MODEL__H__ */