DNN全连接层(线性层)
计算公式:
y = w * x + b
W和b是参与训练的参数
W的维度决定了隐含层输出的维度,一般称为隐单元个数(hidden size)
b是偏差值(本文没考虑)
举例:
输入:x (维度1 x 3)
隐含层1:w(维度3 x 5)
隐含层2: w(维度5 x 2)
个人思想如下:
比如说如上图,我们有输入层是3个,中间层是5个,输出层要求是2个。利用线性代数,输入是【1×3】,那么需要乘【3×5】的权重矩阵得到【1×5】,再由【1×5】乘【5×2】的权重矩阵,最后得到【1×2】的结果。在本代码中没有考虑偏差值(bias),利用pytorch中随机初始化的权重实现模型预测。
import torch
import torch.nn as nn
import numpy as np
"""
用pytorch框架实现单层的全连接网络
不使用偏置bias
"""
class TorchModel(nn.Module): #nn.module是torch自带的库def __init__(self, input_size, hidden_size, output_size):super(TorchModel, self).__init__()self.layer1 = nn.Linear(input_size, hidden_size, bias=False)#nn.linear是torch的线性层,input_size是输入的维度,hidden_size是这一层的输出的维度self.layer2 = nn.Linear(hidden_size, output_size, bias=False)#这个线性层可以有很多个def forward(self, x): #开始计算的函数hidden = self.layer1(x) #传入输入第一层# print("torch hidden", hidden)y_pred = self.layer2(hidden) #传入输入第二层return y_pred
x = np.array([1, 0, 0]) #网络输入#torch实验
torch_model = TorchModel(len(x), 5, 2) #这三个数分别代表输入,中间,结果层的维度
#print(torch_model.state_dict()) #可以打印出pytorch随机初始化的权重
torch_model_w1 = torch_model.state_dict()["layer1.weight"].numpy()
#通过取字典方式将权重取出来并把torch的权重转化为numpy的
torch_model_w2 = torch_model.state_dict()["layer2.weight"].numpy()
#print(torch_model_w1, "torch w1 权重")
#这里你会发现随机初始化的权重矩阵是5×3,所以当自定义模型时需要转置,但是在pytorch中会自动转置相乘
#print(torch_model_w2, "torch w2 权重")
torch_x = torch.FloatTensor([x]) #numpy的输入转化为torch
y_pred = torch_model.forward(torch_x)
print("torch模型预测结果:", y_pred)
以上是pytorch模型实现DNN的简单方法。
自定义模型手工实现:
(注意因为自定义模型需要得到模型中的权重,而上面代码利用的是pytorch的随机自定义模型,为了能让两者对比答案是否相同,自定义模型中的权重需要继承pytorch的随机权重)
"""
手动实现简单的神经网络
用自定义框架实现单层的全连接网络
不使用偏置bias
"""
#自定义模型
class DiyModel:def __init__(self, weight1, weight2):self.weight1 = weight1 #收到在torch随机的权重self.weight2 = weight2def forward(self, x):hidden = np.dot(x, self.weight1.T) #将输入与第一层权重的转置相乘y_pred = np.dot(hidden, self.weight2.T)return y_preddiy_model = DiyModel(torch_model_w1, torch_model_w2)
y_pred_diy = diy_model.forward(np.array([x]))
print("diy模型预测结果:", y_pred_diy)
如需运行须将自定义模型放入pytorch的代码下面继承输入和随机权重,通过最后结果能发现两者相同。
结果如下:
可以发现两者代码结果相同~