SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Thoralf Klein 00008 * Written (W) 2012 Fernando José Iglesias García 00009 * Copyright (C) 2012 Fernando José Iglesias García 00010 */ 00011 00012 #include <shogun/structure/MulticlassSOLabels.h> 00013 00014 using namespace shogun; 00015 00016 CMulticlassSOLabels::CMulticlassSOLabels() 00017 : CStructuredLabels(), m_labels_vector(16) 00018 { 00019 init(); 00020 } 00021 00022 CMulticlassSOLabels::CMulticlassSOLabels(int32_t num_labels) 00023 : CStructuredLabels(), m_labels_vector(num_labels) 00024 { 00025 init(); 00026 } 00027 00028 CMulticlassSOLabels::CMulticlassSOLabels(SGVector< float64_t > const src) 00029 : CStructuredLabels(src.vlen), m_labels_vector(src.vlen) 00030 { 00031 init(); 00032 00033 m_num_classes = SGVector< float64_t >::max(src.vector, src.vlen) + 1; 00034 m_labels_vector.resize_vector(src.vlen); 00035 00036 for ( int32_t i = 0 ; i < src.vlen ; ++i ) 00037 { 00038 if ( src[i] < 0 || src[i] >= m_num_classes ) 00039 SG_ERROR("Found label out of {0, 1, 2, ..., num_classes-1}") 00040 else 00041 add_label( new CRealNumber(src[i]) ); 00042 } 00043 00044 //TODO check that every class has at least one example 00045 } 00046 00047 CMulticlassSOLabels::~CMulticlassSOLabels() 00048 { 00049 } 00050 00051 CStructuredData* CMulticlassSOLabels::get_label(int32_t idx) 00052 { 00053 // ensure_valid("CMulticlassSOLabels::get_label(int32_t)"); 00054 if ( idx < 0 || idx >= get_num_labels() ) 00055 SG_ERROR("Index must be inside [0, num_labels-1]\n") 00056 00057 return (CStructuredData*) new CRealNumber(m_labels_vector[idx]); 00058 } 00059 00060 void CMulticlassSOLabels::add_label(CStructuredData* label) 00061 { 00062 SG_REF(label); 00063 float64_t value = CRealNumber::obtain_from_generic(label)->value; 00064 SG_UNREF(label); 00065 00066 //ensure_valid_sdt(label); 00067 if (m_num_labels_set >= m_labels_vector.vlen) 00068 { 00069 m_labels_vector.resize_vector(m_num_labels_set + 16); 00070 } 00071 00072 00073 m_labels_vector[m_num_labels_set] = value; 00074 m_num_labels_set++; 00075 } 00076 00077 bool CMulticlassSOLabels::set_label(int32_t idx, CStructuredData* label) 00078 { 00079 SG_REF(label); 00080 float64_t value = CRealNumber::obtain_from_generic(label)->value; 00081 SG_UNREF(label); 00082 00083 // ensure_valid_sdt(label); 00084 int32_t real_idx = m_subset_stack->subset_idx_conversion(idx); 00085 00086 if ( real_idx < get_num_labels() ) 00087 { 00088 m_labels_vector[real_idx] = value; 00089 return true; 00090 } 00091 else 00092 { 00093 return false; 00094 } 00095 } 00096 00097 int32_t CMulticlassSOLabels::get_num_labels() const 00098 { 00099 return m_num_labels_set; 00100 } 00101 00102 void CMulticlassSOLabels::init() 00103 { 00104 SG_ADD(&m_num_classes, "m_num_classes", "The number of classes", 00105 MS_NOT_AVAILABLE); 00106 SG_ADD(&m_num_labels_set, "m_num_labels_set", "The number of assigned labels", 00107 MS_NOT_AVAILABLE); 00108 00109 m_num_classes = 0; 00110 m_num_labels_set = 0; 00111 }