From 34ee2fdff2a212658f0a9052d1d4ca0f3fa1ea3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com> Date: Mon, 1 Jul 2024 07:57:43 +0100 Subject: [PATCH] Try single epoch training for now --- firmware/firmware.c | 19 +++++++-- firmware/include/rnn.h | 19 +++++++++ firmware/rnn.c | 89 +++++++++++++++++++++++++++++------------- src/py/tty3.py | 17 ++++---- src/rtl/neuron.v | 1 - 5 files changed, 106 insertions(+), 39 deletions(-) diff --git a/firmware/firmware.c b/firmware/firmware.c index bf107ce..8b66b95 100644 --- a/firmware/firmware.c +++ b/firmware/firmware.c @@ -78,7 +78,8 @@ void main() TRAIN_STORE_TOKENS, TRAIN_STORE_LEARNING_RATE, TRAIN_STORE_DECAY_RATE, - TRAIN_RUN_EPOCHS + TRAIN_RUN_EPOCHS, + TRAIN_RUN_SINGLE_EPOCH } command_mode; command_mode = START; @@ -97,6 +98,8 @@ void main() int token_counter = 0; uint32_t learning_rate = 0; uint32_t decay_rate = 0; + uint32_t epoch_value = 0; + while(true) { @@ -277,6 +280,10 @@ void main() command_mode = TRAIN_RUN_EPOCHS; response = "OK"; } + else if(!strcmp(msg,"RUN_SINGLE_EPOCH")) { + command_mode = TRAIN_RUN_SINGLE_EPOCH; + response = "OK"; + } break; case TRAIN_STORE_TOKENS: @@ -303,11 +310,17 @@ void main() break; case TRAIN_RUN_EPOCHS: - uint32_t num_epochs = atoi(numstr); - response = run_training(response, num_epochs, learning_rate, decay_rate, token_series, token_counter); + epoch_value = atoi(numstr); + response = run_training(response, epoch_value, learning_rate, decay_rate, token_series, token_counter); command_mode = START; break; + case TRAIN_RUN_SINGLE_EPOCH: + epoch_value = atoi(numstr); + new_value = run_training_single_epoch(epoch_value, learning_rate, decay_rate, token_series, token_counter); + response = new_value?"SUCCESS":"FAILURE"; + break; + } write_response(response); } diff --git a/firmware/include/rnn.h b/firmware/include/rnn.h index 48ede09..9fc064b 100644 --- a/firmware/include/rnn.h +++ b/firmware/include/rnn.h @@ -152,3 +152,22 @@ char* run_training( uint32_t *xp, uint32_t xn ); + +/* + * Run training cycle for 1 epoch + * epoch: The epoch we're in + * learning_rate_zero: initial learning rate + * decay_rate: the decay rate for gradient decay + * xp: Pointer to array of values + * xi: Index of y + * + * Returns 0 when successful + * Returns 1 when unsuccessful + */ +int run_training_single_epoch( + int epoch, + uint32_t learning_rate_zero, + uint32_t decay_rate, + uint32_t *xp, + uint32_t xi +); diff --git a/firmware/rnn.c b/firmware/rnn.c index 05c7b69..921a1ec 100644 --- a/firmware/rnn.c +++ b/firmware/rnn.c @@ -387,6 +387,59 @@ void set_bias_values(int bias) } } +/* + * Run training cycle for 1 epoch + * epoch: The epoch we're in + * learning_rate_zero: initial learning rate + * decay_rate: the decay rate for gradient decay + * xp: Pointer to array of values + * xi: Index of y + * + * Returns 0 when successful + * Returns 1 when unsuccessful + */ +int run_training_single_epoch( + int epoch, + uint32_t learning_rate_zero, + uint32_t decay_rate, + uint32_t *xp, + uint32_t xi +) +{ + int last_val; + uint32_t train_mask; + uint32_t y = xp[xi];; + + /* + * Revert the LSTM states + */ + reset_network(); + /* + * Put all the prior tokens of the time series + * into the LSTM for reaching the state we wish to train on + */ + for(int xpi=0; xpi<xi; xpi++) { + last_val = predict_next_token(xp[xpi]); + } + /* + * Check whether the network has already been trained + * Set return value to zero and exit the loop if so + */ + if(last_val==y) { + return 1; + } + /* + * Determine the training mask my finding out which + * neurons misfired or didn't fire although they should have + */ + train_mask = last_val ^ y; + set_alpha(learning_rate_zero/(1+(decay_rate*epoch))); + mask_back_propgatation(train_mask); + + return 0; +} + + /* * Run training cycle for N epochs * num_epochs: Amount of epochs @@ -406,41 +459,21 @@ int run_training_epochs( uint32_t xi ) { - int ret = 1; int last_val; uint32_t train_mask; uint32_t y = xp[xi];; for(int epoch=0; epoch<num_epochs;epoch++) { - /* - * Revert the LSTM states - */ - reset_network(); - /* - * Put all the prior tokens of the time series - * into the LSTM for reaching the state we wish to train on - */ - for(int xpi=0; xpi<xi; xpi++) { - last_val = predict_next_token(xp[xpi]); - } - /* - * Check whether the network has already been trained - * Set return value to zero and exit the loop if so - */ - if(last_val==y) { - ret = 0; - break; - } - /* - * Determine the training mask my finding out which - * neurons misfired or didn't fire although they should have - */ - train_mask = last_val ^ y; - set_alpha(learning_rate_zero/(1+(decay_rate*epoch))); - mask_back_propgatation(train_mask); + if(run_training_single_epoch( + epoch, + learning_rate_zero, + decay_rate, + xp, + xi + )) return 0; } - return ret; + return 1; } diff --git a/src/py/tty3.py b/src/py/tty3.py index 0922d96..1078dcd 100644 --- a/src/py/tty3.py +++ b/src/py/tty3.py @@ -17,7 +17,7 @@ server = get_fac_wrapper("telnet") INTMAX=2147483647 lr=0.95 -decay_rate=10 +decay_rate=100 run_command(server,"HELLO") @@ -31,7 +31,6 @@ run_command(server,"RWEIGHTS") run_command(server,"INIT") run_command(server,"BIAS") - run_command(server,"TRAIN") run_command(server,"LEARNING_RATE") # PicoRV does not have a floating point unit @@ -41,7 +40,6 @@ run_command(server,"LEARNING_RATE") lr=lr*INTMAX run_command(server,str(int(lr))) - run_command(server,"TRAIN") run_command(server,"DECAY_RATE") # The decay rate is a hyper parameter @@ -62,9 +60,13 @@ tok=tokens[1] run_command(server,str(tok)) run_command(server,"DONE") run_command(server,"TRAIN") -run_command(server,"RUN_EPOCHS") -run_command(server,str(1000)) +run_command(server,"RUN_SINGLE_EPOCH") +for i in range(0,500): + ret = run_command(server,str(i)) + if "SUCCESS" in ret: + break +''' run_command(server,"TRAIN") run_command(server,"TOKENS") tok=tokens[1] @@ -74,7 +76,7 @@ run_command(server,str(tok)) run_command(server,"DONE") run_command(server,"TRAIN") run_command(server,"RUN_EPOCHS") -run_command(server,str(1000)) +run_command(server,str(500)) run_command(server,"TRAIN") run_command(server,"TOKENS") @@ -85,7 +87,7 @@ run_command(server,str(tok)) run_command(server,"DONE") run_command(server,"TRAIN") run_command(server,"RUN_EPOCHS") -run_command(server,str(1000)) +run_command(server,str(500)) # Upload token series run_command(server,"TRAIN") @@ -96,6 +98,7 @@ run_command(server,"DONE") run_command(server,"TRAIN") run_command(server,"RUN_EPOCHS") run_command(server,str(1000)) +''' # Store the weights and biases weights_and_biases = dump_neural_network(server) diff --git a/src/rtl/neuron.v b/src/rtl/neuron.v index d1d5d2c..6895e96 100644 --- a/src/rtl/neuron.v +++ b/src/rtl/neuron.v @@ -101,7 +101,6 @@ module neuron #( else if(!backprop_done && backprop_running) begin - $display("Backprop with dW=", $signed(dw)); // Adjust normal weights if( regidx < NUMBER_SYNAPSES ) begin -- GitLab