From 0cd68badfcc924ad3da43c9cb5b5c1babbee888a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com> Date: Fri, 28 Jun 2024 00:07:00 +0100 Subject: [PATCH] Prepare testing of training cycles --- .gitlab-ci.yml | 11 ++++++++ firmware/firmware.c | 43 +++++++++++++++++++++++++++- firmware/rnn.c | 68 +++++++++++++-------------------------------- src/py/tty3.py | 33 ++++++++++++++++++++++ 4 files changed, 105 insertions(+), 50 deletions(-) create mode 100644 src/py/tty3.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 533b16f..83e3d8e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -27,6 +27,17 @@ write_test: paths: - result +train_test: + stage: build + script: + - mkdir -p result + - make + - ./testbench_ecp5_minifpga.bin & + - sleep 1s && python3 src/py/tty3.py + artifacts: + paths: + - result + #json: # stage: build # script: diff --git a/firmware/firmware.c b/firmware/firmware.c index 0b0af18..cccdc66 100644 --- a/firmware/firmware.c +++ b/firmware/firmware.c @@ -23,6 +23,7 @@ #define MEM_TOTAL 0x20000 #define MSGBUF 40 +#define MAX_NUM_TOKENS 20 // a pointer to this is a null pointer, but the compiler does not // know that because "sram" is a linker symbol from sections.lds. @@ -66,7 +67,11 @@ void main() // Write mode WRITE, WRITE_LAYER, - WRITE_LAYER_WEIGHTS + WRITE_LAYER_WEIGHTS, + // Training the network + TRAIN, + TRAIN_STORE_TOKENS, + TRAIN_PROCESS } command_mode; command_mode = START; @@ -75,9 +80,15 @@ void main() uint32_t current_max_num_values; uint32_t current_num_neurons; + // Write params uint32_t value_write_counter; int new_value; + // Training + uint32_t new_token; + uint32_t token_series[MAX_NUM_TOKENS]; + int token_counter; + while(true) { syn(); // Booted @@ -115,6 +126,12 @@ void main() else if(!strcmp(msg,"WRITE")) { command_mode = WRITE; } + else if(!strcmp(msg,"RESET")) { + reset_network(); + } + else if(!strcmp(msg,"TRAIN")) { + command_mode = TRAIN; + } break; case INIT: @@ -232,6 +249,30 @@ void main() value_write_counter++; } break; + + case TRAIN: + if(!strcmp(msg,"TOKENS")) { + command_mode = TRAIN_STORE_TOKENS; + response = "OK"; + token_counter = 0; + } + else if(!strcmp(msg,"RUN")) { + command_mode = TRAIN_PROCESS; + response = "OK"; + } + break; + + case TRAIN_STORE_TOKENS: + if(token_counter<MAX_NUM_TOKENS) { + new_token = atoi(numstr,10); + token_series[token_counter] = new_token; + token_counter++; + response = "OK"; + } else { + response = "END"; + } + break; + } write_response(response); } diff --git a/firmware/rnn.c b/firmware/rnn.c index 06e015b..af34057 100644 --- a/firmware/rnn.c +++ b/firmware/rnn.c @@ -394,62 +394,32 @@ void set_bias_values(int bias) } } -#if 0 +void run_training() +{ int dw = 200000000; int bias = 0; int last_val; uint32_t train_mask; - /* - - citoa(27017, numstr, 10); - print("Input value: "); - print(numstr); - ack(); - - last_val = predict_next_token(27017); - - citoa(last_val, numstr, 10); - print("Output value: "); - print(numstr); - ack(); - - leds(8);*/ - /* Compare values - * The phrase "Taxation is theft" - * tokenized with the GPT2 tokenizer - * is tensor([[27017, 341, 318, 12402]]) - * this means the next value we want is 341 - */ - /*while(last_val!=341) { + last_val = predict_next_token(27017); - train_mask = last_val ^ 341; + /* Compare values + * The phrase "Taxation is theft" + * tokenized with the GPT2 tokenizer + * is tensor([[27017, 341, 318, 12402]]) + * this means the next value we want is 341 + */ + while(last_val!=341) { - print("Run training single shot: "); - citoa(train_mask, numstr, 10); - print(numstr); - ack(); - - //set_dws(dw); - set_dws(dw, false); - dw=dw/100; - dw=99*dw; - mask_back_propgatation(train_mask); - - citoa(27017, numstr, 10); - print("Input value: "); - print(numstr); - ack(); - - last_val = predict_next_token(27017); + train_mask = last_val ^ 341; - citoa(last_val, numstr, 10); - print("Output value: "); - print(numstr); - ack(); - }*/ + //set_dws(dw); + set_dws(dw, false); + dw=dw/100; + dw=99*dw; + mask_back_propgatation(train_mask); - //print("EXIT"); - //ack(); + last_val = predict_next_token(27017); + } -#endif +} diff --git a/src/py/tty3.py b/src/py/tty3.py new file mode 100644 index 0000000..7ff1660 --- /dev/null +++ b/src/py/tty3.py @@ -0,0 +1,33 @@ +import json + +from fac_tools import run_command +from fac_tools import get_fac_wrapper +from fac_tools import load_weights_and_biases + +from transformers import AutoTokenizer + +server = get_fac_wrapper("telnet") + +print("Reading in JSON") +with open("test_files/weights_and_biases.json", "r") as f: + weights_and_biases = json.loads(f.read()) + f.close() + +#load_weights_and_biases(server, weights_and_biases) + +tokenizer = AutoTokenizer.from_pretrained("gpt2") +prompt = "Taxation is theft" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids +tokens = input_ids[0].numpy() +run_command(server,"TRAIN") +run_command(server,"TOKENS") +for tok in tokens: + run_command(server,str(tok)) +run_command(server,"DONE") + +run_command(server,"TRAIN") +run_command(server,"SHOOT") + +run_command(server,"TERMINATE") + +server.close() -- GitLab