From 868cbd8815edff7a7c6fd1d2a48092ccce2f009b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Fri, 28 Jun 2024 16:13:06 +0100
Subject: [PATCH] First attempt of running training

---
 Makefile               | 46 +++++------------------------
 firmware/firmware.c    | 34 ++++++++++++++++-----
 firmware/include/rnn.h | 18 ++++++++++++
 firmware/rnn.c         | 67 +++++++++++++++++++++---------------------
 src/py/tty3.py         | 52 ++++++++++++++++++++++++++++----
 5 files changed, 132 insertions(+), 85 deletions(-)

diff --git a/Makefile b/Makefile
index 34dc6d0..6e53018 100644
--- a/Makefile
+++ b/Makefile
@@ -8,30 +8,6 @@ NETWORK_PARAMS = \
 
 TARGET_BOARD?=ecp5_minifpga
 
-RISCV_GNU_TOOLCHAIN_INSTALL_PREFIX ?= /opt/riscv
-PYTHON = python3
-GCC_WARNS  = -Werror -Wall -Wextra -Wshadow -Wundef -Wpointer-arith -Wcast-qual -Wcast-align -Wwrite-strings
-GCC_WARNS += -Wredundant-decls -Wstrict-prototypes -Wmissing-prototypes -pedantic # -Wconversion
-TOOLCHAIN_PREFIX = $(RISCV_GNU_TOOLCHAIN_INSTALL_PREFIX)/bin/riscv32-unknown-elf-
-COMPRESSED_ISA = C
-
-FIRMWARE_OBJS= \
-firmware/start.o \
-firmware/main.o \
-firmware/print.o
-
-FIRMWARE_FILES= \
-firmware/start.s \
-firmware/rnn.c \
-firmware/firmware.c \
-firmware/io.c \
-firmware/string.c
-
-FIRMWARE_CFLAGS=
-FIRMWARE_CFLAGS+=-mabi=ilp32 -march=rv32imc
-FIRMWARE_CFLAGS+=-Wl,--build-id=none,-Bstatic,-T,firmware/sections.lds,--strip-debug
-FIRMWARE_CFLAGS+=-ffreestanding -nostdlib
-
 BENCHES=
 
 VERILATOR_DIR?=/usr/share/verilator/include
@@ -78,7 +54,7 @@ src/cpptb/params.h:
 		echo "#define $(PARAM) $(VALUE)" >> $@ ; \
 	)
 
-firmware/defines.h:
+firmware/include/defines.h:
 	@ $(foreach PARAMS,$(NETWORK_PARAMS), \
 		$(eval PARAM = $(word 1,$(subst :, ,$(PARAMS)))) \
 		$(eval VALUE = $(word 2,$(subst :, ,$(PARAMS)))) \
@@ -100,20 +76,14 @@ result/soc.json: result/firmware.hex src/rtl/params.vh result
 result/soc_out.config: $(LPFILE) # result/soc.json 
 	nextpnr-ecp5 --json result/soc.json --lpf $(LPFILE) --textcfg $@ --freq 50 --package CABGA256 --lpf-allow-unconstrained
 
-result/firmware.hex: firmware/firmware.bin submodules/picorv32/firmware/makehex.py result
-	$(TOOLCHAIN_PREFIX)objcopy -O verilog firmware/firmware.elf result/firmware.hex
-	cp result/firmware.hex firmware.hex
-
-firmware/firmware.bin: firmware/firmware.elf
-	$(TOOLCHAIN_PREFIX)objcopy -O binary $< $@
-
-firmware/firmware.elf: firmware/sections.lds firmware/defines.h $(FIRMWARE_FILES)
-	$(TOOLCHAIN_PREFIX)gcc $(CFLAGS) $(FIRMWARE_CFLAGS) -o $@ $(FIRMWARE_FILES)
+result/firmware.hex: result firmware/include/defines.h
+	make -C firmware firmware.hex
+	cp firmware/firmware.hex firmware.hex
+	cp firmware/firmware.hex result/firmware.hex
 
 result:
 	mkdir -p result
 clean:
-	rm -rf firmware/firmware.elf result/firmware.hex \
-		firmware/firmware.bin firmware/*.o firmware/firmware.map \
-		*.bin *.vcd obj_dir src/cpptb/*.o soc.json src/rtl/params.vh \
-		src/cpptb/params.h firmware/defines.h
+	rm -rf 	*.bin *.vcd obj_dir src/cpptb/*.o soc.json src/rtl/params.vh \
+	src/cpptb/params.h firmware/include/defines.h result/firmware.hex
+	make -C firmware clean
diff --git a/firmware/firmware.c b/firmware/firmware.c
index aaf0253..8558656 100644
--- a/firmware/firmware.c
+++ b/firmware/firmware.c
@@ -77,7 +77,8 @@ void main()
       TRAIN,
       TRAIN_STORE_TOKENS,
       TRAIN_STORE_LEARNING_RATE,
-      TRAIN_PROCESS
+      TRAIN_STORE_DECAY_RATE,
+      TRAIN_RUN_EPOCHS
     } command_mode;
     command_mode = START;
 
@@ -94,6 +95,8 @@ void main()
     uint32_t new_token;
     uint32_t token_series[MAX_NUM_TOKENS];
     int token_counter;
+    uint32_t learning_rate;
+    uint32_t decay_rate;
 
     while(true) {
 
@@ -266,10 +269,14 @@ void main()
                     response = "OK";
                     command_mode = TRAIN_STORE_LEARNING_RATE;
                 }
-                /*else if(!strcmp(msg,"RUN")) {
-                    command_mode = TRAIN_PROCESS;
+                else if(!strcmp(msg,"DECAY_RATE")) {
                     response = "OK";
-                }*/
+                    command_mode = TRAIN_STORE_DECAY_RATE;
+                }
+                else if(!strcmp(msg,"RUN_EPOCHS")) {
+                    command_mode = TRAIN_RUN_EPOCHS;
+                    response = "OK";
+                }
                 break;
 
             case TRAIN_STORE_TOKENS:
@@ -277,15 +284,28 @@ void main()
                     new_token = atoi(numstr);
                     token_series[token_counter] = new_token;
                     token_counter++;
-                    response = "OK";
+                    response = numstr;
                 } else {
                     response = "END";
                 }
                 break;
 
             case TRAIN_STORE_LEARNING_RATE:
-                new_token = atoi(numstr);
-                response = "OK";
+                learning_rate = atoi(numstr);
+                response = numstr;
+                command_mode = START;
+                break;
+
+            case TRAIN_STORE_DECAY_RATE:
+                decay_rate = atoi(numstr);
+                response = numstr;
+                command_mode = START;
+                break;
+
+            case TRAIN_RUN_EPOCHS:
+                uint32_t num_epochs = atoi(numstr);
+                response = run_training(response, num_epochs, learning_rate, decay_rate, token_series[0], token_series[1]);
+                command_mode = START;
                 break;
 
         }
diff --git a/firmware/include/rnn.h b/firmware/include/rnn.h
index ad36ee7..bfa3461 100644
--- a/firmware/include/rnn.h
+++ b/firmware/include/rnn.h
@@ -134,3 +134,21 @@ void set_bias_values(int bias);
  * Resetting the network: Clear LTSM
  */
 void reset_network();
+
+/*
+ * Run training cycle for N epochs
+ * msgbuf: buffer for infos
+ * num_epochs: Amount of epochs
+ * learning_rate_zero: initial learning rate
+ * decay_rate: the decay rate for gradient decay
+ * x: input token
+ * y: output token
+ */
+char* run_training(
+    char *msgbuf,
+    int num_epochs,
+    uint32_t learning_rate_zero,
+    uint32_t decay_rate,
+    uint32_t x,
+    uint32_t y
+);
diff --git a/firmware/rnn.c b/firmware/rnn.c
index af34057..7bf595a 100644
--- a/firmware/rnn.c
+++ b/firmware/rnn.c
@@ -318,28 +318,21 @@ void reset_network()
 /*
  * Setting the delta for the weights
  * We try to achieve something like a gradients
- * with random noise added
  */
-void set_dws(int dw, bool randomize)
+void set_alpha(int alpha)
 {
     int i, x, y;
-    int dwd;
-    
-    dwd = (dw/NUM_HIDDEN_NEURONS_W);
+
     for(i=0;i<NUM_INPUT_NEURONS;i++) {
-        if(randomize) set_layer_weight(LAYER_TYPE_ENCODER, LAYER_VALUE_TYPE_DELTA_W, i, 0, dwd+get_random_char()/NUM_HIDDEN_NEURONS_W);
-        else set_layer_weight(LAYER_TYPE_ENCODER, LAYER_VALUE_TYPE_DELTA_W, i, 0, dwd);
+        set_layer_weight(LAYER_TYPE_ENCODER, LAYER_VALUE_TYPE_DELTA_W, i, 0, alpha);
     }
     for(x=0;x<NUM_HIDDEN_NEURONS_W;x++) {
-        dwd = (x+1)*(dw/NUM_HIDDEN_NEURONS_W);
         for(y=0;y<NUM_HIDDEN_NEURONS_H;y++) {
-            if(randomize) set_layer_weight(LAYER_TYPE_HIDDEN,LAYER_VALUE_TYPE_DELTA_W, x*NUM_HIDDEN_NEURONS_H+y, 0, dwd+get_random_char()/NUM_HIDDEN_NEURONS_W);
-            else set_layer_weight(LAYER_TYPE_HIDDEN,LAYER_VALUE_TYPE_DELTA_W, x*NUM_HIDDEN_NEURONS_H+y, 0, dwd);
+            set_layer_weight(LAYER_TYPE_HIDDEN,LAYER_VALUE_TYPE_DELTA_W, x*NUM_HIDDEN_NEURONS_H+y, 0, alpha);
         }
     }    
     for(i=0;i<NUM_OUTPUT_NEURONS;i++) {
-        if(randomize) set_layer_weight(LAYER_TYPE_DECODER,LAYER_VALUE_TYPE_DELTA_W, i, 0, dw+get_random_char());
-        else set_layer_weight(LAYER_TYPE_DECODER,LAYER_VALUE_TYPE_DELTA_W, i, 0, dw);
+        set_layer_weight(LAYER_TYPE_DECODER,LAYER_VALUE_TYPE_DELTA_W, i, 0, alpha);
     }
 }
 
@@ -394,32 +387,38 @@ void set_bias_values(int bias)
     }
 }
 
-void run_training()
+/*
+ * Run training cycle for N epochs
+ * msgbuf: buffer for infos
+ * num_epochs: Amount of epochs
+ * learning_rate_zero: initial learning rate
+ * decay_rate: the decay rate for gradient decay
+ * x: input token
+ * y: output token
+ */
+char* run_training(
+    char *msgbuf,
+    int num_epochs,
+    uint32_t learning_rate_zero,
+    uint32_t decay_rate,
+    uint32_t x,
+    uint32_t y
+)
 {
-    int dw = 200000000;
-    int bias = 0;
     int last_val;
     uint32_t train_mask;
 
-    last_val = predict_next_token(27017);
-
-    /* Compare values
-     * The phrase "Taxation is theft"
-     * tokenized with the GPT2 tokenizer
-     * is tensor([[27017,   341,   318, 12402]])
-     * this means the next value we want is 341
-     */
-    while(last_val!=341) {
-
-        train_mask = last_val ^ 341;
-
-        //set_dws(dw);
-        set_dws(dw, false);
-        dw=dw/100;
-        dw=99*dw;
+    for(int epoch=0; epoch<num_epochs;epoch++) {
+        reset_network();
+        last_val = predict_next_token(x);
+        if(last_val==y) {
+            return "SUCCESS";
+            break;
+        }
+        train_mask = last_val ^ y;
+        set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
         mask_back_propgatation(train_mask);
-
-        last_val = predict_next_token(27017);
     }
-
+    return "FAIL";
 }
+
diff --git a/src/py/tty3.py b/src/py/tty3.py
index 7ff1660..a0df053 100644
--- a/src/py/tty3.py
+++ b/src/py/tty3.py
@@ -6,6 +6,12 @@ from fac_tools import load_weights_and_biases
 
 from transformers import AutoTokenizer
 
+tokenizer = AutoTokenizer.from_pretrained("gpt2")
+prompt = "Taxation is theft"
+input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+tokens = input_ids[0].numpy()
+print(tokens)
+
 server = get_fac_wrapper("telnet")
 
 print("Reading in JSON")
@@ -13,12 +19,14 @@ with open("test_files/weights_and_biases.json", "r") as f:
     weights_and_biases = json.loads(f.read())
     f.close()
 
-#load_weights_and_biases(server, weights_and_biases)
+load_weights_and_biases(server, weights_and_biases)
+
+INTMAX=2147483647
+lr=0.95
+decay_rate=1
+
+run_command(server,"HELLO")
 
-tokenizer = AutoTokenizer.from_pretrained("gpt2")
-prompt = "Taxation is theft"
-input_ids = tokenizer(prompt, return_tensors="pt").input_ids
-tokens = input_ids[0].numpy()
 run_command(server,"TRAIN")
 run_command(server,"TOKENS")
 for tok in tokens:
@@ -26,8 +34,40 @@ for tok in tokens:
 run_command(server,"DONE")
 
 run_command(server,"TRAIN")
-run_command(server,"SHOOT")
+run_command(server,"LEARNING_RATE")
+# PicoRV does not have a floating point unit
+# All weights are stored as signed 32 bit integers
+# this means we have to do the math here and multiply
+# the maximum integer value by our desired learning rate
+lr=lr*INTMAX
+run_command(server,str(int(lr)))
+
+
+run_command(server,"TRAIN")
+run_command(server,"DECAY_RATE")
+# The decay rate is a hyper parameter
+# Because delta W aka alpha is calculated
+# as the initial learning rate α0 multiplied
+# as shown below anything floaty doesn't make sense
+# anyway. Usually the decay rate is set to 1
+# α=(1/(1+decayRate×epochNumber))*​α0
+run_command(server,str(decay_rate))
+
+
+run_command(server,"TRAIN")
+run_command(server,"RUN_EPOCHS")
+run_command(server,str(100))
+
+weights_and_biases = dump_neural_network(server)
+
+j = json.dumps(weights_and_biases, indent=4, sort_keys=True)
+print(j)
 
 run_command(server,"TERMINATE")
 
 server.close()
+
+print("Writing out JSON")
+with open("result/weights_and_biases_trained.json", "w") as f:
+    f.write(j)
+    f.close()
-- 
GitLab