`timescale 1ns / 1ps// Description : 全连接层
// Change Logs : 2024.05.10 - Yang.Long - 1.0.0 - module nnLinear #(parameter G_WDEPTH = 12 ,//权重深度parameter G_PDEPTH = 8 ,//像素深度parameter G_LINEXLEN = 160 ,//每行图像宽度parameter G_LINEYLEN = 160 ,//每行图像高度 parameter G_FEATURESi = 120 ,//输入的神经元个数parameter G_FEATURESo = 10 ,//输出神经元个数parameter G_FEATURESb = 1'b1 //是否包含偏置
)( input wire isysclk ,input wire isysrst ,input wire s_axis_ruser ,input wire s_axis_rvalid ,input wire signed [G_PDEPTH-1:0] s_axis_rdata ,output wire s_axis_readen ,input wire signed [G_WDEPTH*G_FEATURESo-1:0] s_axis_weight ,input wire signed [G_WDEPTH*G_FEATURESo-1:0] s_axis_bias ,output reg m_axis_tuser ,output reg m_axis_tvalid ,output reg signed [G_PDEPTH*G_FEATURESo-1:0] m_axis_tdata
);
/*
import torchlinear = torch.nn.Linear(in_features=3, out_features=5, bias=True)b = torch.tensor([[1, 1, 1]], dtype=torch.float32)out2 = linear(b)print(linear.weight.data)
print(linear.bias.data)
print(out2)b = torch.tensor([[1, 1, 1]], dtype=torch.float32)
tensor([[-0.1069, -0.3522, 0.3378],[ 0.2721, 0.3001, 0.4206],[-0.1825, 0.1193, -0.0052],[-0.1361, -0.3696, -0.3186],[-0.5642, 0.5640, 0.4559]])
tensor([ 0.0126, -0.3215, 0.3172, -0.0352, -0.5045])
tensor([[-0.1088, 0.6713, 0.2488, -0.8595, -0.0489]], grad_fn=<AddmmBackward0>)
*/localparam ACTIVERST = 1'b0; function integer log2;
input integer number; begin log2 = 0; while(2**log2 < number) begin log2 = log2 + 1; end end
endfunction localparam A = G_FEATURESi + 2; localparam G_XPCOUNT = log2(A+4); reg [G_XPCOUNT-1:0] buffer_xcnt;
reg signed [G_PDEPTH-1:0] buffer_csum [G_FEATURESo-1:0];
wire temp_axis_tvalid [G_FEATURESo-1:0];
wire signed [G_PDEPTH-1:0] temp_axis_tdata [G_FEATURESo-1:0];always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)buffer_xcnt <= 0; else if(s_axis_ruser == 1'b1)buffer_xcnt <= 0; else if(temp_axis_tvalid[0] == 1'b1) begin if(buffer_xcnt == G_FEATURESi - 1)buffer_xcnt <= 0;elsebuffer_xcnt <= buffer_xcnt + 1'b1;end
endgenerate
genvar i;
for(i=0; i<=G_FEATURESo-1; i=i+1) beginSigMultiply #(.G_PDEPTH ( G_PDEPTH ))u_SigMultiply( .isysclk ( isysclk ),.isysrst ( isysrst ),.s_axis_rvalid ( s_axis_rvalid ),.s_axis_rdat1 ( s_axis_rdata ), .s_axis_rdat2 ( s_axis_weight[(i+1)*G_WDEPTH-1:G_WDEPTH*i] ), .m_axis_tvalid ( temp_axis_tvalid[i] ), .m_axis_tdata ( temp_axis_tdata[i] ));always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)buffer_csum[i] <= 0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == 0)buffer_csum[i] <= temp_axis_tdata[i];else buffer_csum[i] <= buffer_csum[i] + temp_axis_tdata[i];endif(G_FEATURESb == 1'b0) begin always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1) m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= buffer_csum[i] + temp_axis_tdata[i];else m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;endalways @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tvalid <= 1'b0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1) m_axis_tvalid <= 1'b1;else m_axis_tvalid <= 1'b0;endend else begin always @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1)m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= buffer_csum[i] + temp_axis_tdata[i] + s_axis_bias[(i+1)*G_WDEPTH-1:G_WDEPTH*i];else m_axis_tdata[(i+1)*G_PDEPTH-1:G_PDEPTH*i] <= 0;endalways @(posedge isysclk or negedge isysrst) begin if(isysrst == ACTIVERST)m_axis_tvalid <= 1'b0;else if(temp_axis_tvalid[i] == 1'b1 && buffer_xcnt == G_FEATURESi - 1) m_axis_tvalid <= 1'b1;else m_axis_tvalid <= 1'b0;endend
end
endgeneratealways @(posedge isysclk) begin m_axis_tuser <= s_axis_ruser;
endendmodule