文章目录
- 1. 前言
- 2. 导入必要的库
- 3. 加载数据集
- 4. 定义线性层网络结构
- 5. 实例化网络并打印输出
- 6. 定义非线性层网络结构
- 7. 总结
1. 前言
在深度学习中,线性层和非线性层是构建神经网络的基本单元。本文将通过PyTorch实现一个简单的网络,详细讲解线性层与非线性层的使用和区别。
2. 导入必要的库
首先,我们需要导入PyTorch以及一些常用的模块:
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
3. 加载数据集
使用torchvision加载CIFAR-10数据集,并将其转换为Tensor格式。
dataset = torchvision.datasets.CIFAR10(root="data1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)
root="data1"
:数据存储路径。train=False
:加载测试集。transform=torchvision.transforms.ToTensor()
:将图像数据转换为Tensor。download=True
:如果数据集不存在,则下载。drop_last=True
:如果最后一个batch大小小于batch_size,则丢弃。
4. 定义线性层网络结构
构建一个包含线性层的简单神经网络:
class NN(nn.Module):def __init__(self):super(NN, self).__init__()self.linear1 = nn.Linear(196608, 10) # 定义一个线性层def forward(self, input):output = self.linear1(input) # 前向传播return output
nn.Linear(196608, 10)
:定义一个线性层,输入维度为196608,输出维度为10。
5. 实例化网络并打印输出
使用DataLoader加载数据,遍历数据并打印输出结果。
mynn = NN() # 实例化网络for data in dataloader:imgs, targets = dataprint(imgs.shape) # 打印图像的形状output = torch.flatten(imgs) # 展平图像print(output.shape) # 打印展平后的形状output = mynn(output) # 输入到网络中print(output.shape) # 打印输出的形状print("------------------")
torch.flatten(imgs)
:将图像展平为一维。- 将展平后的图像输入到网络中,得到输出。
输出结果:
torch.Size([64, 3, 32, 32])
torch.Size([196608])
torch.Size([10])
------------------
每次遍历数据加载器,我们可以看到原始图像的形状,展平后的形状,以及通过线性层后的输出形状。
6. 定义非线性层网络结构
为了演示非线性层,我们可以在网络中加入激活函数,例如ReLU(Rectified Linear Unit):
class NNWithNonLinearity(nn.Module):def __init__(self):super(NNWithNonLinearity, self).__init__()self.linear1 = nn.Linear(196608, 10)self.relu = nn.ReLU() # 定义ReLU激活函数def forward(self, input):output = self.linear1(input)output = self.relu(output) # 应用激活函数return output
nn.ReLU()
:定义ReLU激活函数。- 将线性层的输出通过ReLU激活函数,增加非线性。
实例化非线性网络并打印输出:
mynn_nonlin = NNWithNonLinearity()for data in dataloader:imgs, targets = dataoutput = torch.flatten(imgs)output = mynn_nonlin(output)print(output)print("------------------")
7. 总结
线性层和非线性层是神经网络的基本构件。线性层执行线性变换,而非线性层(例如激活函数)引入非线性,从而使网络能够拟合复杂的函数。本文通过实例演示了如何在PyTorch中使用这些层,理解了它们的工作原理和应用。
通过这种方式,我们可以更好地理解和构建复杂的神经网络,提高模型的表现力和泛化能力。