From 0fa5860d5b56139585c72cb5396817801aae7211 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Sun, 30 Jun 2024 06:09:49 +0100
Subject: [PATCH] Train on a complete time series

---
 firmware/firmware.c    | 10 +++++-----
 firmware/include/rnn.h |  8 ++++----
 firmware/rnn.c         | 38 +++++++++++++++++++++++---------------
 src/py/tty3.py         |  1 -
 4 files changed, 32 insertions(+), 25 deletions(-)

diff --git a/firmware/firmware.c b/firmware/firmware.c
index 8558656..0d1fcac 100644
--- a/firmware/firmware.c
+++ b/firmware/firmware.c
@@ -88,15 +88,15 @@ void main()
     uint32_t current_num_neurons;
 
     // Write params
-    uint32_t value_write_counter;
+    uint32_t value_write_counter = 0;
     int new_value;
 
     // Training
     uint32_t new_token;
     uint32_t token_series[MAX_NUM_TOKENS];
-    int token_counter;
-    uint32_t learning_rate;
-    uint32_t decay_rate;
+    int token_counter = 0;
+    uint32_t learning_rate = 0;
+    uint32_t decay_rate = 0;
 
     while(true) {
 
@@ -304,7 +304,7 @@ void main()
 
             case TRAIN_RUN_EPOCHS:
                 uint32_t num_epochs = atoi(numstr);
-                response = run_training(response, num_epochs, learning_rate, decay_rate, token_series[0], token_series[1]);
+                //response = run_training(response, num_epochs, learning_rate, decay_rate, token_series, token_counter);
                 command_mode = START;
                 break;
 
diff --git a/firmware/include/rnn.h b/firmware/include/rnn.h
index bfa3461..48ede09 100644
--- a/firmware/include/rnn.h
+++ b/firmware/include/rnn.h
@@ -141,14 +141,14 @@ void reset_network();
  * num_epochs: Amount of epochs
  * learning_rate_zero: initial learning rate
  * decay_rate: the decay rate for gradient decay
- * x: input token
- * y: output token
+ * xp: Pointer to array of values
+ * xn: Amount of values in array
  */
 char* run_training(
     char *msgbuf,
     int num_epochs,
     uint32_t learning_rate_zero,
     uint32_t decay_rate,
-    uint32_t x,
-    uint32_t y
+    uint32_t *xp,
+    uint32_t xn
 );
diff --git a/firmware/rnn.c b/firmware/rnn.c
index 7bf595a..c1ddcfe 100644
--- a/firmware/rnn.c
+++ b/firmware/rnn.c
@@ -393,32 +393,40 @@ void set_bias_values(int bias)
  * num_epochs: Amount of epochs
  * learning_rate_zero: initial learning rate
  * decay_rate: the decay rate for gradient decay
- * x: input token
- * y: output token
+ * xp: Pointer to array of values
+ * xn: Amount of values in array
  */
 char* run_training(
     char *msgbuf,
     int num_epochs,
     uint32_t learning_rate_zero,
     uint32_t decay_rate,
-    uint32_t x,
-    uint32_t y
+    uint32_t *xp,
+    uint32_t xn
 )
 {
     int last_val;
     uint32_t train_mask;
-
-    for(int epoch=0; epoch<num_epochs;epoch++) {
-        reset_network();
-        last_val = predict_next_token(x);
-        if(last_val==y) {
-            return "SUCCESS";
-            break;
+    uint32_t y;
+
+    for(uint32_t xni=1; xni<xn; xni++) {
+        msgbuf = "FAIL";
+        y = xp[xni];
+        for(int epoch=0; epoch<num_epochs;epoch++) {
+            reset_network();
+            for(int xpi=0; xpi<xni; xpi++) {
+                last_val = predict_next_token(xp[xpi]);
+            }
+            if(last_val==y) {
+                msgbuf = "SUCCESS";
+                break;
+            }
+            train_mask = last_val ^ y;
+            set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
+            mask_back_propgatation(train_mask);
         }
-        train_mask = last_val ^ y;
-        set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
-        mask_back_propgatation(train_mask);
     }
-    return "FAIL";
+
+    return msgbuf;
 }
 
diff --git a/src/py/tty3.py b/src/py/tty3.py
index 29158e4..e05a6ab 100644
--- a/src/py/tty3.py
+++ b/src/py/tty3.py
@@ -54,7 +54,6 @@ run_command(server,"DECAY_RATE")
 # α=(1/(1+decayRate×epochNumber))*​α0
 run_command(server,str(decay_rate))
 
-
 run_command(server,"TRAIN")
 run_command(server,"RUN_EPOCHS")
 run_command(server,str(20000))
-- 
GitLab