Skip to content
Snippets Groups Projects
Commit 34ee2fdf authored by David Lanzendörfer's avatar David Lanzendörfer
Browse files

Try single epoch training for now

parent 20521bab
No related branches found
No related tags found
No related merge requests found
Pipeline #100 canceled with stage
in 1 hour, 22 minutes, and 51 seconds
......@@ -78,7 +78,8 @@ void main()
TRAIN_STORE_TOKENS,
TRAIN_STORE_LEARNING_RATE,
TRAIN_STORE_DECAY_RATE,
TRAIN_RUN_EPOCHS
TRAIN_RUN_EPOCHS,
TRAIN_RUN_SINGLE_EPOCH
} command_mode;
command_mode = START;
......@@ -97,6 +98,8 @@ void main()
int token_counter = 0;
uint32_t learning_rate = 0;
uint32_t decay_rate = 0;
uint32_t epoch_value = 0;
while(true) {
......@@ -277,6 +280,10 @@ void main()
command_mode = TRAIN_RUN_EPOCHS;
response = "OK";
}
else if(!strcmp(msg,"RUN_SINGLE_EPOCH")) {
command_mode = TRAIN_RUN_SINGLE_EPOCH;
response = "OK";
}
break;
case TRAIN_STORE_TOKENS:
......@@ -303,11 +310,17 @@ void main()
break;
case TRAIN_RUN_EPOCHS:
uint32_t num_epochs = atoi(numstr);
response = run_training(response, num_epochs, learning_rate, decay_rate, token_series, token_counter);
epoch_value = atoi(numstr);
response = run_training(response, epoch_value, learning_rate, decay_rate, token_series, token_counter);
command_mode = START;
break;
case TRAIN_RUN_SINGLE_EPOCH:
epoch_value = atoi(numstr);
new_value = run_training_single_epoch(epoch_value, learning_rate, decay_rate, token_series, token_counter);
response = new_value?"SUCCESS":"FAILURE";
break;
}
write_response(response);
}
......
......@@ -152,3 +152,22 @@ char* run_training(
uint32_t *xp,
uint32_t xn
);
/*
* Run training cycle for 1 epoch
* epoch: The epoch we're in
* learning_rate_zero: initial learning rate
* decay_rate: the decay rate for gradient decay
* xp: Pointer to array of values
* xi: Index of y
*
* Returns 0 when successful
* Returns 1 when unsuccessful
*/
int run_training_single_epoch(
int epoch,
uint32_t learning_rate_zero,
uint32_t decay_rate,
uint32_t *xp,
uint32_t xi
);
......@@ -387,6 +387,59 @@ void set_bias_values(int bias)
}
}
/*
* Run training cycle for 1 epoch
* epoch: The epoch we're in
* learning_rate_zero: initial learning rate
* decay_rate: the decay rate for gradient decay
* xp: Pointer to array of values
* xi: Index of y
*
* Returns 0 when successful
* Returns 1 when unsuccessful
*/
int run_training_single_epoch(
int epoch,
uint32_t learning_rate_zero,
uint32_t decay_rate,
uint32_t *xp,
uint32_t xi
)
{
int last_val;
uint32_t train_mask;
uint32_t y = xp[xi];;
/*
* Revert the LSTM states
*/
reset_network();
/*
* Put all the prior tokens of the time series
* into the LSTM for reaching the state we wish to train on
*/
for(int xpi=0; xpi<xi; xpi++) {
last_val = predict_next_token(xp[xpi]);
}
/*
* Check whether the network has already been trained
* Set return value to zero and exit the loop if so
*/
if(last_val==y) {
return 1;
}
/*
* Determine the training mask my finding out which
* neurons misfired or didn't fire although they should have
*/
train_mask = last_val ^ y;
set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
mask_back_propgatation(train_mask);
return 0;
}
/*
* Run training cycle for N epochs
* num_epochs: Amount of epochs
......@@ -406,41 +459,21 @@ int run_training_epochs(
uint32_t xi
)
{
int ret = 1;
int last_val;
uint32_t train_mask;
uint32_t y = xp[xi];;
for(int epoch=0; epoch<num_epochs;epoch++) {
/*
* Revert the LSTM states
*/
reset_network();
/*
* Put all the prior tokens of the time series
* into the LSTM for reaching the state we wish to train on
*/
for(int xpi=0; xpi<xi; xpi++) {
last_val = predict_next_token(xp[xpi]);
}
/*
* Check whether the network has already been trained
* Set return value to zero and exit the loop if so
*/
if(last_val==y) {
ret = 0;
break;
}
/*
* Determine the training mask my finding out which
* neurons misfired or didn't fire although they should have
*/
train_mask = last_val ^ y;
set_alpha(learning_rate_zero/(1+(decay_rate*epoch)));
mask_back_propgatation(train_mask);
if(run_training_single_epoch(
epoch,
learning_rate_zero,
decay_rate,
xp,
xi
)) return 0;
}
return ret;
return 1;
}
......
......@@ -17,7 +17,7 @@ server = get_fac_wrapper("telnet")
INTMAX=2147483647
lr=0.95
decay_rate=10
decay_rate=100
run_command(server,"HELLO")
......@@ -31,7 +31,6 @@ run_command(server,"RWEIGHTS")
run_command(server,"INIT")
run_command(server,"BIAS")
run_command(server,"TRAIN")
run_command(server,"LEARNING_RATE")
# PicoRV does not have a floating point unit
......@@ -41,7 +40,6 @@ run_command(server,"LEARNING_RATE")
lr=lr*INTMAX
run_command(server,str(int(lr)))
run_command(server,"TRAIN")
run_command(server,"DECAY_RATE")
# The decay rate is a hyper parameter
......@@ -62,9 +60,13 @@ tok=tokens[1]
run_command(server,str(tok))
run_command(server,"DONE")
run_command(server,"TRAIN")
run_command(server,"RUN_EPOCHS")
run_command(server,str(1000))
run_command(server,"RUN_SINGLE_EPOCH")
for i in range(0,500):
ret = run_command(server,str(i))
if "SUCCESS" in ret:
break
'''
run_command(server,"TRAIN")
run_command(server,"TOKENS")
tok=tokens[1]
......@@ -74,7 +76,7 @@ run_command(server,str(tok))
run_command(server,"DONE")
run_command(server,"TRAIN")
run_command(server,"RUN_EPOCHS")
run_command(server,str(1000))
run_command(server,str(500))
run_command(server,"TRAIN")
run_command(server,"TOKENS")
......@@ -85,7 +87,7 @@ run_command(server,str(tok))
run_command(server,"DONE")
run_command(server,"TRAIN")
run_command(server,"RUN_EPOCHS")
run_command(server,str(1000))
run_command(server,str(500))
# Upload token series
run_command(server,"TRAIN")
......@@ -96,6 +98,7 @@ run_command(server,"DONE")
run_command(server,"TRAIN")
run_command(server,"RUN_EPOCHS")
run_command(server,str(1000))
'''
# Store the weights and biases
weights_and_biases = dump_neural_network(server)
......
......@@ -101,7 +101,6 @@ module neuron #(
else
if(!backprop_done && backprop_running)
begin
$display("Backprop with dW=", $signed(dw));
// Adjust normal weights
if( regidx < NUMBER_SYNAPSES )
begin
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment