From 0cd68badfcc924ad3da43c9cb5b5c1babbee888a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Fri, 28 Jun 2024 00:07:00 +0100
Subject: [PATCH] Prepare testing of training cycles

---
 .gitlab-ci.yml      | 11 ++++++++
 firmware/firmware.c | 43 +++++++++++++++++++++++++++-
 firmware/rnn.c      | 68 +++++++++++++--------------------------------
 src/py/tty3.py      | 33 ++++++++++++++++++++++
 4 files changed, 105 insertions(+), 50 deletions(-)
 create mode 100644 src/py/tty3.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 533b16f..83e3d8e 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -27,6 +27,17 @@ write_test:
     paths:
     - result
 
+train_test:
+  stage: build
+  script:
+  - mkdir -p result
+  - make
+  - ./testbench_ecp5_minifpga.bin &
+  - sleep 1s && python3 src/py/tty3.py
+  artifacts:
+    paths:
+    - result
+
 #json:
 #  stage: build
 #  script:
diff --git a/firmware/firmware.c b/firmware/firmware.c
index 0b0af18..cccdc66 100644
--- a/firmware/firmware.c
+++ b/firmware/firmware.c
@@ -23,6 +23,7 @@
 
 #define MEM_TOTAL 0x20000
 #define MSGBUF 40
+#define MAX_NUM_TOKENS 20
 
 // a pointer to this is a null pointer, but the compiler does not
 // know that because "sram" is a linker symbol from sections.lds.
@@ -66,7 +67,11 @@ void main()
       // Write mode
       WRITE,
       WRITE_LAYER,
-      WRITE_LAYER_WEIGHTS
+      WRITE_LAYER_WEIGHTS,
+      // Training the network
+      TRAIN,
+      TRAIN_STORE_TOKENS,
+      TRAIN_PROCESS
     } command_mode;
     command_mode = START;
 
@@ -75,9 +80,15 @@ void main()
     uint32_t current_max_num_values;
     uint32_t current_num_neurons;
 
+    // Write params
     uint32_t value_write_counter;
     int new_value;
 
+    // Training
+    uint32_t new_token;
+    uint32_t token_series[MAX_NUM_TOKENS];
+    int token_counter;
+
     while(true) {
 
         syn(); // Booted
@@ -115,6 +126,12 @@ void main()
                 else if(!strcmp(msg,"WRITE")) {
                     command_mode = WRITE;
                 }
+                else if(!strcmp(msg,"RESET")) {
+                    reset_network();
+                }
+                else if(!strcmp(msg,"TRAIN")) {
+                    command_mode = TRAIN;
+                }
                 break;
             
             case INIT:
@@ -232,6 +249,30 @@ void main()
                     value_write_counter++;
                 }
                 break;
+
+            case TRAIN:
+                if(!strcmp(msg,"TOKENS")) {
+                    command_mode = TRAIN_STORE_TOKENS;
+                    response = "OK";
+                    token_counter = 0;
+                }
+                else if(!strcmp(msg,"RUN")) {
+                    command_mode = TRAIN_PROCESS;
+                    response = "OK";
+                }
+                break;
+
+            case TRAIN_STORE_TOKENS:
+                if(token_counter<MAX_NUM_TOKENS) {
+                    new_token = atoi(numstr,10);
+                    token_series[token_counter] = new_token;
+                    token_counter++;
+                    response = "OK";
+                } else {
+                    response = "END";
+                }
+                break;
+
         }
         write_response(response);
     }
diff --git a/firmware/rnn.c b/firmware/rnn.c
index 06e015b..af34057 100644
--- a/firmware/rnn.c
+++ b/firmware/rnn.c
@@ -394,62 +394,32 @@ void set_bias_values(int bias)
     }
 }
 
-#if 0
+void run_training()
+{
     int dw = 200000000;
     int bias = 0;
     int last_val;
     uint32_t train_mask;
-        /*
-
-        citoa(27017, numstr, 10);
-        print("Input value: ");
-        print(numstr);
-        ack();
-
-        last_val = predict_next_token(27017);
-
-        citoa(last_val, numstr, 10);
-        print("Output value: ");
-        print(numstr);
-        ack();
-
-        leds(8);*/
 
-        /* Compare values
-         * The phrase "Taxation is theft"
-         * tokenized with the GPT2 tokenizer
-         * is tensor([[27017,   341,   318, 12402]])
-         * this means the next value we want is 341
-         */
-        /*while(last_val!=341) {
+    last_val = predict_next_token(27017);
 
-            train_mask = last_val ^ 341;
+    /* Compare values
+     * The phrase "Taxation is theft"
+     * tokenized with the GPT2 tokenizer
+     * is tensor([[27017,   341,   318, 12402]])
+     * this means the next value we want is 341
+     */
+    while(last_val!=341) {
 
-            print("Run training single shot: ");
-            citoa(train_mask, numstr, 10);
-            print(numstr);
-            ack();
-    
-            //set_dws(dw);
-            set_dws(dw, false);
-            dw=dw/100;
-            dw=99*dw;
-            mask_back_propgatation(train_mask);
-
-            citoa(27017, numstr, 10);
-            print("Input value: ");
-            print(numstr);
-            ack();
-
-            last_val = predict_next_token(27017);
+        train_mask = last_val ^ 341;
 
-            citoa(last_val, numstr, 10);
-            print("Output value: ");
-            print(numstr);
-            ack();
-        }*/
+        //set_dws(dw);
+        set_dws(dw, false);
+        dw=dw/100;
+        dw=99*dw;
+        mask_back_propgatation(train_mask);
 
-        //print("EXIT");
-        //ack();
+        last_val = predict_next_token(27017);
+    }
 
-#endif
+}
diff --git a/src/py/tty3.py b/src/py/tty3.py
new file mode 100644
index 0000000..7ff1660
--- /dev/null
+++ b/src/py/tty3.py
@@ -0,0 +1,33 @@
+import json
+
+from fac_tools import run_command
+from fac_tools import get_fac_wrapper
+from fac_tools import load_weights_and_biases
+
+from transformers import AutoTokenizer
+
+server = get_fac_wrapper("telnet")
+
+print("Reading in JSON")
+with open("test_files/weights_and_biases.json", "r") as f:
+    weights_and_biases = json.loads(f.read())
+    f.close()
+
+#load_weights_and_biases(server, weights_and_biases)
+
+tokenizer = AutoTokenizer.from_pretrained("gpt2")
+prompt = "Taxation is theft"
+input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+tokens = input_ids[0].numpy()
+run_command(server,"TRAIN")
+run_command(server,"TOKENS")
+for tok in tokens:
+    run_command(server,str(tok))
+run_command(server,"DONE")
+
+run_command(server,"TRAIN")
+run_command(server,"SHOOT")
+
+run_command(server,"TERMINATE")
+
+server.close()
-- 
GitLab