Skip to content
Snippets Groups Projects
Commit 46eaf401 authored by David Lanzendörfer's avatar David Lanzendörfer
Browse files

Integrate convolution

parent 24908f9f
No related branches found
No related tags found
No related merge requests found
AI_Accelerator_Top_TB: AI_Accelerator_Top_TB:
iverilog -o AI_Accelerator_Top_TB -s AI_Accelerator_Top_TB -pARRAY_SIZE_LIMIT=1073807361 \ iverilog -o AI_Accelerator_Top_TB -s AI_Accelerator_Top_TB -pARRAY_SIZE_LIMIT=1073807361 \
verilog/src/mine/constants.v \ verilog/src/mine/constants.v \
verilog/src/chatgpt/Matrix_Convolution.v \
verilog/src/chatgpt/Matrix_Multiplication.v \ verilog/src/chatgpt/Matrix_Multiplication.v \
verilog/src/chatgpt/AI_Accelerator_Top.v \ verilog/src/chatgpt/AI_Accelerator_Top.v \
verilog/benches/AI_Accelerator_Top_TB.v verilog/benches/AI_Accelerator_Top_TB.v
......
...@@ -40,10 +40,8 @@ module AI_Accelerator_Top #( ...@@ -40,10 +40,8 @@ module AI_Accelerator_Top #(
always @(posedge wb_clk_i) begin always @(posedge wb_clk_i) begin
mem_opdone <= next_mem_opdone; mem_opdone <= next_mem_opdone;
if (wb_rst_i) begin if (wb_rst_i) begin
busy <= 1'b0;
started <= 1'b0;
multiplier_enable <= 1'b0;
mmul_data_i <= 0; mmul_data_i <= 0;
mconv_data_i <= 0;
mem_opdone <= 0; mem_opdone <= 0;
next_mem_opdone <= 0; next_mem_opdone <= 0;
end end
...@@ -51,7 +49,7 @@ module AI_Accelerator_Top #( ...@@ -51,7 +49,7 @@ module AI_Accelerator_Top #(
mem_opdone <= 0; mem_opdone <= 0;
next_mem_opdone <= 0; next_mem_opdone <= 0;
end end
else if (! mem_opdone )begin else if (! mem_opdone ) begin
case ( DFFRAM[0] ) // Register 1 holds the operation to be executed case ( DFFRAM[0] ) // Register 1 holds the operation to be executed
// Enable corresponding module based on operation value in operation register // Enable corresponding module based on operation value in operation register
`TYPE_BW'h1: begin // matrix multiplication `TYPE_BW'h1: begin // matrix multiplication
...@@ -67,6 +65,19 @@ module AI_Accelerator_Top #( ...@@ -67,6 +65,19 @@ module AI_Accelerator_Top #(
end end
endcase endcase
end end
`TYPE_BW'h2: begin // matrix convolution
case (mconv_mem_op)
2'b01: begin // Read
mconv_data_i <= DFFRAM[mconv_addr_o];
next_mem_opdone <= 1;
end
2'b11: begin // Write
DFFRAM[mconv_addr_o] <= mconv_data_o;
mconv_data_i <= 0;
next_mem_opdone <= 1;
end
endcase
end
endcase endcase
end end
end end
...@@ -95,6 +106,25 @@ module AI_Accelerator_Top #( ...@@ -95,6 +106,25 @@ module AI_Accelerator_Top #(
wire [31:0] mmul_addr_o; wire [31:0] mmul_addr_o;
wire [1:0] mmul_mem_op; // Read 01 /Write 11 /None 00 wire [1:0] mmul_mem_op; // Read 01 /Write 11 /None 00
// Matrix Convolution
Matrix_Convolution matrix_conv (
.clk(wb_clk_i),
.reset(wb_rst_i),
.enable(convolution_enable),
.done(matrix_conv_done),
.addr_o(mconv_addr_o),
.data_i(mconv_data_i),
.data_o(mconv_data_o),
.mem_opdone(mem_opdone),
.mem_operation(mconv_mem_op)
);
reg convolution_enable; // on switch
wire matrix_conv_done; // status wire
reg [`TYPE_BW-1:0] mconv_data_i;
wire [`TYPE_BW-1:0] mconv_data_o;
wire [31:0] mconv_addr_o;
wire [1:0] mconv_mem_op; // Read 01 /Write 11 /None 00
/* /*
Control Unit Control Unit
Manages the current operation and changes the value Manages the current operation and changes the value
...@@ -107,6 +137,7 @@ module AI_Accelerator_Top #( ...@@ -107,6 +137,7 @@ module AI_Accelerator_Top #(
busy <= 1'b0; busy <= 1'b0;
started <= 1'b0; started <= 1'b0;
multiplier_enable <= 1'b0; multiplier_enable <= 1'b0;
convolution_enable <= 1'b0;
end end
else if ( started ) begin else if ( started ) begin
started <= 1'b0; started <= 1'b0;
...@@ -126,6 +157,18 @@ module AI_Accelerator_Top #( ...@@ -126,6 +157,18 @@ module AI_Accelerator_Top #(
started <= 1'b1; started <= 1'b1;
end end
end end
`TYPE_BW'h2: begin // matrix convolution
if( matrix_conv_done && busy ) begin
busy <= 1'b0;
convolution_enable <= 1'b0; // Enable matrix multiplication module
DFFRAM[5] <= `TYPE_BW'h0; // Done
end
else if ( DFFRAM[5] == `TYPE_BW'hffff_ffff ) begin
busy <= 1'b1; // indicate that we started operation
convolution_enable <= 1'b1; // Enable matrix multiplication module
started <= 1'b1;
end
end
endcase endcase
end end
end end
......
...@@ -78,22 +78,35 @@ module Matrix_Convolution ( ...@@ -78,22 +78,35 @@ module Matrix_Convolution (
always @(posedge clk) begin always @(posedge clk) begin
// Assign initial values // Assign initial values
if (reset) begin if (reset) begin
i <= 0; height_a <= 0;
j <= 0; width_a <= 0;
height_b <= 0;
width_b <= 0;
i <= 1;
j <= 1;
k <= -1; k <= -1;
l <= -1; l <= -1;
data_o <= 0; data_o <= 0;
addr_o <= base_addr_c; addr_o <= 0;;
mem_operation <= 2'b00; mem_operation <= 2'b00;
done <= 0; done <= 0;
state <= IDLE;
// reset result register
result_buffer<= 0;
operator1_buffer <= 0;
operator2_buffer <= 0;
end end
// State machine // State machine
else if (enable) begin else if (enable) begin
case (state) case (state)
IDLE: begin IDLE: begin
state <= FETCH_PARAMS; state <= FETCH_PARAMS;
i <= 0; height_a <= 0;
j <= 0; width_a <= 0;
height_b <= 0;
width_b <= 0;
i <= 1;
j <= 1;
k <= -1; k <= -1;
l <= -1; l <= -1;
height_a <= 0; height_a <= 0;
...@@ -140,6 +153,7 @@ module Matrix_Convolution ( ...@@ -140,6 +153,7 @@ module Matrix_Convolution (
end end
end end
LOOP1: begin // for (int i = 1; i < rows - 1; i++) { LOOP1: begin // for (int i = 1; i < rows - 1; i++) {
$display("LOOP1");
if (i < height_a - 1) begin if (i < height_a - 1) begin
state <= LOOP2; state <= LOOP2;
i <= i + 1; i <= i + 1;
...@@ -149,75 +163,82 @@ module Matrix_Convolution ( ...@@ -149,75 +163,82 @@ module Matrix_Convolution (
end end
end end
LOOP2: begin // for (int j = 1; j < cols - 1; j++) { LOOP2: begin // for (int j = 1; j < cols - 1; j++) {
$display("LOOP2");
if (j < width_a - 1) begin if (j < width_a - 1) begin
state <= LOOP3; state <= LOOP3;
j <= j + 1; j <= j + 1;
end end
else begin else begin
state <= LOOP1; state <= LOOP1;
j <= 0; j <= 1;
end end
end end
LOOP3: begin // for (int k = -1; k <= 1; k++) { LOOP3: begin // for (int k = -1; k <= 1; k++) {
if (k < 2) begin $display("LOOP3 %d", $signed(k));
if ($signed(k) < 2) begin
state <= LOOP4; state <= LOOP4;
k <= k + 1; k <= $signed(k) + 1;
end end
else begin else begin
state <= LOAD_OPERATOR1; state <= LOOP2;
k <= -1; k <= -1;
end end
end end
LOOP4: begin // for (int l = -1; l <= 1; l++) { LOOP4: begin // for (int l = -1; l <= 1; l++) {
if (l < 2) begin $display("LOOP4 %d", $signed(l));
if ($signed(l) < 2) begin
state <= LOAD_OPERATOR1; state <= LOAD_OPERATOR1;
l <= l + 1; l <= $signed(l) + 1;
end end
else begin else begin
state <= PERFORM_OPERATION; state <= WRITE_RESULT;
l <= -1; l <= -1;
end end
end end
LOAD_OPERATOR1: begin // A[i + k][j + l] LOAD_OPERATOR1: begin // A[i + k][j + l]
if ( addr_o == 0 ) begin $display("LOAD_OPERATOR1");
mem_operation <= 2'b01; // read if ( addr_o == 0 ) begin
addr_o <= base_addr_a + ((i + k) * width_a) + (j + l); mem_operation <= 2'b01; // read
end addr_o <= base_addr_a + (($signed(i) + $signed(k)) * $signed(width_a)) + ($signed(j) + $signed(l));
else if (mem_opdone) begin end
operator1_buffer <= data_i; else if (mem_opdone) begin
state <= LOAD_OPERATOR2; operator1_buffer <= data_i;
mem_operation <= 2'b00; // done state <= LOAD_OPERATOR2;
addr_o <= 0; mem_operation <= 2'b00; // done
end addr_o <= 0;
end
end end
LOAD_OPERATOR2: begin // F[k + 1][l + 1] LOAD_OPERATOR2: begin // F[k + 1][l + 1]
if ( addr_o == 0 ) begin $display("LOAD_OPERATOR2");
mem_operation <= 2'b01; // read if ( addr_o == 0 ) begin
addr_o <= base_addr_b + ((k + 1) * width_b) + (l + 1); mem_operation <= 2'b01; // read
end addr_o <= base_addr_b + (($signed(k) + $signed(1) ) * $signed(width_b)) + ($signed(l) + $signed(1));
else if (mem_opdone) begin end
operator2_buffer <= data_i; else if (mem_opdone) begin
state <= PERFORM_OPERATION; operator2_buffer <= data_i;
mem_operation <= 2'b00; // done state <= PERFORM_OPERATION;
addr_o <= 0; mem_operation <= 2'b00; // done
end addr_o <= 0;
end
end end
PERFORM_OPERATION: begin PERFORM_OPERATION: begin
result_buffer <= result_buffer + operator1_buffer * operator2_buffer; result_buffer <= $signed(result_buffer) + $signed(operator1_buffer) * $signed(operator2_buffer);
state <= LOOP4; $display("%d + %d * %d\n", $signed(result_buffer), $signed(operator1_buffer), $signed(operator2_buffer));
state <= LOOP4;
end end
WRITE_RESULT: begin WRITE_RESULT: begin
if ( addr_o == 0 ) begin if ( addr_o == 0 ) begin
mem_operation <= 2'b11; // write mem_operation <= 2'b11; // write
addr_o <= base_addr_c + (i * width_b) + j; addr_o <= base_addr_c + ($signed(i) * width_b) + $signed(j);
data_o <= result_buffer; data_o <= result_buffer;
end end
else if (mem_opdone) begin else if (mem_opdone) begin
result_buffer <= 0; $display("Wrote %d to %x", $signed(data_o), addr_o);
state <= LOOP2; result_buffer <= 0;
mem_operation <= 2'b00; // done state <= LOOP2;
addr_o <= 0; mem_operation <= 2'b00; // done
end addr_o <= 0;
end
end end
FSM_DONE: begin FSM_DONE: begin
done <= 1; done <= 1;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment