From 78e93e83587e79f9e624c6d39294b423e5488af0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Sun, 30 Jun 2024 22:03:58 +0100
Subject: [PATCH] Train pairs first

---
 src/py/tty3.py | 49 +++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 43 insertions(+), 6 deletions(-)

diff --git a/src/py/tty3.py b/src/py/tty3.py
index 8a19ea2..64def00 100644
--- a/src/py/tty3.py
+++ b/src/py/tty3.py
@@ -31,12 +31,6 @@ run_command(server,"RWEIGHTS")
 run_command(server,"INIT")
 run_command(server,"BIAS")
 
-# Upload token series
-run_command(server,"TRAIN")
-run_command(server,"TOKENS")
-for tok in tokens:
-    run_command(server,str(tok))
-run_command(server,"DONE")
 
 run_command(server,"TRAIN")
 run_command(server,"LEARNING_RATE")
@@ -58,6 +52,49 @@ run_command(server,"DECAY_RATE")
 # α=(1/(1+decayRate×epochNumber))*​α0
 run_command(server,str(decay_rate))
 
+
+
+# Priming phase!
+# Upload and train token pairs first
+run_command(server,"TRAIN")
+run_command(server,"TOKENS")
+tok=tokens[0]
+run_command(server,str(tok))
+tok=tokens[1]
+run_command(server,str(tok))
+run_command(server,"DONE")
+run_command(server,"TRAIN")
+run_command(server,"RUN_EPOCHS")
+run_command(server,str(10000))
+
+run_command(server,"TRAIN")
+run_command(server,"TOKENS")
+tok=tokens[1]
+run_command(server,str(tok))
+tok=tokens[2]
+run_command(server,str(tok))
+run_command(server,"DONE")
+run_command(server,"TRAIN")
+run_command(server,"RUN_EPOCHS")
+run_command(server,str(10000))
+
+run_command(server,"TRAIN")
+run_command(server,"TOKENS")
+tok=tokens[2]
+run_command(server,str(tok))
+tok=tokens[3]
+run_command(server,str(tok))
+run_command(server,"DONE")
+run_command(server,"TRAIN")
+run_command(server,"RUN_EPOCHS")
+run_command(server,str(10000))
+
+# Upload token series
+run_command(server,"TRAIN")
+run_command(server,"TOKENS")
+for tok in tokens:
+    run_command(server,str(tok))
+run_command(server,"DONE")
 run_command(server,"TRAIN")
 run_command(server,"RUN_EPOCHS")
 run_command(server,str(1000))
-- 
GitLab