水果三分类项目
Git源码:传送门
水果种类:草莓、树莓、桑葚
- 0:草莓 strawberry
- 1:树莓 raspberry
- 2:桑葚 mulberry
项目设计
- 获取数据 spider.py
- 数据清洗 cleaner.py
- 自定义数据集 dataset.py
- 网络构建 net.py
- 训练模型 train.py
- 推理预测 predict.py
- 摄像头 video_process.py、video2Frame.py
获取数据 spider.py
数据清洗 cleaner.py
删除不是图片的文件
删除无法打开的图片
效果图
手动删除错误图片
随机分配数据集
训练:测试 ≈ 8 :1
自定义数据集 dataset.py
尺度缩放
100 * 100
300 * 300
网络构建 net.py
网络层数
激活函数
训练模型 train.py
迭代次数:
100
200
推理预测 predict.py
输入图片,输出水果类别,记录预测结果
摄像头 video_process.py、video2Frame.py
读取视频
摄像头识别
项目实现
FCNN
v1:6层,ReLU,均方差,100*100
网络结构
self.fc = nn.Sequential(nn.Linear(100 * 100 * 3, 784),nn.ReLU(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 32),nn.ReLU(),nn.Linear(32, 3),nn.Softmax(dim=1)
)
可视化结果
最优准确率:0.9063
平均准确率:87%左右
准确率图示
损失图示
预测结果
总数:90(草莓30+树莓30+桑葚30)
错误:18
分析与解决方案
v2:6层,ReLU,交叉熵,100*100
对比版本1.0
更新损失函数:均方差–>交叉熵
网络结构
self.fc = nn.Sequential(nn.Linear(100 * 100 * 3, 784),nn.ReLU(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 32),nn.ReLU(),nn.Linear(32, 3),
)
可视化结果
最优准确率:0.9109
平均准确率:87%左右
准确率图示
损失图示
预测结果
总数:90(草莓30+树莓30+桑葚30)
错误:15
分析与解决方案
结论
交叉熵效果优于均方差,损失图示显示有过拟合的效果,考虑优化网络层数
v3:4层,ReLU,交叉熵,100*100
对比版本2.0
更改网络层数:6–>4
网络结构
self.fc = nn.Sequential(nn.Linear(100 * 100 * 3, 784),nn.ReLU(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 64),nn.ReLU(),nn.Linear(64, 3),
)
可视化结果
最优准确率:0.9
平均准确率:85%
准确率图示
损失图示
结论
更改网络层数没有明显的数据变化,考虑增加尺度缩放大小、更换激活函数、增加迭代次数
v4:4层,ReLU,交叉熵,300*300
对比版本3.0
更新尺度缩放:100 * 100->300 * 300
更新迭代次数:200
网络结构
self.fc = nn.Sequential(nn.Linear(300 * 300 * 3, 784),nn.ReLU(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 64),nn.ReLU(),nn.Linear(64, 3),
)
可视化结果
最优准确率:0.9
平均准确率:85%
结论
v5:4层,PReLU,交叉熵,100*100
对比版本4
更新激活函数:ReLu–>PReLU
网络结构
self.fc = nn.Sequential(nn.Linear(100 * 100 * 3, 784),nn.PReLU(),nn.Linear(784, 256),nn.PReLU(),nn.Linear(256, 64),nn.PReLU(),nn.Linear(64, 3),
)
可视化结果
CNN
v1:LeNet5:2cnn+3fc,32*32
网络结构
self.cnn = nn.Sequential(# 32 * 32 * 3 --> 28 * 28 * 6nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),nn.ReLU(),# 28 * 28 * 6 --> 14 * 14 * 6nn.MaxPool2d(2),# 14 * 14 * 6 --> 10 * 10 * 16nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),# 10 * 10 * 16 --> 5 * 5 * 16nn.ReLU(),nn.MaxPool2d(2),# 5 * 5 * 16 --> 1 * 1 * 120# nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
)
self.fc = nn.Sequential(# 5 * 5 * 16 --> 1 * 1 * 120nn.Linear(400, 120),nn.ReLU(),# 1 * 1 * 120 --> 1 * 1 * 84nn.Linear(120, 84),nn.ReLU(),# 1 * 1 * 84 --> 1 * 1 * 3nn.Linear(84, 3),nn.Softmax(dim=1),
)
可视化结果
v2:LeNet5:3cnn+2fc,32*32
网络结构
self.cnn = nn.Sequential(# 32 * 32 * 3 --> 28 * 28 * 6nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),nn.ReLU(),# 28 * 28 * 6 --> 14 * 14 * 6nn.MaxPool2d(2),# 14 * 14 * 6 --> 10 * 10 * 16nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),# 10 * 10 * 16 --> 5 * 5 * 16nn.ReLU(),nn.MaxPool2d(2),# 5 * 5 * 16 --> 1 * 1 * 120nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),nn.ReLU(),
)
self.fc = nn.Sequential(# 5 * 5 * 16 --> 1 * 1 * 120# nn.Linear(400, 120),# nn.ReLU(),# 1 * 1 * 120 --> 1 * 1 * 84nn.Linear(120, 84),nn.ReLU(),# 1 * 1 * 84 --> 1 * 1 * 3nn.Linear(84, 3),nn.Softmax(dim=1),
)
可视化结果
v3:LeNet5:2cnn+bn+3fc,32*32
网络结构
self.cnn = nn.Sequential(# 32 * 32 * 3 --> 28 * 28 * 6nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, bias=False),nn.BatchNorm2d(num_features=6),nn.ReLU(),# 28 * 28 * 6 --> 14 * 14 * 6nn.MaxPool2d(2),# 14 * 14 * 6 --> 10 * 10 * 16nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, bias=False),nn.BatchNorm2d(num_features=16),# 10 * 10 * 16 --> 5 * 5 * 16nn.ReLU(),nn.MaxPool2d(2),# 5 * 5 * 16 --> 1 * 1 * 120# nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
)
self.fc = nn.Sequential(# 5 * 5 * 16 --> 1 * 1 * 120nn.Linear(400, 120),nn.ReLU(),# 1 * 1 * 120 --> 1 * 1 * 84nn.Linear(120, 84),nn.ReLU(),# 1 * 1 * 84 --> 1 * 1 * 3nn.Linear(84, 3),nn.Softmax(dim=1),
)
可视化结果
结论
cnn_v1可视化效果最稳定
对比
总数90(草莓30+树莓30+桑葚30),统计错误数据
cnn_v2优于cnn_v1
cnn优于fcnn
视频展示
总结
问题汇总
爬取数据
问题:获取过多无关数据,例如:草莓熊图片、草莓标准的商品、草莓的广告宣传
解决方案:更换爬取图片的关键字,例如:“草莓”–>“水果草莓”
数据清洗
问题1:无法打开的图片
解决方案:代码实现删除不是图片的文件
问题2:存在无法打开的图片
解决方案:代码实现删除无法打开的图片
问题3:错误信息的图片,例如:水果副产品、错误水果图片
解决方案:手动删除错误信息的图片
问题4:存在重复图片
解决方案:手动删除
问题5:1类中有2类图片
解决方案:手动检查所有图片,调整数据
数据预处理
问题:图片大小不同,获得的图片张量不同
解决方案:尺度缩放,统一大小
推理预测
问题:加载文件报错torch.load with map_location=torch.device(“cpu”) to map your storages to the CPU
原代码:
net.load_state_dict(torch.load(best_weight_path))
改为:
net.load_state_dict(torch.load(best_weight_path, map_location='cpu’))
最优参数
损失函数:交叉熵
网络层数:4层
尺度缩放:100*100
激活函数:ReLU
效果预览
v3:共19张图片,错误2