diff --git a/src/py/tty3.py b/src/py/tty3.py index 8a19ea256dc778dd141d1b758cbcc870fc762227..64def001aede23d1bf96d214cef296b20310478c 100644 --- a/src/py/tty3.py +++ b/src/py/tty3.py @@ -31,12 +31,6 @@ run_command(server,"RWEIGHTS") run_command(server,"INIT") run_command(server,"BIAS") -# Upload token series -run_command(server,"TRAIN") -run_command(server,"TOKENS") -for tok in tokens: - run_command(server,str(tok)) -run_command(server,"DONE") run_command(server,"TRAIN") run_command(server,"LEARNING_RATE") @@ -58,6 +52,49 @@ run_command(server,"DECAY_RATE") # α=(1/(1+decayRate×epochNumber))*​α0 run_command(server,str(decay_rate)) + + +# Priming phase! +# Upload and train token pairs first +run_command(server,"TRAIN") +run_command(server,"TOKENS") +tok=tokens[0] +run_command(server,str(tok)) +tok=tokens[1] +run_command(server,str(tok)) +run_command(server,"DONE") +run_command(server,"TRAIN") +run_command(server,"RUN_EPOCHS") +run_command(server,str(10000)) + +run_command(server,"TRAIN") +run_command(server,"TOKENS") +tok=tokens[1] +run_command(server,str(tok)) +tok=tokens[2] +run_command(server,str(tok)) +run_command(server,"DONE") +run_command(server,"TRAIN") +run_command(server,"RUN_EPOCHS") +run_command(server,str(10000)) + +run_command(server,"TRAIN") +run_command(server,"TOKENS") +tok=tokens[2] +run_command(server,str(tok)) +tok=tokens[3] +run_command(server,str(tok)) +run_command(server,"DONE") +run_command(server,"TRAIN") +run_command(server,"RUN_EPOCHS") +run_command(server,str(10000)) + +# Upload token series +run_command(server,"TRAIN") +run_command(server,"TOKENS") +for tok in tokens: + run_command(server,str(tok)) +run_command(server,"DONE") run_command(server,"TRAIN") run_command(server,"RUN_EPOCHS") run_command(server,str(1000))