From 0f5e57d6d67245ed6bec0dda1d25d729e6933091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com> Date: Sun, 30 Jun 2024 07:05:52 +0100 Subject: [PATCH] Add training loop --- firmware/rnn.c | 103 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 87 insertions(+), 16 deletions(-) diff --git a/firmware/rnn.c b/firmware/rnn.c index c1ddcfe..0633a49 100644 --- a/firmware/rnn.c +++ b/firmware/rnn.c @@ -387,6 +387,63 @@ void set_bias_values(int bias) } } +/* + * Run training cycle for N epochs + * num_epochs: Amount of epochs + * learning_rate_zero: initial learning rate + * decay_rate: the decay rate for gradient decay + * xp: Pointer to array of values + * xi: Index of y + * + * Returns 0 when successful + * Returns 1 when unsuccessful + */ +int run_training_epochs( + int num_epochs, + uint32_t learning_rate_zero, + uint32_t decay_rate, + uint32_t *xp, + uint32_t xi +) +{ + int ret = 1; + int last_val; + uint32_t train_mask; + uint32_t y = xp[xi];; + + for(int epoch=0; epoch<num_epochs;epoch++) { + /* + * Revert the LSTM states + */ + reset_network(); + /* + * Put all the prior tokens of the time series + * into the LSTM for reaching the state we wish to train on + */ + for(int xpi=0; xpi<xi; xpi++) { + last_val = predict_next_token(xp[xpi]); + } + /* + * Check whether the network has already been trained + * Set return value to zero and exit the loop if so + */ + if(last_val==y) { + ret = 0; + break; + } + /* + * Determine the training mask my finding out which + * neurons misfired or didn't fire although they should have + */ + train_mask = last_val ^ y; + set_alpha(learning_rate_zero/(1+(decay_rate*epoch))); + mask_back_propgatation(train_mask); + } + + return ret; +} + + /* * Run training cycle for N epochs * msgbuf: buffer for infos @@ -407,24 +464,38 @@ char* run_training( { int last_val; uint32_t train_mask; - uint32_t y; + int ret; - for(uint32_t xni=1; xni<xn; xni++) { + /* + * Training the RNN for a time series + * 1) Tax -> ation + * 2) Tax -> ation -> is + * 3) Tax -> ation -> is -> theft + * + * xn is the amount of values we have in our token array + * xi tells the functions where y is located in the array (x -> y) + * xi must be xn-1 at most + */ + for(uint32_t xi=1; xi<xn; xi++) { + ret = run_training_epochs( + num_epochs, + learning_rate_zero, + decay_rate, + xp, + xi + ); msgbuf = "FAIL"; - y = xp[xni]; - for(int epoch=0; epoch<num_epochs;epoch++) { - reset_network(); - for(int xpi=0; xpi<xni; xpi++) { - last_val = predict_next_token(xp[xpi]); - } - if(last_val==y) { - msgbuf = "SUCCESS"; - break; - } - train_mask = last_val ^ y; - set_alpha(learning_rate_zero/(1+(decay_rate*epoch))); - mask_back_propgatation(train_mask); - } + /* + * When the training for the last time series failed + * there's no point in training for a further token + * in the time series. + */ + if(ret) break; + /* + * If we haven't terminated the loop by now, it means we're + * still game + */ + msgbuf = "SUCESS"; } return msgbuf; -- GitLab