1、介绍
GhostNet 文章地址:[1911.11907] GhostNet: More Features from Cheap Operations (arxiv.org)
主要思想:
特征提取的时候,很多特征图是具有高度相似性的,也就是说存在许多的冗余特征图。
从另一个角度想,利用一系列的线性变化,以很小的代价生成许多能从原始特征发掘所需信息的“幻影”特征图呢
冗余的特征图是非常有必要的,可以保证网络对输入数据的理解更为全面。
ghostnet 的版本,本人在github上搜到了三个版本,这里干脆一起实现了
文末有项目下载
2、代码解释
代码目录如下:
- 红色框为ghostnet 的主干网络
- 蓝色框为数据集,数据保存在不同目录中,如果有测试集的话,也放在这里即可
- 绿色框为训练生成的结果
2.1 训练脚本 train
传入参数如下:
网络的输出个数,代码会计算数据集,然后自动设定!
都是很常规的参数,优化器可供选择的有SGD、Adam,学习率采用自适应衰减
2.2 评估脚本 val
这里评估的代码从训练中独立出来,参数如下:
- pth 传入测试集。没有的话,传入验证集也可以
- 数据集的mean和std在训练日志log文件中可以找到
2.3 训练过程和结果
运行train脚本如下:
评估脚本如下:这里采用的评估指标是混淆矩阵、F1分数等等
训练结果:
weights 下有最好的权重和最后的权重文件
训练和验证集的loss、acc曲线
注意:这里val 比 train 的acc高,因为数据划分不平衡所致,val集数量加多一点即可
学习率衰减:
训练日志:这里有mean和std
数据预处理的可视化:
3、使用
项目地址: 基于ghostNet网络对将香蕉5种不同阶段成熟度的分类【包含数据集+代码+训练结果】资源-CSDN文库
配置GPU的torch训练,参考:Pytorch 配置 GPU 环境_pytorch gpu-CSDN博客
新建虚拟环境:conda create -n ghost python=3.8
激活虚拟环境:conda activate ghost
安装库文件即可:pip install -r requirements.txt
训练自己数据集,将数据集摆放好即可