Skip to content
Snippets Groups Projects
Commit 0f5e57d6 authored by David Lanzendörfer's avatar David Lanzendörfer
Browse files

Add training loop

parent 7f68d353
No related branches found
No related tags found
No related merge requests found
......@@ -387,6 +387,63 @@ void set_bias_values(int bias)
}
}
/*
* Run training cycle for N epochs
* num_epochs: Amount of epochs
* learning_rate_zero: initial learning rate
* decay_rate: the decay rate for gradient decay
* xp: Pointer to array of values
* xi: Index of y
*
* Returns 0 when successful
* Returns 1 when unsuccessful
*/
int run_training_epochs(
int num_epochs,
uint32_t learning_rate_zero,
uint32_t decay_rate,
uint32_t *xp,
uint32_t xi
)
{
int ret = 1;
int last_val;
uint32_t train_mask;
uint32_t y = xp[xi];;
for(int epoch=0; epoch<num_epochs;epoch++) {
/*
* Revert the LSTM states
*/
reset_network();
/*
* Put all the prior tokens of the time series
* into the LSTM for reaching the state we wish to train on
*/
for(int xpi=0; xpi<xi; xpi++) {
last_val = predict_next_token(xp[xpi]);
}
/*
* Check whether the network has already been trained
* Set return value to zero and exit the loop if so
*/
if(last_val==y) {
ret = 0;
break;
}
/*
* Determine the training mask my finding out which
* neurons misfired or didn't fire although they should have
*/
train_mask = last_val ^ y;
set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
mask_back_propgatation(train_mask);
}
return ret;
}
/*
* Run training cycle for N epochs
* msgbuf: buffer for infos
......@@ -407,24 +464,38 @@ char* run_training(
{
int last_val;
uint32_t train_mask;
uint32_t y;
int ret;
for(uint32_t xni=1; xni<xn; xni++) {
/*
* Training the RNN for a time series
* 1) Tax -> ation
* 2) Tax -> ation -> is
* 3) Tax -> ation -> is -> theft
*
* xn is the amount of values we have in our token array
* xi tells the functions where y is located in the array (x -> y)
* xi must be xn-1 at most
*/
for(uint32_t xi=1; xi<xn; xi++) {
ret = run_training_epochs(
num_epochs,
learning_rate_zero,
decay_rate,
xp,
xi
);
msgbuf = "FAIL";
y = xp[xni];
for(int epoch=0; epoch<num_epochs;epoch++) {
reset_network();
for(int xpi=0; xpi<xni; xpi++) {
last_val = predict_next_token(xp[xpi]);
}
if(last_val==y) {
msgbuf = "SUCCESS";
break;
}
train_mask = last_val ^ y;
set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
mask_back_propgatation(train_mask);
}
/*
* When the training for the last time series failed
* there's no point in training for a further token
* in the time series.
*/
if(ret) break;
/*
* If we haven't terminated the loop by now, it means we're
* still game
*/
msgbuf = "SUCESS";
}
return msgbuf;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment