From 0fa5860d5b56139585c72cb5396817801aae7211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com> Date: Sun, 30 Jun 2024 06:09:49 +0100 Subject: [PATCH] Train on a complete time series --- firmware/firmware.c | 10 +++++----- firmware/include/rnn.h | 8 ++++---- firmware/rnn.c | 38 +++++++++++++++++++++++--------------- src/py/tty3.py | 1 - 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/firmware/firmware.c b/firmware/firmware.c index 8558656..0d1fcac 100644 --- a/firmware/firmware.c +++ b/firmware/firmware.c @@ -88,15 +88,15 @@ void main() uint32_t current_num_neurons; // Write params - uint32_t value_write_counter; + uint32_t value_write_counter = 0; int new_value; // Training uint32_t new_token; uint32_t token_series[MAX_NUM_TOKENS]; - int token_counter; - uint32_t learning_rate; - uint32_t decay_rate; + int token_counter = 0; + uint32_t learning_rate = 0; + uint32_t decay_rate = 0; while(true) { @@ -304,7 +304,7 @@ void main() case TRAIN_RUN_EPOCHS: uint32_t num_epochs = atoi(numstr); - response = run_training(response, num_epochs, learning_rate, decay_rate, token_series[0], token_series[1]); + //response = run_training(response, num_epochs, learning_rate, decay_rate, token_series, token_counter); command_mode = START; break; diff --git a/firmware/include/rnn.h b/firmware/include/rnn.h index bfa3461..48ede09 100644 --- a/firmware/include/rnn.h +++ b/firmware/include/rnn.h @@ -141,14 +141,14 @@ void reset_network(); * num_epochs: Amount of epochs * learning_rate_zero: initial learning rate * decay_rate: the decay rate for gradient decay - * x: input token - * y: output token + * xp: Pointer to array of values + * xn: Amount of values in array */ char* run_training( char *msgbuf, int num_epochs, uint32_t learning_rate_zero, uint32_t decay_rate, - uint32_t x, - uint32_t y + uint32_t *xp, + uint32_t xn ); diff --git a/firmware/rnn.c b/firmware/rnn.c index 7bf595a..c1ddcfe 100644 --- a/firmware/rnn.c +++ b/firmware/rnn.c @@ -393,32 +393,40 @@ void set_bias_values(int bias) * num_epochs: Amount of epochs * learning_rate_zero: initial learning rate * decay_rate: the decay rate for gradient decay - * x: input token - * y: output token + * xp: Pointer to array of values + * xn: Amount of values in array */ char* run_training( char *msgbuf, int num_epochs, uint32_t learning_rate_zero, uint32_t decay_rate, - uint32_t x, - uint32_t y + uint32_t *xp, + uint32_t xn ) { int last_val; uint32_t train_mask; - - for(int epoch=0; epoch<num_epochs;epoch++) { - reset_network(); - last_val = predict_next_token(x); - if(last_val==y) { - return "SUCCESS"; - break; + uint32_t y; + + for(uint32_t xni=1; xni<xn; xni++) { + msgbuf = "FAIL"; + y = xp[xni]; + for(int epoch=0; epoch<num_epochs;epoch++) { + reset_network(); + for(int xpi=0; xpi<xni; xpi++) { + last_val = predict_next_token(xp[xpi]); + } + if(last_val==y) { + msgbuf = "SUCCESS"; + break; + } + train_mask = last_val ^ y; + set_alpha(learning_rate_zero/(1+(decay_rate*epoch))); + mask_back_propgatation(train_mask); } - train_mask = last_val ^ y; - set_alpha(learning_rate_zero/(1+(decay_rate*epoch))); - mask_back_propgatation(train_mask); } - return "FAIL"; + + return msgbuf; } diff --git a/src/py/tty3.py b/src/py/tty3.py index 29158e4..e05a6ab 100644 --- a/src/py/tty3.py +++ b/src/py/tty3.py @@ -54,7 +54,6 @@ run_command(server,"DECAY_RATE") # α=(1/(1+decayRate×epochNumber))*​α0 run_command(server,str(decay_rate)) - run_command(server,"TRAIN") run_command(server,"RUN_EPOCHS") run_command(server,str(20000)) -- GitLab