SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Written (W) 2012 Fernando José Iglesias García 00009 * Copyright (C) 2012 Fernando José Iglesias García 00010 */ 00011 00012 #ifndef _STRUCTURED_OUTPUT_MACHINE__H__ 00013 #define _STRUCTURED_OUTPUT_MACHINE__H__ 00014 00015 #include <shogun/labels/StructuredLabels.h> 00016 #include <shogun/lib/StructuredData.h> 00017 #include <shogun/machine/Machine.h> 00018 #include <shogun/structure/StructuredModel.h> 00019 #include <shogun/loss/LossFunction.h> 00020 #include <shogun/structure/SOSVMHelper.h> 00021 00022 namespace shogun 00023 { 00024 00030 enum EStructRiskType 00031 { 00032 N_SLACK_MARGIN_RESCALING = 0, 00033 N_SLACK_SLACK_RESCALING = 1, 00034 ONE_SLACK_MARGIN_RESCALING = 2, 00035 ONE_SLACK_SLACK_RESCALING = 3, 00036 CUSTOMIZED_RISK = 4 00037 }; 00038 00039 class CStructuredModel; 00040 00042 class CStructuredOutputMachine : public CMachine 00043 { 00044 public: 00046 MACHINE_PROBLEM_TYPE(PT_STRUCTURED); 00047 00049 CStructuredOutputMachine(); 00050 00056 CStructuredOutputMachine(CStructuredModel* model, CStructuredLabels* labs); 00057 00059 virtual ~CStructuredOutputMachine(); 00060 00065 void set_model(CStructuredModel* model); 00066 00071 CStructuredModel* get_model() const; 00072 00074 virtual const char* get_name() const 00075 { 00076 return "StructuredOutputMachine"; 00077 } 00078 00083 virtual void set_labels(CLabels* lab); 00084 00089 void set_features(CFeatures* f); 00090 00095 CFeatures* get_features() const; 00096 00101 void set_surrogate_loss(CLossFunction* loss); 00102 00107 CLossFunction* get_surrogate_loss() const; 00108 00117 virtual float64_t risk(float64_t* subgrad, float64_t* W, 00118 TMultipleCPinfo* info=0, EStructRiskType rtype = N_SLACK_MARGIN_RESCALING); 00119 00121 CSOSVMHelper* get_helper() const; 00122 00129 void set_verbose(bool verbose); 00130 00135 bool get_verbose() const; 00136 00137 protected: 00165 virtual float64_t risk_nslack_margin_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0); 00166 00174 virtual float64_t risk_nslack_slack_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0); 00175 00183 virtual float64_t risk_1slack_margin_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0); 00184 00192 virtual float64_t risk_1slack_slack_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0); 00193 00201 virtual float64_t risk_customized_formulation(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0); 00202 00203 private: 00205 void register_parameters(); 00206 00207 protected: 00209 CStructuredModel* m_model; 00210 00215 CLossFunction* m_surrogate_loss; 00216 00218 CSOSVMHelper* m_helper; 00219 00221 bool m_verbose; 00222 00223 }; /* class CStructuredOutputMachine */ 00224 00225 } /* namespace shogun */ 00226 00227 #endif /* _STRUCTURED_OUTPUT_MACHINE__H__ */