From bfa6f6bf909a40ae1a35a5565e2238ac1d4a5f7a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Mon, 1 Jul 2024 09:04:43 +0100
Subject: [PATCH] Making more API functions

---
 src/py/fac_tools.py | 24 +++++++++++++++++++++
 src/py/tty1.py      | 10 +++------
 src/py/tty3.py      | 51 ++++++---------------------------------------
 3 files changed, 33 insertions(+), 52 deletions(-)

diff --git a/src/py/fac_tools.py b/src/py/fac_tools.py
index f745915..f6f6293 100644
--- a/src/py/fac_tools.py
+++ b/src/py/fac_tools.py
@@ -140,3 +140,27 @@ def load_weights_and_biases(server,weights_and_biases):
     write_value_array(server, "HIDDEN", "BIAS", weights_and_biases['hidden']['biases'])
     write_value_multi_array(server, "HIDDEN", "WEIGHTS", weights_and_biases['hidden']['weights'])
     write_value_multi_array(server, "HIDDEN", "RWEIGHTS", weights_and_biases['hidden']['rnn_weights'])
+
+def init_weights_and_biases(server):
+    run_command(server,"INIT")
+    run_command(server,"WEIGHTS")
+    run_command(server,"INIT")
+    run_command(server,"RWEIGHTS")
+    run_command(server,"INIT")
+    run_command(server,"BIAS")
+    run_command(server,"TRAIN")
+    run_command(server,"LEARNING_RATE")
+
+def train_token_series(server, tokens, epochs):
+    run_command(server,"TRAIN")
+    run_command(server,"TOKENS")
+    for tok in tokens:
+        run_command(server,str(tok))
+    run_command(server,"DONE")
+    run_command(server,"TRAIN")
+    run_command(server,"RUN_SINGLE_EPOCH")
+    for i in range(0,epochs):
+        print("Training epoch: ",i)
+        ret = run_command(server,str(i))
+        if "SUCCESS" in ret:
+            break
diff --git a/src/py/tty1.py b/src/py/tty1.py
index 1e2b47e..0fbaa8f 100644
--- a/src/py/tty1.py
+++ b/src/py/tty1.py
@@ -4,17 +4,13 @@ from fac_tools import run_command
 from fac_tools import fetch_all_values
 from fac_tools import get_fac_wrapper
 from fac_tools import dump_neural_network
+from fac_tools import init_weights_and_biases
 
 server = get_fac_wrapper("telnet")
 
-run_command(server,"INIT")
-run_command(server,"WEIGHTS")
 
-run_command(server,"INIT")
-run_command(server,"RWEIGHTS")
-
-run_command(server,"INIT")
-run_command(server,"BIAS")
+# Initialize the values
+init_weights_and_biases(server)
 
 weights_and_biases = dump_neural_network(server)
 
diff --git a/src/py/tty3.py b/src/py/tty3.py
index 1078dcd..116ef47 100644
--- a/src/py/tty3.py
+++ b/src/py/tty3.py
@@ -4,6 +4,8 @@ from fac_tools import run_command
 from fac_tools import get_fac_wrapper
 from fac_tools import load_weights_and_biases
 from fac_tools import dump_neural_network
+from fac_tools import init_weights_and_biases
+from fac_tools import train_token_series
 
 from transformers import AutoTokenizer
 
@@ -22,17 +24,8 @@ decay_rate=100
 run_command(server,"HELLO")
 
 # Initialize the values
-run_command(server,"INIT")
-run_command(server,"WEIGHTS")
+init_weights_and_biases(server)
 
-run_command(server,"INIT")
-run_command(server,"RWEIGHTS")
-
-run_command(server,"INIT")
-run_command(server,"BIAS")
-
-run_command(server,"TRAIN")
-run_command(server,"LEARNING_RATE")
 # PicoRV does not have a floating point unit
 # All weights are stored as signed 32 bit integers
 # this means we have to do the math here and multiply
@@ -52,43 +45,11 @@ run_command(server,str(decay_rate))
 
 # Priming phase!
 # Upload and train token pairs first
-run_command(server,"TRAIN")
-run_command(server,"TOKENS")
-tok=tokens[0]
-run_command(server,str(tok))
-tok=tokens[1]
-run_command(server,str(tok))
-run_command(server,"DONE")
-run_command(server,"TRAIN")
-run_command(server,"RUN_SINGLE_EPOCH")
-for i in range(0,500):
-    ret = run_command(server,str(i))
-    if "SUCCESS" in ret:
-        break
+train_token_series(server, tokens[0:2], 20000)
+train_token_series(server, tokens[1:3], 20000)
+train_token_series(server, tokens[2:4], 20000)
 
 '''
-run_command(server,"TRAIN")
-run_command(server,"TOKENS")
-tok=tokens[1]
-run_command(server,str(tok))
-tok=tokens[2]
-run_command(server,str(tok))
-run_command(server,"DONE")
-run_command(server,"TRAIN")
-run_command(server,"RUN_EPOCHS")
-run_command(server,str(500))
-
-run_command(server,"TRAIN")
-run_command(server,"TOKENS")
-tok=tokens[2]
-run_command(server,str(tok))
-tok=tokens[3]
-run_command(server,str(tok))
-run_command(server,"DONE")
-run_command(server,"TRAIN")
-run_command(server,"RUN_EPOCHS")
-run_command(server,str(500))
-
 # Upload token series
 run_command(server,"TRAIN")
 run_command(server,"TOKENS")
-- 
GitLab