From 0f5e57d6d67245ed6bec0dda1d25d729e6933091 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Sun, 30 Jun 2024 07:05:52 +0100
Subject: [PATCH] Add training loop

---
 firmware/rnn.c | 103 +++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 87 insertions(+), 16 deletions(-)

diff --git a/firmware/rnn.c b/firmware/rnn.c
index c1ddcfe..0633a49 100644
--- a/firmware/rnn.c
+++ b/firmware/rnn.c
@@ -387,6 +387,63 @@ void set_bias_values(int bias)
     }
 }
 
+/*
+ * Run training cycle for N epochs
+ * num_epochs: Amount of epochs
+ * learning_rate_zero: initial learning rate
+ * decay_rate: the decay rate for gradient decay
+ * xp: Pointer to array of values
+ * xi: Index of y
+ * 
+ * Returns 0 when successful
+ * Returns 1 when unsuccessful
+ */
+int run_training_epochs(
+    int num_epochs,
+    uint32_t learning_rate_zero,
+    uint32_t decay_rate,
+    uint32_t *xp,
+    uint32_t xi
+)
+{
+    int ret = 1;
+    int last_val;
+    uint32_t train_mask;
+    uint32_t y = xp[xi];;
+
+    for(int epoch=0; epoch<num_epochs;epoch++) {
+        /*
+         * Revert the LSTM states
+         */
+        reset_network();
+        /*
+         * Put all the prior tokens of the time series
+         * into the LSTM for reaching the state we wish to train on
+         */
+        for(int xpi=0; xpi<xi; xpi++) {
+            last_val = predict_next_token(xp[xpi]);
+        }
+        /*
+         * Check whether the network has already been trained
+         * Set return value to zero and exit the loop if so
+         */
+        if(last_val==y) {
+            ret = 0;
+            break;
+        }
+        /*
+         * Determine the training mask my finding out which
+         * neurons misfired or didn't fire although they should have
+         */
+        train_mask = last_val ^ y;
+        set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
+        mask_back_propgatation(train_mask);
+    }
+
+    return ret;
+}
+
+
 /*
  * Run training cycle for N epochs
  * msgbuf: buffer for infos
@@ -407,24 +464,38 @@ char* run_training(
 {
     int last_val;
     uint32_t train_mask;
-    uint32_t y;
+    int ret;
 
-    for(uint32_t xni=1; xni<xn; xni++) {
+    /*
+     * Training the RNN for a time series
+     * 1) Tax -> ation
+     * 2) Tax -> ation -> is
+     * 3) Tax -> ation -> is -> theft
+     * 
+     * xn is the amount of values we have in our token array
+     * xi tells the functions where y is located in the array (x -> y)
+     * xi must be xn-1 at most
+     */
+    for(uint32_t xi=1; xi<xn; xi++) {
+        ret = run_training_epochs(
+            num_epochs,
+            learning_rate_zero,
+            decay_rate,
+            xp,
+            xi
+        );
         msgbuf = "FAIL";
-        y = xp[xni];
-        for(int epoch=0; epoch<num_epochs;epoch++) {
-            reset_network();
-            for(int xpi=0; xpi<xni; xpi++) {
-                last_val = predict_next_token(xp[xpi]);
-            }
-            if(last_val==y) {
-                msgbuf = "SUCCESS";
-                break;
-            }
-            train_mask = last_val ^ y;
-            set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
-            mask_back_propgatation(train_mask);
-        }
+        /*
+         * When the training for the last time series failed
+         * there's no point in training for a further token
+         * in the time series.
+         */
+        if(ret) break;
+        /*
+         * If we haven't terminated the loop by now, it means we're
+         * still game
+         */
+        msgbuf = "SUCESS";
     }
 
     return msgbuf;
-- 
GitLab