From 868cbd8815edff7a7c6fd1d2a48092ccce2f009b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com> Date: Fri, 28 Jun 2024 16:13:06 +0100 Subject: [PATCH] First attempt of running training --- Makefile | 46 +++++------------------------ firmware/firmware.c | 34 ++++++++++++++++----- firmware/include/rnn.h | 18 ++++++++++++ firmware/rnn.c | 67 +++++++++++++++++++++--------------------- src/py/tty3.py | 52 ++++++++++++++++++++++++++++---- 5 files changed, 132 insertions(+), 85 deletions(-) diff --git a/Makefile b/Makefile index 34dc6d0..6e53018 100644 --- a/Makefile +++ b/Makefile @@ -8,30 +8,6 @@ NETWORK_PARAMS = \ TARGET_BOARD?=ecp5_minifpga -RISCV_GNU_TOOLCHAIN_INSTALL_PREFIX ?= /opt/riscv -PYTHON = python3 -GCC_WARNS = -Werror -Wall -Wextra -Wshadow -Wundef -Wpointer-arith -Wcast-qual -Wcast-align -Wwrite-strings -GCC_WARNS += -Wredundant-decls -Wstrict-prototypes -Wmissing-prototypes -pedantic # -Wconversion -TOOLCHAIN_PREFIX = $(RISCV_GNU_TOOLCHAIN_INSTALL_PREFIX)/bin/riscv32-unknown-elf- -COMPRESSED_ISA = C - -FIRMWARE_OBJS= \ -firmware/start.o \ -firmware/main.o \ -firmware/print.o - -FIRMWARE_FILES= \ -firmware/start.s \ -firmware/rnn.c \ -firmware/firmware.c \ -firmware/io.c \ -firmware/string.c - -FIRMWARE_CFLAGS= -FIRMWARE_CFLAGS+=-mabi=ilp32 -march=rv32imc -FIRMWARE_CFLAGS+=-Wl,--build-id=none,-Bstatic,-T,firmware/sections.lds,--strip-debug -FIRMWARE_CFLAGS+=-ffreestanding -nostdlib - BENCHES= VERILATOR_DIR?=/usr/share/verilator/include @@ -78,7 +54,7 @@ src/cpptb/params.h: echo "#define $(PARAM) $(VALUE)" >> $@ ; \ ) -firmware/defines.h: +firmware/include/defines.h: @ $(foreach PARAMS,$(NETWORK_PARAMS), \ $(eval PARAM = $(word 1,$(subst :, ,$(PARAMS)))) \ $(eval VALUE = $(word 2,$(subst :, ,$(PARAMS)))) \ @@ -100,20 +76,14 @@ result/soc.json: result/firmware.hex src/rtl/params.vh result result/soc_out.config: $(LPFILE) # result/soc.json nextpnr-ecp5 --json result/soc.json --lpf $(LPFILE) --textcfg $@ --freq 50 --package CABGA256 --lpf-allow-unconstrained -result/firmware.hex: firmware/firmware.bin submodules/picorv32/firmware/makehex.py result - $(TOOLCHAIN_PREFIX)objcopy -O verilog firmware/firmware.elf result/firmware.hex - cp result/firmware.hex firmware.hex - -firmware/firmware.bin: firmware/firmware.elf - $(TOOLCHAIN_PREFIX)objcopy -O binary $< $@ - -firmware/firmware.elf: firmware/sections.lds firmware/defines.h $(FIRMWARE_FILES) - $(TOOLCHAIN_PREFIX)gcc $(CFLAGS) $(FIRMWARE_CFLAGS) -o $@ $(FIRMWARE_FILES) +result/firmware.hex: result firmware/include/defines.h + make -C firmware firmware.hex + cp firmware/firmware.hex firmware.hex + cp firmware/firmware.hex result/firmware.hex result: mkdir -p result clean: - rm -rf firmware/firmware.elf result/firmware.hex \ - firmware/firmware.bin firmware/*.o firmware/firmware.map \ - *.bin *.vcd obj_dir src/cpptb/*.o soc.json src/rtl/params.vh \ - src/cpptb/params.h firmware/defines.h + rm -rf *.bin *.vcd obj_dir src/cpptb/*.o soc.json src/rtl/params.vh \ + src/cpptb/params.h firmware/include/defines.h result/firmware.hex + make -C firmware clean diff --git a/firmware/firmware.c b/firmware/firmware.c index aaf0253..8558656 100644 --- a/firmware/firmware.c +++ b/firmware/firmware.c @@ -77,7 +77,8 @@ void main() TRAIN, TRAIN_STORE_TOKENS, TRAIN_STORE_LEARNING_RATE, - TRAIN_PROCESS + TRAIN_STORE_DECAY_RATE, + TRAIN_RUN_EPOCHS } command_mode; command_mode = START; @@ -94,6 +95,8 @@ void main() uint32_t new_token; uint32_t token_series[MAX_NUM_TOKENS]; int token_counter; + uint32_t learning_rate; + uint32_t decay_rate; while(true) { @@ -266,10 +269,14 @@ void main() response = "OK"; command_mode = TRAIN_STORE_LEARNING_RATE; } - /*else if(!strcmp(msg,"RUN")) { - command_mode = TRAIN_PROCESS; + else if(!strcmp(msg,"DECAY_RATE")) { response = "OK"; - }*/ + command_mode = TRAIN_STORE_DECAY_RATE; + } + else if(!strcmp(msg,"RUN_EPOCHS")) { + command_mode = TRAIN_RUN_EPOCHS; + response = "OK"; + } break; case TRAIN_STORE_TOKENS: @@ -277,15 +284,28 @@ void main() new_token = atoi(numstr); token_series[token_counter] = new_token; token_counter++; - response = "OK"; + response = numstr; } else { response = "END"; } break; case TRAIN_STORE_LEARNING_RATE: - new_token = atoi(numstr); - response = "OK"; + learning_rate = atoi(numstr); + response = numstr; + command_mode = START; + break; + + case TRAIN_STORE_DECAY_RATE: + decay_rate = atoi(numstr); + response = numstr; + command_mode = START; + break; + + case TRAIN_RUN_EPOCHS: + uint32_t num_epochs = atoi(numstr); + response = run_training(response, num_epochs, learning_rate, decay_rate, token_series[0], token_series[1]); + command_mode = START; break; } diff --git a/firmware/include/rnn.h b/firmware/include/rnn.h index ad36ee7..bfa3461 100644 --- a/firmware/include/rnn.h +++ b/firmware/include/rnn.h @@ -134,3 +134,21 @@ void set_bias_values(int bias); * Resetting the network: Clear LTSM */ void reset_network(); + +/* + * Run training cycle for N epochs + * msgbuf: buffer for infos + * num_epochs: Amount of epochs + * learning_rate_zero: initial learning rate + * decay_rate: the decay rate for gradient decay + * x: input token + * y: output token + */ +char* run_training( + char *msgbuf, + int num_epochs, + uint32_t learning_rate_zero, + uint32_t decay_rate, + uint32_t x, + uint32_t y +); diff --git a/firmware/rnn.c b/firmware/rnn.c index af34057..7bf595a 100644 --- a/firmware/rnn.c +++ b/firmware/rnn.c @@ -318,28 +318,21 @@ void reset_network() /* * Setting the delta for the weights * We try to achieve something like a gradients - * with random noise added */ -void set_dws(int dw, bool randomize) +void set_alpha(int alpha) { int i, x, y; - int dwd; - - dwd = (dw/NUM_HIDDEN_NEURONS_W); + for(i=0;i<NUM_INPUT_NEURONS;i++) { - if(randomize) set_layer_weight(LAYER_TYPE_ENCODER, LAYER_VALUE_TYPE_DELTA_W, i, 0, dwd+get_random_char()/NUM_HIDDEN_NEURONS_W); - else set_layer_weight(LAYER_TYPE_ENCODER, LAYER_VALUE_TYPE_DELTA_W, i, 0, dwd); + set_layer_weight(LAYER_TYPE_ENCODER, LAYER_VALUE_TYPE_DELTA_W, i, 0, alpha); } for(x=0;x<NUM_HIDDEN_NEURONS_W;x++) { - dwd = (x+1)*(dw/NUM_HIDDEN_NEURONS_W); for(y=0;y<NUM_HIDDEN_NEURONS_H;y++) { - if(randomize) set_layer_weight(LAYER_TYPE_HIDDEN,LAYER_VALUE_TYPE_DELTA_W, x*NUM_HIDDEN_NEURONS_H+y, 0, dwd+get_random_char()/NUM_HIDDEN_NEURONS_W); - else set_layer_weight(LAYER_TYPE_HIDDEN,LAYER_VALUE_TYPE_DELTA_W, x*NUM_HIDDEN_NEURONS_H+y, 0, dwd); + set_layer_weight(LAYER_TYPE_HIDDEN,LAYER_VALUE_TYPE_DELTA_W, x*NUM_HIDDEN_NEURONS_H+y, 0, alpha); } } for(i=0;i<NUM_OUTPUT_NEURONS;i++) { - if(randomize) set_layer_weight(LAYER_TYPE_DECODER,LAYER_VALUE_TYPE_DELTA_W, i, 0, dw+get_random_char()); - else set_layer_weight(LAYER_TYPE_DECODER,LAYER_VALUE_TYPE_DELTA_W, i, 0, dw); + set_layer_weight(LAYER_TYPE_DECODER,LAYER_VALUE_TYPE_DELTA_W, i, 0, alpha); } } @@ -394,32 +387,38 @@ void set_bias_values(int bias) } } -void run_training() +/* + * Run training cycle for N epochs + * msgbuf: buffer for infos + * num_epochs: Amount of epochs + * learning_rate_zero: initial learning rate + * decay_rate: the decay rate for gradient decay + * x: input token + * y: output token + */ +char* run_training( + char *msgbuf, + int num_epochs, + uint32_t learning_rate_zero, + uint32_t decay_rate, + uint32_t x, + uint32_t y +) { - int dw = 200000000; - int bias = 0; int last_val; uint32_t train_mask; - last_val = predict_next_token(27017); - - /* Compare values - * The phrase "Taxation is theft" - * tokenized with the GPT2 tokenizer - * is tensor([[27017, 341, 318, 12402]]) - * this means the next value we want is 341 - */ - while(last_val!=341) { - - train_mask = last_val ^ 341; - - //set_dws(dw); - set_dws(dw, false); - dw=dw/100; - dw=99*dw; + for(int epoch=0; epoch<num_epochs;epoch++) { + reset_network(); + last_val = predict_next_token(x); + if(last_val==y) { + return "SUCCESS"; + break; + } + train_mask = last_val ^ y; + set_alpha(learning_rate_zero/(1+(decay_rate*epoch))); mask_back_propgatation(train_mask); - - last_val = predict_next_token(27017); } - + return "FAIL"; } + diff --git a/src/py/tty3.py b/src/py/tty3.py index 7ff1660..a0df053 100644 --- a/src/py/tty3.py +++ b/src/py/tty3.py @@ -6,6 +6,12 @@ from fac_tools import load_weights_and_biases from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained("gpt2") +prompt = "Taxation is theft" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids +tokens = input_ids[0].numpy() +print(tokens) + server = get_fac_wrapper("telnet") print("Reading in JSON") @@ -13,12 +19,14 @@ with open("test_files/weights_and_biases.json", "r") as f: weights_and_biases = json.loads(f.read()) f.close() -#load_weights_and_biases(server, weights_and_biases) +load_weights_and_biases(server, weights_and_biases) + +INTMAX=2147483647 +lr=0.95 +decay_rate=1 + +run_command(server,"HELLO") -tokenizer = AutoTokenizer.from_pretrained("gpt2") -prompt = "Taxation is theft" -input_ids = tokenizer(prompt, return_tensors="pt").input_ids -tokens = input_ids[0].numpy() run_command(server,"TRAIN") run_command(server,"TOKENS") for tok in tokens: @@ -26,8 +34,40 @@ for tok in tokens: run_command(server,"DONE") run_command(server,"TRAIN") -run_command(server,"SHOOT") +run_command(server,"LEARNING_RATE") +# PicoRV does not have a floating point unit +# All weights are stored as signed 32 bit integers +# this means we have to do the math here and multiply +# the maximum integer value by our desired learning rate +lr=lr*INTMAX +run_command(server,str(int(lr))) + + +run_command(server,"TRAIN") +run_command(server,"DECAY_RATE") +# The decay rate is a hyper parameter +# Because delta W aka alpha is calculated +# as the initial learning rate α0 multiplied +# as shown below anything floaty doesn't make sense +# anyway. Usually the decay rate is set to 1 +# α=(1/(1+decayRate×epochNumber))*​α0 +run_command(server,str(decay_rate)) + + +run_command(server,"TRAIN") +run_command(server,"RUN_EPOCHS") +run_command(server,str(100)) + +weights_and_biases = dump_neural_network(server) + +j = json.dumps(weights_and_biases, indent=4, sort_keys=True) +print(j) run_command(server,"TERMINATE") server.close() + +print("Writing out JSON") +with open("result/weights_and_biases_trained.json", "w") as f: + f.write(j) + f.close() -- GitLab