SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include <shogun/ui/GUIHMM.h> 00013 #include <shogun/ui/SGInterface.h> 00014 00015 #include <shogun/lib/config.h> 00016 #include <shogun/lib/common.h> 00017 #include <shogun/features/StringFeatures.h> 00018 #include <shogun/labels/Labels.h> 00019 #include <shogun/labels/RegressionLabels.h> 00020 #include <shogun/mathematics/Statistics.h> 00021 00022 #include <unistd.h> 00023 00024 using namespace shogun; 00025 00026 CGUIHMM::CGUIHMM(CSGInterface* ui_) 00027 : CSGObject(), ui(ui_) 00028 { 00029 working=NULL; 00030 00031 pos=NULL; 00032 neg=NULL; 00033 test=NULL; 00034 00035 PSEUDO=1e-10; 00036 M=4; 00037 } 00038 00039 CGUIHMM::~CGUIHMM() 00040 { 00041 SG_UNREF(working); 00042 } 00043 00044 bool CGUIHMM::new_hmm(int32_t n, int32_t m) 00045 { 00046 SG_UNREF(working); 00047 working=new CHMM(n, m, NULL, PSEUDO); 00048 M=m; 00049 return true; 00050 } 00051 00052 bool CGUIHMM::baum_welch_train() 00053 { 00054 if (!working) 00055 SG_ERROR("Create HMM first.\n") 00056 00057 CFeatures* trainfeatures=ui->ui_features->get_train_features(); 00058 if (!trainfeatures) 00059 SG_ERROR("Assign train features first.\n") 00060 if (trainfeatures->get_feature_type()!=F_WORD || 00061 trainfeatures->get_feature_class()!=C_STRING) 00062 SG_ERROR("Features must be STRING of type WORD.\n") 00063 00064 CStringFeatures<uint16_t>* sf=(CStringFeatures<uint16_t>*) trainfeatures; 00065 SG_DEBUG("Stringfeatures have %ld orig_symbols %ld symbols %d order %ld max_symbols\n", (int64_t) sf->get_original_num_symbols(), (int64_t) sf->get_num_symbols(), sf->get_order(), (int64_t) sf->get_max_num_symbols()) 00066 00067 working->set_observations(sf); 00068 00069 return working->baum_welch_viterbi_train(BW_NORMAL); 00070 } 00071 00072 00073 bool CGUIHMM::baum_welch_trans_train() 00074 { 00075 if (!working) 00076 SG_ERROR("Create HMM first.\n") 00077 00078 CFeatures* trainfeatures=ui->ui_features->get_train_features(); 00079 if (!trainfeatures) 00080 SG_ERROR("Assign train features first.\n") 00081 if (trainfeatures->get_feature_type()!=F_WORD || 00082 trainfeatures->get_feature_class()!=C_STRING) 00083 SG_ERROR("Features must be STRING of type WORD.\n") 00084 00085 working->set_observations((CStringFeatures<uint16_t>*) trainfeatures); 00086 00087 return working->baum_welch_viterbi_train(BW_TRANS); 00088 } 00089 00090 00091 bool CGUIHMM::baum_welch_train_defined() 00092 { 00093 if (!working) 00094 SG_ERROR("Create HMM first.\n") 00095 if (!working->get_observations()) 00096 SG_ERROR("Assign observation first.\n") 00097 00098 return working->baum_welch_viterbi_train(BW_DEFINED); 00099 } 00100 00101 bool CGUIHMM::viterbi_train() 00102 { 00103 if (!working) 00104 SG_ERROR("Create HMM first.\n") 00105 if (!working->get_observations()) 00106 SG_ERROR("Assign observation first.\n") 00107 00108 return working->baum_welch_viterbi_train(VIT_NORMAL); 00109 } 00110 00111 bool CGUIHMM::viterbi_train_defined() 00112 { 00113 if (!working) 00114 SG_ERROR("Create HMM first.\n") 00115 if (!working->get_observations()) 00116 SG_ERROR("Assign observation first.\n") 00117 00118 return working->baum_welch_viterbi_train(VIT_DEFINED); 00119 } 00120 00121 bool CGUIHMM::linear_train(char align) 00122 { 00123 if (!working) 00124 SG_ERROR("Create HMM first.\n") 00125 00126 CFeatures* trainfeatures=ui->ui_features->get_train_features(); 00127 if (!trainfeatures) 00128 SG_ERROR("Assign train features first.\n") 00129 if (trainfeatures->get_feature_type()!=F_WORD || 00130 trainfeatures->get_feature_class()!=C_STRING) 00131 SG_ERROR("Features must be STRING of type WORD.\n") 00132 00133 working->set_observations((CStringFeatures<uint16_t>*) ui-> 00134 ui_features->get_train_features()); 00135 00136 bool right_align=false; 00137 if (align=='r') 00138 { 00139 SG_INFO("Using alignment to right.\n") 00140 right_align=true; 00141 } 00142 else 00143 SG_INFO("Using alignment to left.\n") 00144 working->linear_train(right_align); 00145 00146 return true; 00147 } 00148 00149 CRegressionLabels* CGUIHMM::classify(CRegressionLabels* result) 00150 { 00151 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00152 ui_features->get_test_features(); 00153 ASSERT(obs) 00154 int32_t num_vec=obs->get_num_vectors(); 00155 00156 //CStringFeatures<uint16_t>* old_pos=pos->get_observations(); 00157 //CStringFeatures<uint16_t>* old_neg=neg->get_observations(); 00158 00159 pos->set_observations(obs); 00160 neg->set_observations(obs); 00161 00162 if (!result) 00163 result=new CRegressionLabels(num_vec); 00164 00165 for (int32_t i=0; i<num_vec; i++) 00166 result->set_label(i, pos->model_probability(i) - neg->model_probability(i)); 00167 00168 //pos->set_observations(old_pos); 00169 //neg->set_observations(old_neg); 00170 return result; 00171 } 00172 00173 float64_t CGUIHMM::classify_example(int32_t idx) 00174 { 00175 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00176 ui_features->get_test_features(); 00177 ASSERT(obs) 00178 00179 //CStringFeatures<uint16_t>* old_pos=pos->get_observations(); 00180 //CStringFeatures<uint16_t>* old_neg=neg->get_observations(); 00181 00182 pos->set_observations(obs); 00183 neg->set_observations(obs); 00184 00185 float64_t result=pos->model_probability(idx) - neg->model_probability(idx); 00186 //pos->set_observations(old_pos); 00187 //neg->set_observations(old_neg); 00188 return result; 00189 } 00190 00191 CRegressionLabels* CGUIHMM::one_class_classify(CRegressionLabels* result) 00192 { 00193 ASSERT(working) 00194 00195 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00196 ui_features->get_test_features(); 00197 ASSERT(obs) 00198 int32_t num_vec=obs->get_num_vectors(); 00199 00200 //CStringFeatures<uint16_t>* old_pos=working->get_observations(); 00201 working->set_observations(obs); 00202 00203 if (!result) 00204 result=new CRegressionLabels(num_vec); 00205 00206 for (int32_t i=0; i<num_vec; i++) 00207 result->set_label(i, working->model_probability(i)); 00208 00209 //working->set_observations(old_pos); 00210 return result; 00211 } 00212 00213 CRegressionLabels* CGUIHMM::linear_one_class_classify(CRegressionLabels* result) 00214 { 00215 ASSERT(working) 00216 00217 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00218 ui_features->get_test_features(); 00219 ASSERT(obs) 00220 int32_t num_vec=obs->get_num_vectors(); 00221 00222 //CStringFeatures<uint16_t>* old_pos=working->get_observations(); 00223 working->set_observations(obs); 00224 00225 if (!result) 00226 result=new CRegressionLabels(num_vec); 00227 00228 for (int32_t i=0; i<num_vec; i++) 00229 result->set_label(i, working->linear_model_probability(i)); 00230 00231 //working->set_observations(old_pos); 00232 return result; 00233 } 00234 00235 00236 float64_t CGUIHMM::one_class_classify_example(int32_t idx) 00237 { 00238 ASSERT(working) 00239 00240 CStringFeatures<uint16_t>* obs= (CStringFeatures<uint16_t>*) ui-> 00241 ui_features->get_test_features(); 00242 ASSERT(obs) 00243 00244 //CStringFeatures<uint16_t>* old_pos=pos->get_observations(); 00245 00246 pos->set_observations(obs); 00247 neg->set_observations(obs); 00248 00249 float64_t result=working->model_probability(idx); 00250 //working->set_observations(old_pos); 00251 return result; 00252 } 00253 00254 bool CGUIHMM::append_model(char* filename, int32_t base1, int32_t base2) 00255 { 00256 if (!working) 00257 SG_ERROR("Create HMM first.\n") 00258 if (!filename) 00259 SG_ERROR("Invalid filename.\n") 00260 00261 FILE* model_file=fopen(filename, "r"); 00262 if (!model_file) 00263 SG_ERROR("Opening file %s failed.\n", filename) 00264 00265 CHMM* h=new CHMM(model_file,PSEUDO); 00266 if (!h || !h->get_status()) 00267 { 00268 SG_UNREF(h); 00269 fclose(model_file); 00270 SG_ERROR("Reading file %s failed.\n", filename) 00271 } 00272 00273 fclose(model_file); 00274 SG_INFO("File %s successfully read.\n", filename) 00275 00276 SG_DEBUG("h %d , M: %d\n", h, h->get_M()) 00277 if (base1!=-1 && base2!=-1) 00278 { 00279 float64_t* cur_o=SG_MALLOC(float64_t, h->get_M()); 00280 float64_t* app_o=SG_MALLOC(float64_t, h->get_M()); 00281 00282 for (int32_t i=0; i<h->get_M(); i++) 00283 { 00284 if (i==base1) 00285 cur_o[i]=0; 00286 else 00287 cur_o[i]=-1000; 00288 00289 if (i==base2) 00290 app_o[i]=0; 00291 else 00292 app_o[i]=-1000; 00293 } 00294 00295 working->append_model(h, cur_o, app_o); 00296 00297 SG_FREE(cur_o); 00298 SG_FREE(app_o); 00299 } 00300 else 00301 working->append_model(h); 00302 00303 SG_UNREF(h); 00304 SG_INFO("New model has %i states.\n", working->get_N()) 00305 return true; 00306 } 00307 00308 bool CGUIHMM::add_states(int32_t num_states, float64_t value) 00309 { 00310 if (!working) 00311 SG_ERROR("Create HMM first.\n") 00312 00313 working->add_states(num_states, value); 00314 SG_INFO("New model has %i states, value %f.\n", working->get_N(), value) 00315 return true; 00316 } 00317 00318 bool CGUIHMM::set_pseudo(float64_t pseudo) 00319 { 00320 PSEUDO=pseudo; 00321 SG_INFO("Current setting: pseudo=%e.\n", PSEUDO) 00322 return true; 00323 } 00324 00325 bool CGUIHMM::convergence_criteria(int32_t num_iterations, float64_t epsilon) 00326 { 00327 if (!working) 00328 SG_ERROR("Create HMM first.\n") 00329 00330 working->set_iterations(num_iterations); 00331 working->set_epsilon(epsilon); 00332 00333 SG_INFO("Current HMM convergence criteria: iterations=%i, epsilon=%e\n", working->get_iterations(), working->get_epsilon()) 00334 return true; 00335 } 00336 00337 bool CGUIHMM::set_hmm_as(char* target) 00338 { 00339 if (!working) 00340 SG_ERROR("Create HMM first!\n") 00341 00342 if (strncmp(target, "POS", 3)==0) 00343 { 00344 SG_UNREF(pos); 00345 pos=working; 00346 working=NULL; 00347 } 00348 else if (strncmp(target, "NEG", 3)==0) 00349 { 00350 SG_UNREF(neg); 00351 neg=working; 00352 working=NULL; 00353 } 00354 else if (strncmp(target, "TEST", 4)==0) 00355 { 00356 SG_UNREF(test); 00357 test=working; 00358 working=NULL; 00359 } 00360 else 00361 SG_ERROR("Target POS|NEG|TEST is missing.\n") 00362 00363 return true; 00364 } 00365 00366 bool CGUIHMM::load(char* filename) 00367 { 00368 bool result=false; 00369 00370 FILE* model_file=fopen(filename, "r"); 00371 if (!model_file) 00372 SG_ERROR("Opening file %s failed.\n", filename) 00373 00374 SG_UNREF(working); 00375 working=new CHMM(model_file, PSEUDO); 00376 fclose(model_file); 00377 00378 if (working && working->get_status()) 00379 { 00380 SG_INFO("Loaded HMM successfully from file %s.\n", filename) 00381 result=true; 00382 } 00383 00384 M=working->get_M(); 00385 00386 return result; 00387 } 00388 00389 bool CGUIHMM::save(char* filename, bool is_binary) 00390 { 00391 bool result=false; 00392 00393 if (!working) 00394 SG_ERROR("Create HMM first.\n") 00395 00396 FILE* file=fopen(filename, "w"); 00397 if (file) 00398 { 00399 if (is_binary) 00400 result=working->save_model_bin(file); 00401 else 00402 result=working->save_model(file); 00403 } 00404 00405 if (!file || !result) 00406 SG_ERROR("Writing to file %s failed!\n", filename) 00407 else 00408 SG_INFO("Successfully written model into %s!\n", filename) 00409 00410 if (file) 00411 fclose(file); 00412 00413 return result; 00414 } 00415 00416 bool CGUIHMM::load_definitions(char* filename, bool do_init) 00417 { 00418 if (!working) 00419 SG_ERROR("Create HMM first.\n") 00420 00421 bool result=false; 00422 FILE* def_file=fopen(filename, "r"); 00423 if (!def_file) 00424 SG_ERROR("Opening file %s failed\n", filename) 00425 00426 if (working->load_definitions(def_file, true, do_init)) 00427 { 00428 SG_INFO("Definitions successfully read from %s.\n", filename) 00429 result=true; 00430 } 00431 else 00432 SG_ERROR("Couldn't load definitions form file %s.\n", filename) 00433 00434 fclose(def_file); 00435 return result; 00436 } 00437 00438 bool CGUIHMM::save_likelihood(char* filename, bool is_binary) 00439 { 00440 bool result=false; 00441 00442 if (!working) 00443 SG_ERROR("Create HMM first\n") 00444 00445 FILE* file=fopen(filename, "w"); 00446 if (file) 00447 { 00449 //if (binary) 00450 // result=working->save_model_bin(file); 00451 //else 00452 00453 result=working->save_likelihood(file); 00454 } 00455 00456 if (!file || !result) 00457 SG_ERROR("Writing to file %s failed!\n", filename) 00458 else 00459 SG_INFO("Successfully written likelihoods into %s!\n", filename) 00460 00461 if (file) 00462 fclose(file); 00463 00464 return result; 00465 } 00466 00467 bool CGUIHMM::save_path(char* filename, bool is_binary) 00468 { 00469 bool result=false; 00470 if (!working) 00471 SG_ERROR("Create HMM first.\n") 00472 00473 FILE* file=fopen(filename, "w"); 00474 if (file) 00475 { 00477 //if (binary) 00478 //_train()/ result=working->save_model_bin(file); 00479 //else 00480 CStringFeatures<uint16_t>* obs=(CStringFeatures<uint16_t>*) ui-> 00481 ui_features->get_test_features(); 00482 ASSERT(obs) 00483 working->set_observations(obs); 00484 00485 result=working->save_path(file); 00486 } 00487 00488 if (!file || !result) 00489 SG_ERROR("Writing to file %s failed!\n", filename) 00490 else 00491 SG_INFO("Successfully written path into %s!\n", filename) 00492 00493 if (file) 00494 fclose(file); 00495 00496 return result; 00497 } 00498 00499 bool CGUIHMM::chop(float64_t value) 00500 { 00501 if (!working) 00502 SG_ERROR("Create HMM first.\n") 00503 00504 working->chop(value); 00505 return true; 00506 } 00507 00508 bool CGUIHMM::likelihood() 00509 { 00510 if (!working) 00511 SG_ERROR("Create HMM first!\n") 00512 00513 working->output_model(false); 00514 return true; 00515 } 00516 00517 bool CGUIHMM::output_hmm() 00518 { 00519 if (!working) 00520 SG_ERROR("Create HMM first!\n") 00521 00522 working->output_model(true); 00523 return true; 00524 } 00525 00526 bool CGUIHMM::output_hmm_defined() 00527 { 00528 if (!working) 00529 SG_ERROR("Create HMM first!\n") 00530 00531 working->output_model_defined(true); 00532 return true; 00533 } 00534 00535 bool CGUIHMM::best_path(int32_t from, int32_t to) 00536 { 00537 // FIXME: from unused??? 00538 00539 if (!working) 00540 SG_ERROR("Create HMM first.\n") 00541 00542 //get path 00543 working->best_path(0); 00544 00545 for (int32_t t=0; t<working->get_observations()->get_vector_length(0)-1 && t<to; t++) 00546 SG_PRINT("%d ", working->get_best_path_state(0, t)) 00547 SG_PRINT("\n") 00548 00549 //for (t=0; t<p_observations->get_vector_length(0)-1 && t<to; t++) 00550 // SG_PRINT("%d ", PATH(0)[t]) 00551 // 00552 return true; 00553 } 00554 00555 bool CGUIHMM::normalize(bool keep_dead_states) 00556 { 00557 if (!working) 00558 SG_ERROR("Create HMM first.\n") 00559 00560 working->normalize(keep_dead_states); 00561 return true; 00562 } 00563 00564 bool CGUIHMM::relative_entropy(float64_t*& values, int32_t& len) 00565 { 00566 if (!pos || !neg) 00567 SG_ERROR("Set pos and neg HMM first!\n") 00568 00569 int32_t pos_N=pos->get_N(); 00570 int32_t neg_N=neg->get_N(); 00571 int32_t pos_M=pos->get_M(); 00572 int32_t neg_M=neg->get_M(); 00573 if (pos_M!=neg_M || pos_N!=neg_N) 00574 SG_ERROR("Pos and neg HMM's differ in number of emissions or states.\n") 00575 00576 float64_t* p=SG_MALLOC(float64_t, pos_M); 00577 float64_t* q=SG_MALLOC(float64_t, neg_M); 00578 00579 SG_FREE(values); 00580 values=SG_MALLOC(float64_t, pos_N); 00581 00582 for (int32_t i=0; i<pos_N; i++) 00583 { 00584 for (int32_t j=0; j<pos_M; j++) 00585 { 00586 p[j]=pos->get_b(i, j); 00587 q[j]=neg->get_b(i, j); 00588 } 00589 00590 values[i]=CStatistics::relative_entropy(p, q, pos_M); 00591 } 00592 SG_FREE(p); 00593 SG_FREE(q); 00594 00595 len=pos_N; 00596 return true; 00597 } 00598 00599 bool CGUIHMM::entropy(float64_t*& values, int32_t& len) 00600 { 00601 if (!working) 00602 SG_ERROR("Create HMM first!\n") 00603 00604 int32_t n=working->get_N(); 00605 int32_t m=working->get_M(); 00606 float64_t* p=SG_MALLOC(float64_t, m); 00607 00608 SG_FREE(values); 00609 values=SG_MALLOC(float64_t, n); 00610 00611 for (int32_t i=0; i<n; i++) 00612 { 00613 for (int32_t j=0; j<m; j++) 00614 p[j]=working->get_b(i, j); 00615 00616 values[i]=CStatistics::entropy(p, m); 00617 } 00618 SG_FREE(p); 00619 00620 len=m; 00621 return true; 00622 } 00623 00624 bool CGUIHMM::permutation_entropy(int32_t width, int32_t seq_num) 00625 { 00626 if (!working) 00627 SG_ERROR("Create hmm first.\n") 00628 00629 if (!working->get_observations()) 00630 SG_ERROR("Set observations first.\n") 00631 00632 return working->permutation_entropy(width, seq_num); 00633 }