今天看代码,对比了常见的公式表达与代码的表达,发觉torch.nn.Linear的数学表达与我想象的有点不同,于是思索了一番。
众多周知,torch.nn.Linear作为全连接层,将下一层的每个结点与上一层的每一节点相连,用来将前边提取的特征综合起来。具体如下:
则显然可以得到:,其中
, ,,
上面的公式进行转置后,得到,
也就是将输入和输出向量都变成了行向量了。
在pytorch中,
实际上这里的x就是行向量,y也是行向量,A的行数与y(输出)相关,列数与x(输入)相关, b是一个行向量,与输出维度有关。
这里可以看到,m作为一个全连接层,输入为20维,输出为30维,则可见A的规模为30x20(输出规模x输入规模),
input作为一个输入矩阵,规模为128x20,这里可见一般在一个tensor中,feature都是行优先,
这样的话,使用m作用到input上,规模为128x20x(20x30)---->128x30. 即为示例中结果。
注意下图A和b的维度。
以上的分析照应了torch.nn.functional.
linear的表达。