wagon.cc (speech_tools-2.4-release) | : | wagon.cc (speech_tools-2.5.0-release) | ||
---|---|---|---|---|
skipping to change at line 65 | skipping to change at line 65 | |||
Discretes wgn_discretes; | Discretes wgn_discretes; | |||
WDataSet wgn_dataset; | WDataSet wgn_dataset; | |||
WDataSet wgn_test_dataset; | WDataSet wgn_test_dataset; | |||
EST_FMatrix wgn_DistMatrix; | EST_FMatrix wgn_DistMatrix; | |||
EST_Track wgn_VertexTrack; | EST_Track wgn_VertexTrack; | |||
EST_Track wgn_VertexFeats; | EST_Track wgn_VertexFeats; | |||
EST_Track wgn_UnitTrack; | EST_Track wgn_UnitTrack; | |||
int wgn_min_cluster_size = 50; | int wgn_min_cluster_size = 50; | |||
int wgn_max_questions = 2000000; /* not ideal, but adequate */ | ||||
int wgn_held_out = 0; | int wgn_held_out = 0; | |||
float wgn_dropout_feats = 0.0; | ||||
float wgn_dropout_samples = 0.0; | ||||
int wgn_cos = 1; | ||||
int wgn_prune = TRUE; | int wgn_prune = TRUE; | |||
int wgn_quiet = FALSE; | int wgn_quiet = FALSE; | |||
int wgn_verbose = FALSE; | int wgn_verbose = FALSE; | |||
int wgn_count_field = -1; | int wgn_count_field = -1; | |||
EST_String wgn_count_field_name = ""; | EST_String wgn_count_field_name = ""; | |||
int wgn_predictee = 0; | int wgn_predictee = 0; | |||
EST_String wgn_predictee_name = ""; | EST_String wgn_predictee_name = ""; | |||
float wgn_float_range_split = 10; | float wgn_float_range_split = 10; | |||
float wgn_balance = 0; | float wgn_balance = 0; | |||
EST_String wgn_opt_param = ""; | EST_String wgn_opt_param = ""; | |||
skipping to change at line 661 | skipping to change at line 665 | |||
WNode *leaf; | WNode *leaf; | |||
float predict,real; | float predict,real; | |||
EST_SuffStats x,y,xx,yy,xy,se,e; | EST_SuffStats x,y,xx,yy,xy,se,e; | |||
double cor,error; | double cor,error; | |||
double count; | double count; | |||
for (p=dataset.head(); p != 0; p=p->next()) | for (p=dataset.head(); p != 0; p=p->next()) | |||
{ | { | |||
leaf = tree.predict_node((*dataset(p))); | leaf = tree.predict_node((*dataset(p))); | |||
// do ols to get predict; | // do ols to get predict; | |||
predict = 0.0; | predict = 0.0; // This is incomplete ! you need to use leaf | |||
real = dataset(p)->get_flt_val(wgn_predictee); | real = dataset(p)->get_flt_val(wgn_predictee); | |||
if (wgn_count_field == -1) | if (wgn_count_field == -1) | |||
count = 1.0; | count = 1.0; | |||
else | else | |||
count = dataset(p)->get_flt_val(wgn_count_field); | count = dataset(p)->get_flt_val(wgn_count_field); | |||
x.cumulate(predict,count); | x.cumulate(predict,count); | |||
y.cumulate(real,count); | y.cumulate(real,count); | |||
error = predict-real; | error = predict-real; | |||
se.cumulate((error*error),count); | se.cumulate((error*error),count); | |||
e.cumulate(fabs(error),count); | e.cumulate(fabs(error),count); | |||
skipping to change at line 722 | skipping to change at line 726 | |||
return cor; // should really be % variance, I think | return cor; // should really be % variance, I think | |||
} | } | |||
static int wagon_split(int margin, WNode &node) | static int wagon_split(int margin, WNode &node) | |||
{ | { | |||
// Split given node (if possible) | // Split given node (if possible) | |||
WQuestion q; | WQuestion q; | |||
WNode *l,*r; | WNode *l,*r; | |||
node.set_impurity(WImpurity(node.get_data())); | node.set_impurity(WImpurity(node.get_data())); | |||
if (wgn_max_questions < 1) | ||||
return FALSE; | ||||
q = find_best_question(node.get_data()); | q = find_best_question(node.get_data()); | |||
/* printf("q.score() %f impurity %f\n", | /* printf("q.score() %f impurity %f\n", | |||
q.get_score(), | q.get_score(), | |||
node.get_impurity().measure()); */ | node.get_impurity().measure()); */ | |||
double impurity_measure = node.get_impurity().measure(); | double impurity_measure = node.get_impurity().measure(); | |||
double question_score = q.get_score(); | double question_score = q.get_score(); | |||
if ((question_score < WGN_HUGE_VAL) && | if ((question_score < WGN_HUGE_VAL) && | |||
(question_score < impurity_measure)) | (question_score < impurity_measure)) | |||
{ | { | |||
// Ok its worth a split | // Ok its worth a split | |||
l = new WNode(); | l = new WNode(); | |||
r = new WNode(); | r = new WNode(); | |||
wgn_find_split(q,node.get_data(),l->get_data(),r->get_data()); | wgn_find_split(q,node.get_data(),l->get_data(),r->get_data()); | |||
node.set_subnodes(l,r); | node.set_subnodes(l,r); | |||
node.set_question(q); | node.set_question(q); | |||
if (wgn_verbose) | if (wgn_verbose) | |||
{ | { | |||
int i; | int i; | |||
for (i=0; i < margin; i++) | for (i=0; i < margin; i++) | |||
cout << " "; | cout << " "; | |||
cout << q << endl; | cout << q << endl; | |||
} | } | |||
wgn_max_questions--; | ||||
margin++; | margin++; | |||
wagon_split(margin,*l); | wagon_split(margin,*l); | |||
margin++; | margin++; | |||
wagon_split(margin,*r); | wagon_split(margin,*r); | |||
margin--; | margin--; | |||
return TRUE; | return TRUE; | |||
} | } | |||
else | else | |||
{ | { | |||
if (wgn_verbose) | if (wgn_verbose) | |||
skipping to change at line 775 | skipping to change at line 782 | |||
margin--; | margin--; | |||
return FALSE; | return FALSE; | |||
} | } | |||
} | } | |||
void wgn_find_split(WQuestion &q,WVectorVector &ds, | void wgn_find_split(WQuestion &q,WVectorVector &ds, | |||
WVectorVector &y,WVectorVector &n) | WVectorVector &y,WVectorVector &n) | |||
{ | { | |||
int i, iy, in; | int i, iy, in; | |||
y.resize(q.get_yes()); | if (wgn_dropout_samples > 0.0) | |||
n.resize(q.get_no()); | { | |||
// You need to count the number of yes/no again in all ds | ||||
for (iy=in=i=0; i < ds.n(); i++) | ||||
if (q.ask(*ds(i)) == TRUE) | ||||
iy++; | ||||
else | ||||
in++; | ||||
} | ||||
else | ||||
{ | ||||
// Current counts are corrent (as all data was used) | ||||
iy = q.get_yes(); | ||||
in = q.get_no(); | ||||
} | ||||
y.resize(iy); | ||||
n.resize(in); | ||||
for (iy=in=i=0; i < ds.n(); i++) | for (iy=in=i=0; i < ds.n(); i++) | |||
if (q.ask(*ds(i)) == TRUE) | if (q.ask(*ds(i)) == TRUE) | |||
y[iy++] = ds(i); | y[iy++] = ds(i); | |||
else | else | |||
n[in++] = ds(i); | n[in++] = ds(i); | |||
} | } | |||
static float wgn_random_number(float x) | ||||
{ | ||||
// Returns random number between 0 and x | ||||
return (((float)random())/RAND_MAX)*x; | ||||
} | ||||
#ifdef OMP_WAGON | ||||
static WQuestion find_best_question(WVectorVector &dset) | ||||
{ | ||||
// Ask all possible questions and find the best one | ||||
int i; | ||||
float bscore,tscore; | ||||
WQuestion test_ques, best_ques; | ||||
WQuestion** questions=new WQuestion*[wgn_dataset.width()]; | ||||
float* scores = new float[wgn_dataset.width()]; | ||||
bscore = tscore = WGN_HUGE_VAL; | ||||
best_ques.set_score(bscore); | ||||
#pragma omp parallel | ||||
#pragma omp for | ||||
for (i=0;i < wgn_dataset.width(); i++) | ||||
{ | ||||
questions[i] = new WQuestion; | ||||
questions[i]->set_score(bscore);} | ||||
#pragma omp parallel | ||||
#pragma omp for | ||||
for (i=0;i < wgn_dataset.width(); i++) | ||||
{ | ||||
if ((wgn_dataset.ignore(i) == TRUE) || | ||||
(i == wgn_predictee)) | ||||
scores[i] = WGN_HUGE_VAL; // ignore this feature this time | ||||
else if (wgn_random_number(1.0) < wgn_dropout_feats) | ||||
scores[i] = WGN_HUGE_VAL; // randomly dropout feature | ||||
else if (wgn_dataset.ftype(i) == wndt_binary) | ||||
{ | ||||
construct_binary_ques(i,*questions[i]); | ||||
scores[i] = wgn_score_question(*questions[i],dset); | ||||
} | ||||
else if (wgn_dataset.ftype(i) == wndt_float) | ||||
{ | ||||
scores[i] = construct_float_ques(i,*questions[i],dset); | ||||
} | ||||
else if (wgn_dataset.ftype(i) == wndt_ignore) | ||||
scores[i] = WGN_HUGE_VAL; // always ignore this feature | ||||
#if 0 | ||||
// This doesn't work reasonably | ||||
else if (wgn_csubset && (wgn_dataset.ftype(i) >= wndt_class)) | ||||
{ | ||||
wagon_error("subset selection temporarily deleted"); | ||||
tscore = construct_class_ques_subset(i,test_ques,dset); | ||||
} | ||||
#endif | ||||
else if (wgn_dataset.ftype(i) >= wndt_class) | ||||
scores[i] = construct_class_ques(i,*questions[i],dset); | ||||
} | ||||
for (i=0;i < wgn_dataset.width(); i++) | ||||
{ | ||||
if (scores[i] < bscore) | ||||
{ | ||||
memcpy(&best_ques,questions[i],sizeof(*questions[i])); | ||||
best_ques.set_score(scores[i]); | ||||
bscore = scores[i]; | ||||
} | ||||
delete questions[i]; | ||||
} | ||||
delete [] questions; | ||||
delete [] scores; | ||||
return best_ques; | ||||
} | ||||
#else | ||||
// No OMP parallelism | ||||
static WQuestion find_best_question(WVectorVector &dset) | static WQuestion find_best_question(WVectorVector &dset) | |||
{ | { | |||
// Ask all possible questions and find the best one | // Ask all possible questions and find the best one | |||
int i; | int i; | |||
float bscore,tscore; | float bscore,tscore; | |||
WQuestion test_ques, best_ques; | WQuestion test_ques, best_ques; | |||
bscore = tscore = WGN_HUGE_VAL; | bscore = tscore = WGN_HUGE_VAL; | |||
best_ques.set_score(bscore); | best_ques.set_score(bscore); | |||
// test each feature with each possible question | // test each feature with each possible question | |||
for (i=0;i < wgn_dataset.width(); i++) | for (i=0;i < wgn_dataset.width(); i++) | |||
{ | { | |||
if ((wgn_dataset.ignore(i) == TRUE) || | if ((wgn_dataset.ignore(i) == TRUE) || | |||
(i == wgn_predictee)) | (i == wgn_predictee)) | |||
tscore = WGN_HUGE_VAL; // ignore this feature this time | tscore = WGN_HUGE_VAL; // ignore this feature this time | |||
else if (wgn_random_number(1.0) < wgn_dropout_feats) | ||||
tscore = WGN_HUGE_VAL; // randomly dropout feature | ||||
else if (wgn_dataset.ftype(i) == wndt_binary) | else if (wgn_dataset.ftype(i) == wndt_binary) | |||
{ | { | |||
construct_binary_ques(i,test_ques); | construct_binary_ques(i,test_ques); | |||
tscore = wgn_score_question(test_ques,dset); | tscore = wgn_score_question(test_ques,dset); | |||
} | } | |||
else if (wgn_dataset.ftype(i) == wndt_float) | else if (wgn_dataset.ftype(i) == wndt_float) | |||
{ | { | |||
tscore = construct_float_ques(i,test_ques,dset); | tscore = construct_float_ques(i,test_ques,dset); | |||
} | } | |||
else if (wgn_dataset.ftype(i) == wndt_ignore) | else if (wgn_dataset.ftype(i) == wndt_ignore) | |||
skipping to change at line 832 | skipping to change at line 927 | |||
if (tscore < bscore) | if (tscore < bscore) | |||
{ | { | |||
best_ques = test_ques; | best_ques = test_ques; | |||
best_ques.set_score(tscore); | best_ques.set_score(tscore); | |||
bscore = tscore; | bscore = tscore; | |||
} | } | |||
} | } | |||
return best_ques; | return best_ques; | |||
} | } | |||
#endif | ||||
static float construct_class_ques(int feat,WQuestion &ques,WVectorVector &ds) | static float construct_class_ques(int feat,WQuestion &ques,WVectorVector &ds) | |||
{ | { | |||
// Find out which member of a class gives the best split | // Find out which member of a class gives the best split | |||
float tscore,bscore = WGN_HUGE_VAL; | float tscore,bscore = WGN_HUGE_VAL; | |||
int cl; | int cl; | |||
WQuestion test_q; | WQuestion test_q; | |||
test_q.set_fp(feat); | test_q.set_fp(feat); | |||
test_q.set_oper(wnop_is); | test_q.set_oper(wnop_is); | |||
skipping to change at line 1028 | skipping to change at line 1124 | |||
WImpurity y,n; | WImpurity y,n; | |||
int d, num_yes, num_no; | int d, num_yes, num_no; | |||
float count; | float count; | |||
WVector *wv; | WVector *wv; | |||
num_yes = num_no = 0; | num_yes = num_no = 0; | |||
y.data = &ds; | y.data = &ds; | |||
n.data = &ds; | n.data = &ds; | |||
for (d=0; d < ds.n(); d++) | for (d=0; d < ds.n(); d++) | |||
{ | { | |||
if ((ignorenth < 2) || | if (wgn_random_number(1.0) < wgn_dropout_samples) | |||
{ | ||||
continue; // dropout this sample | ||||
} | ||||
else if ((ignorenth < 2) || | ||||
(d%ignorenth != ignorenth-1)) | (d%ignorenth != ignorenth-1)) | |||
{ | { | |||
wv = ds(d); | wv = ds(d); | |||
if (wgn_count_field == -1) | if (wgn_count_field == -1) | |||
count = 1.0; | count = 1.0; | |||
else | else | |||
count = (*wv)[wgn_count_field]; | count = (*wv)[wgn_count_field]; | |||
if (q.ask(*wv) == TRUE) | if (q.ask(*wv) == TRUE) | |||
{ | { | |||
End of changes. 11 change blocks. | ||||
5 lines changed or deleted | 105 lines changed or added |