From 34ee2fdff2a212658f0a9052d1d4ca0f3fa1ea3f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Mon, 1 Jul 2024 07:57:43 +0100
Subject: [PATCH] Try single epoch training for now

---
 firmware/firmware.c    | 19 +++++++--
 firmware/include/rnn.h | 19 +++++++++
 firmware/rnn.c         | 89 +++++++++++++++++++++++++++++-------------
 src/py/tty3.py         | 17 ++++----
 src/rtl/neuron.v       |  1 -
 5 files changed, 106 insertions(+), 39 deletions(-)

diff --git a/firmware/firmware.c b/firmware/firmware.c
index bf107ce..8b66b95 100644
--- a/firmware/firmware.c
+++ b/firmware/firmware.c
@@ -78,7 +78,8 @@ void main()
       TRAIN_STORE_TOKENS,
       TRAIN_STORE_LEARNING_RATE,
       TRAIN_STORE_DECAY_RATE,
-      TRAIN_RUN_EPOCHS
+      TRAIN_RUN_EPOCHS,
+      TRAIN_RUN_SINGLE_EPOCH
     } command_mode;
     command_mode = START;
 
@@ -97,6 +98,8 @@ void main()
     int token_counter = 0;
     uint32_t learning_rate = 0;
     uint32_t decay_rate = 0;
+    uint32_t epoch_value = 0;
+    
 
     while(true) {
 
@@ -277,6 +280,10 @@ void main()
                     command_mode = TRAIN_RUN_EPOCHS;
                     response = "OK";
                 }
+                else if(!strcmp(msg,"RUN_SINGLE_EPOCH")) {
+                    command_mode = TRAIN_RUN_SINGLE_EPOCH;
+                    response = "OK";
+                }
                 break;
 
             case TRAIN_STORE_TOKENS:
@@ -303,11 +310,17 @@ void main()
                 break;
 
             case TRAIN_RUN_EPOCHS:
-                uint32_t num_epochs = atoi(numstr);
-                response = run_training(response, num_epochs, learning_rate, decay_rate, token_series, token_counter);
+                epoch_value = atoi(numstr);
+                response = run_training(response, epoch_value, learning_rate, decay_rate, token_series, token_counter);
                 command_mode = START;
                 break;
 
+            case TRAIN_RUN_SINGLE_EPOCH:
+                epoch_value = atoi(numstr);
+                new_value = run_training_single_epoch(epoch_value, learning_rate, decay_rate, token_series, token_counter);
+                response = new_value?"SUCCESS":"FAILURE";
+                break;
+
         }
         write_response(response);
     }
diff --git a/firmware/include/rnn.h b/firmware/include/rnn.h
index 48ede09..9fc064b 100644
--- a/firmware/include/rnn.h
+++ b/firmware/include/rnn.h
@@ -152,3 +152,22 @@ char* run_training(
     uint32_t *xp,
     uint32_t xn
 );
+
+/*
+ * Run training cycle for 1 epoch
+ * epoch: The epoch we're in
+ * learning_rate_zero: initial learning rate
+ * decay_rate: the decay rate for gradient decay
+ * xp: Pointer to array of values
+ * xi: Index of y
+ * 
+ * Returns 0 when successful
+ * Returns 1 when unsuccessful
+ */
+int run_training_single_epoch(
+    int epoch,
+    uint32_t learning_rate_zero,
+    uint32_t decay_rate,
+    uint32_t *xp,
+    uint32_t xi
+);
diff --git a/firmware/rnn.c b/firmware/rnn.c
index 05c7b69..921a1ec 100644
--- a/firmware/rnn.c
+++ b/firmware/rnn.c
@@ -387,6 +387,59 @@ void set_bias_values(int bias)
     }
 }
 
+/*
+ * Run training cycle for 1 epoch
+ * epoch: The epoch we're in
+ * learning_rate_zero: initial learning rate
+ * decay_rate: the decay rate for gradient decay
+ * xp: Pointer to array of values
+ * xi: Index of y
+ * 
+ * Returns 0 when successful
+ * Returns 1 when unsuccessful
+ */
+int run_training_single_epoch(
+    int epoch,
+    uint32_t learning_rate_zero,
+    uint32_t decay_rate,
+    uint32_t *xp,
+    uint32_t xi
+)
+{
+    int last_val;
+    uint32_t train_mask;
+    uint32_t y = xp[xi];;
+
+    /*
+     * Revert the LSTM states
+     */
+    reset_network();
+    /*
+     * Put all the prior tokens of the time series
+     * into the LSTM for reaching the state we wish to train on
+     */
+    for(int xpi=0; xpi<xi; xpi++) {
+        last_val = predict_next_token(xp[xpi]);
+    }
+    /*
+     * Check whether the network has already been trained
+     * Set return value to zero and exit the loop if so
+     */
+    if(last_val==y) {
+        return 1;
+    }
+    /*
+     * Determine the training mask my finding out which
+     * neurons misfired or didn't fire although they should have
+     */
+    train_mask = last_val ^ y;
+    set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
+    mask_back_propgatation(train_mask);
+    
+    return 0;
+}
+
+
 /*
  * Run training cycle for N epochs
  * num_epochs: Amount of epochs
@@ -406,41 +459,21 @@ int run_training_epochs(
     uint32_t xi
 )
 {
-    int ret = 1;
     int last_val;
     uint32_t train_mask;
     uint32_t y = xp[xi];;
 
     for(int epoch=0; epoch<num_epochs;epoch++) {
-        /*
-         * Revert the LSTM states
-         */
-        reset_network();
-        /*
-         * Put all the prior tokens of the time series
-         * into the LSTM for reaching the state we wish to train on
-         */
-        for(int xpi=0; xpi<xi; xpi++) {
-            last_val = predict_next_token(xp[xpi]);
-        }
-        /*
-         * Check whether the network has already been trained
-         * Set return value to zero and exit the loop if so
-         */
-        if(last_val==y) {
-            ret = 0;
-            break;
-        }
-        /*
-         * Determine the training mask my finding out which
-         * neurons misfired or didn't fire although they should have
-         */
-        train_mask = last_val ^ y;
-        set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
-        mask_back_propgatation(train_mask);
+        if(run_training_single_epoch(
+            epoch,
+            learning_rate_zero,
+            decay_rate,
+            xp,
+            xi
+        )) return 0;
     }
 
-    return ret;
+    return 1;
 }
 
 
diff --git a/src/py/tty3.py b/src/py/tty3.py
index 0922d96..1078dcd 100644
--- a/src/py/tty3.py
+++ b/src/py/tty3.py
@@ -17,7 +17,7 @@ server = get_fac_wrapper("telnet")
 
 INTMAX=2147483647
 lr=0.95
-decay_rate=10
+decay_rate=100
 
 run_command(server,"HELLO")
 
@@ -31,7 +31,6 @@ run_command(server,"RWEIGHTS")
 run_command(server,"INIT")
 run_command(server,"BIAS")
 
-
 run_command(server,"TRAIN")
 run_command(server,"LEARNING_RATE")
 # PicoRV does not have a floating point unit
@@ -41,7 +40,6 @@ run_command(server,"LEARNING_RATE")
 lr=lr*INTMAX
 run_command(server,str(int(lr)))
 
-
 run_command(server,"TRAIN")
 run_command(server,"DECAY_RATE")
 # The decay rate is a hyper parameter
@@ -62,9 +60,13 @@ tok=tokens[1]
 run_command(server,str(tok))
 run_command(server,"DONE")
 run_command(server,"TRAIN")
-run_command(server,"RUN_EPOCHS")
-run_command(server,str(1000))
+run_command(server,"RUN_SINGLE_EPOCH")
+for i in range(0,500):
+    ret = run_command(server,str(i))
+    if "SUCCESS" in ret:
+        break
 
+'''
 run_command(server,"TRAIN")
 run_command(server,"TOKENS")
 tok=tokens[1]
@@ -74,7 +76,7 @@ run_command(server,str(tok))
 run_command(server,"DONE")
 run_command(server,"TRAIN")
 run_command(server,"RUN_EPOCHS")
-run_command(server,str(1000))
+run_command(server,str(500))
 
 run_command(server,"TRAIN")
 run_command(server,"TOKENS")
@@ -85,7 +87,7 @@ run_command(server,str(tok))
 run_command(server,"DONE")
 run_command(server,"TRAIN")
 run_command(server,"RUN_EPOCHS")
-run_command(server,str(1000))
+run_command(server,str(500))
 
 # Upload token series
 run_command(server,"TRAIN")
@@ -96,6 +98,7 @@ run_command(server,"DONE")
 run_command(server,"TRAIN")
 run_command(server,"RUN_EPOCHS")
 run_command(server,str(1000))
+'''
 
 # Store the weights and biases
 weights_and_biases = dump_neural_network(server)
diff --git a/src/rtl/neuron.v b/src/rtl/neuron.v
index d1d5d2c..6895e96 100644
--- a/src/rtl/neuron.v
+++ b/src/rtl/neuron.v
@@ -101,7 +101,6 @@ module neuron #(
         else
         if(!backprop_done && backprop_running)
         begin
-            $display("Backprop with dW=", $signed(dw));
             // Adjust normal weights
             if( regidx < NUMBER_SYNAPSES )
             begin
-- 
GitLab