Pytorch进阶教学——训练一个图像分类模型(GPU)

目录

1、前言 

2、数据集介绍

3、获取数据

4、创建网络

5、训练模型

6、测试模型

6.1、测试整个模型准确率

6.2、测试单张图片


1、前言 

  • 编写一个可以分类蚂蚁和蜜蜂图片的模型,使用数据集对卷积神经网络进行训练。训练后的模型可以对蚂蚁或蜜蜂的图片进行检测。
  • 使用anaconda新建一个虚拟环境,安装好pytorch。后续缺什么包就安装什么包即可。
  • 使用pycharm新建一个项目,配置好环境。

2、数据集介绍

  • 使用的数据集为蚂蚁和蜜蜂的图片,分为训练集和测试集
  • 【注】数据集下载地址。

3、获取数据

  • 代码中获取数据集使用的是txt文件,所以首先需要提取全部图片的地址和标签放入txt文件中。
  • 下述代码为python提取全部图片地址和标签导出为txt文件的脚本。(自行修改)
    • import os  # 导入os模块,用于操作文件路径等操作系统相关功能。def get_file_name(file_path, output_file, type):  # 绝对路径path_list = os.listdir(file_path)  # 列出指定路径下的所有文件和文件夹,并将结果存储在path_list中with open(output_file, 'a') as file:for filename in path_list:all_file_path = os.path.join(file_path, filename)  # 拼接路径file.write(all_file_path + ' ' + type + '\n')if __name__ == '__main__':ants_file_path = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train\ants"bees_file_path = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train\bees"output_file = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train.txt"get_file_name(ants_file_path, output_file, 'ants')get_file_name(bees_file_path, output_file, 'bees')
    •  
  • 将全部地址修改为相对地址。
    • 使用替换操作实现。例如:
  • 最后txt文件的内容如下:
  • 新建一个dataset.py文件。
    • # 读取数据
      import torch
      import torchvision.transforms as transforms
      from PIL import Image# 读取数据类
      class MyDataset(torch.utils.data.Dataset):  # 继承构建自定义数据集的基类def __init__(self, datatxt, datatransform):datas = open(datatxt, 'r').readlines()  # 按行读取,每行包含图像路径和标签self.images = []self.labels = []self.transform = datatransformfor data in datas:item = data.strip().split(' ')  # 去除首尾空格并按空格分割# 分别将图像路径和标签添加到self.images和self.labels列表中self.images.append(item[0])  # 路径self.labels.append(item[1])  # 标签returndef __len__(self):return len(self.images)# 获取数据集中的一个样本。接收一个索引item,根据索引获取对应的图像路径和标签def __getitem__(self, item):imagepath, label = self.images[item], self.labels[item]image = Image.open(imagepath)  # 打开图片return self.transform(image), label  # 返回转换后的图像和对应的标签# 用于测试
      if __name__ == '__main__':# 利用txt文件读取图片信息,txt文件包括图片路径和标签traintxt = './hymenoptera_data/train.txt'valtxt = './hymenoptera_data/val.txt'# 图片转换形式traindata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格式transforms.Resize(60),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整transforms.RandomCrop(48),  # 裁剪图片,随机裁剪成高度和宽度均为48像素的部分transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作])valdata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格transforms.Resize(48),  # 调整图像大小,调整为高度或宽度为48像素,另一边按比例调整transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 加载数据traindataset = MyDataset(traintxt, traindata_transfomer)valdataset = MyDataset(valtxt, valdata_transfomer)print("测试集:" + str(traindataset.__len__()))print("训练集:" + str(valdataset.__len__()))
  • 单独运行结果:(只用于测试)

4、创建网络

  • 新建一个net.py文件。
    • 其中创建了一个简单的三层卷积神经网络。
    • # 三层卷积神经网络
      import torch# 卷积神经网络类
      class SimpleConv3(torch.nn.Module):  # 继承创建神经网络的基类def __init__(self, classes):super(SimpleConv3, self).__init__()# 卷积层self.conv1 = torch.nn.Conv2d(3, 16, 3, 2, 1)  # 输入通道3,输出通道16,3*3的卷积核,步长2,边缘填充1self.conv2 = torch.nn.Conv2d(16, 32, 3, 2, 1)  # 输入通道16,输出通道32,3*3的卷积核,步长2,边缘填充1self.conv3 = torch.nn.Conv2d(32, 64, 3, 2, 1)  # 输入通道32,输出通道64,3*3的卷积核,步长2,边缘填充1# 全连接层self.fc1 = torch.nn.Linear(2304, 100)self.fc2 = torch.nn.Linear(100, classes)def forward(self, x):# 第一次卷积x = torch.nn.functional.relu(self.conv1(x))  # relu为激活函数# 第二次卷积x = torch.nn.functional.relu(self.conv2(x))# 第三次卷积x = torch.nn.functional.relu(self.conv3(x))# 展开成一维向量x = x.view(x.size(0), -1)x = torch.nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 用于测试
      if __name__ == '__main__':inputs = torch.rand((1, 3, 48, 48))  # 生成一个随机的3通道、48x48大小的张量作为输入net = SimpleConv3(2)  # 二分类output = net(inputs)print(output)
  • 单独运行结果:(只用于测试)

5、训练模型

  • 新建一个train.py文件。
    • 其中可自行设置的参数都有标出。 
    • # 训练模型
      import matplotlibmatplotlib.use('TkAgg')
      import matplotlib.pyplot as plt
      from dataset import MyDataset
      from net import SimpleConv3
      import torch
      import torchvision.transforms as transforms
      from torch.optim import SGD  # 优化相关
      from torch.optim.lr_scheduler import StepLR  # 优化相关
      from sklearn import preprocessing  # 处理label# 图片转换形式
      traindata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格式transforms.Resize(60, antialias=True),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整,antialias=True启用了抗锯齿功能transforms.RandomCrop(48),  # 裁剪图片,随机裁剪成高度和宽度均为48像素的部分transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作
      ])if __name__ == '__main__':traintxt = './hymenoptera_data/train.txt'valtxt = './hymenoptera_data/val.txt'# 加载数据traindataset = MyDataset(traintxt, traindata_transfomer)# 创建卷积神经网络net = SimpleConv3(2)  # 二分类# 使用GPUdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net.to(device)# 测试GPU是否能使用# print("The device is gpu later?:", next(net.parameters()).is_cuda)# print("The device is gpu,", next(net.parameters()).device)# 将数据提供给模型使用traindataloader = torch.utils.data.DataLoader(traindataset, batch_size=128, shuffle=True,num_workers=1)  # batch_size可以自行调节# 优化器optim = SGD(net.parameters(), lr=0.1, momentum=0.9)  # 使用随机梯度下降(SGD)作为优化器,学习率0.1,动量0.9,加速梯度下降过程,lr可自行调节criterion = torch.nn.CrossEntropyLoss()  # 使用交叉熵损失作为损失函数lr_step = StepLR(optim, step_size=200, gamma=0.1)  # 学习率调度器,动态调整学习率,每200个epoch调整一次,每次调整缩小为原来的0.1倍,step_size可自行调节epochs = 5  # 训练次数accs = []losss = []# 训练循环for epoch in range(0, epochs):batch = 0running_acc = 0.0  # 精度running_loss = 0.0  # 损失for data in traindataloader:batch += 1imputs, labels = data# 将标签从元组转换为tensor类型labels = preprocessing.LabelEncoder().fit_transform(labels)labels = torch.as_tensor(labels)# 利用GPU训练模型imputs = imputs.to(device)labels = labels.to(device)# 将数据输入至网络output = net(imputs)# 计算损失loss = criterion(output, labels)# 平均准确率acc = float(torch.sum(labels == torch.argmax(output, 1))) / len(imputs)# 累加损失和准确率,后面会除以batchrunning_acc += accrunning_loss += loss.data.item()optim.zero_grad()  # 清空梯度loss.backward()  # 反向传播optim.step()  # 更新参数lr_step.step()  # 更新优化器的学习率# 一次训练的精度和损失running_acc = running_acc / batchrunning_loss = running_loss / batchaccs.append(running_acc)losss.append(running_loss)print('epoch=' + str(epoch) + ' loss=' + str(running_loss) + ' acc=' + str(running_acc))# 保存模型torch.save(net, 'model.pth')  # 保存模型的权重和结构x = torch.randn(1, 3, 48, 48).to(device)  # # 生成一个随机的3通道、48x48大小的张量作为输入,新建的张量也要送到GPU中net = torch.load('model.pth')  # 从保存的.pth文件中加载模型net.train(False)  # 设置模型为推理模式,意味着不会进行梯度计算或反向传播torch.onnx.export(net, x, 'model.onnx')  # 使用ONNX格式导出模型# 接受模型net、示例输入x和导出的文件名model.onnx作为参数# 可视化结果fig = plt.figure()plot1, = plt.plot(range(len(accs)), accs)  # 创建一个图形对象plot1,绘制accs列表中的数据plot2, = plt.plot(range(len(losss)), losss)  # 创建另一个图形对象plot2,绘制losss列表中的数据plt.ylabel('epoch')  # 设置y轴的标签为epochplt.legend(handles=[plot1, plot2], labels=['acc', 'loss'])  # 创建图例,指定图表中不同曲线的标签plt.show()  # 展示所绘制的图表
  • 【注】本项目使用的是GPU训练模型。如果GPU可以获得,但是无法使用,可能是pytorch的版本不对,需要重新安装。
  • 运行结果:
  • 保存后的模型如下:

6、测试模型

6.1、测试整个模型准确率

  • 利用测试集,测试整个模型的准确率。
  • 新建一个test.py文件。
    • # 测试整个模型的准确率
      import torch
      import torchvision.transforms as transforms
      from dataset import MyDataset  # 您的数据集类
      from sklearn import preprocessing  # 处理label# 定义测试集的数据转换形式
      valdata_transfomer = transforms.Compose([transforms.ToTensor(),  # 转为Tensor格式transforms.Resize(60, antialias=True),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整,antialias=True启用了抗锯齿功能transforms.CenterCrop(48),  # 中心裁剪图片,裁剪成高度和宽度均为48像素的部分transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作
      ])if __name__ == '__main__':valtxt = './hymenoptera_data/val.txt'  # 测试集数据路径# 加载测试集数据valdataset = MyDataset(valtxt, valdata_transfomer)# 加载已训练好的模型,利用GPU进行测试device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net = torch.load('model.pth').to(device)net.eval()  # 将模型设置为评估模式,意味着不会进行梯度计算或反向传播# 使用 DataLoader 加载测试集数据valdataloader = torch.utils.data.DataLoader(valdataset, batch_size=1, shuffle=False)correct = 0  # 被正确预测的样本数total = 0  # 测试样本数# 测试模型with torch.no_grad():for data in valdataloader:images, labels = data# 将标签从元组转换为tensor类型labels = preprocessing.LabelEncoder().fit_transform(labels)labels = torch.as_tensor(labels)# 利用GPU训练模型images, labels = images.to(device), labels.to(device)outputs = net(images)  # 输入图像并获取模型预测结果_, predicted = torch.max(outputs.data, 1)  # 获取预测值中最大概率的索引total += labels.size(0)  # 累计测试样本数量correct += (predicted == labels).sum().item()  # 计算正确预测的样本数量# 计算并输出模型在测试集上的准确率accuracy = 100 * correct / totalprint('Test Accuracy: {:.2f}%'.format(accuracy))
  • 运行结果:
    • 因为训练模型时只迭代了200次,所以准确率并不高。可以尝试提高训练次数,提高准确率。 

6.2、测试单张图片

  • 使用训练后的模型,对单张图片进行预测。
  • 新建一个testone.py文件。
    • import torch
      from PIL import Image
      import torchvision.transforms as transforms# 定义图片预处理转换
      image_transforms = transforms.Compose([transforms.Resize(60, antialias=True),  # 调整图像大小transforms.CenterCrop(48),  # 中心裁剪transforms.ToTensor(),  # 转为Tensor格式transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化处理
      ])# 定义类别映射字典
      class_mapping = {0: "ant",1: "bee"
      }# 加载已训练好的模型,利用GPU测试
      device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
      net = torch.load('model.pth').to(device)
      net.eval()  # 将模型设置为评估模式,意味着不会进行梯度计算或反向传播# 加载要测试的图片
      image_path = './hymenoptera_data/val/bees/26589803_5ba7000313.jpg'  # 图片路径
      input_image = Image.open(image_path)  # 加载图片
      input_tensor = image_transforms(input_image).unsqueeze(0)  # 对图片进行预处理转换,并增加 batch 维度# 将输入数据移动到GPU上
      input_tensor = input_tensor.to(device)# 使用模型进行预测
      with torch.no_grad():output = net(input_tensor)_, predicted = torch.max(output, 1)  # 在张量中沿指定维度找到最大值及其对应的索引# 输出预测结果
      predicted_class = predicted.item()  # 得到预测的标签
      predicted_label = class_mapping[predicted_class]  # 将标签转换为文字
      print(f"The predicted class for the image is: {predicted_label}")
  • 运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/196085.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【广州华锐互动】VR沉浸式体验铝厂安全事故让伤害教育更加深刻

随着科技的不断发展,虚拟现实(VR)技术已经逐渐渗透到各个领域,为我们的生活带来了前所未有的便捷和体验。在安全生产领域,VR技术的应用也日益受到重视。 VR公司广州华锐互动就开发了多款VR安全事故体验系统&#xff0c…

蓝桥杯-03-蓝桥杯学习计划

蓝桥杯-03-蓝桥杯学习计划 参考资料 相关文献 报了蓝桥杯比赛,几乎零基础,如何准备,请大牛指导一下。谢谢? 蓝桥杯2022各组真题汇总(完整可评测) 基础学习 C语言网 ACM竞赛入门,蓝桥杯竞赛指南 廖雪峰的官方官网 算法题单 洛谷…

vue,nvue,uniapp,到底是什么

vue,nvue,uniapp,到底是什么? 发展猜想: 开发移动端软件,一般是控件逻辑,可拖动控件android studio都给你设计好了。 开发web页面时,用vue,vue是前端框架。主要是终端设备通过浏览器进行访问&#xff08…

ubuntu20.04使用LIO-SAM对热室空间进行重建

一、安装LIO-SAM 1.环境配置 默认已经安装过ros sudo apt-get install -y ros-Noetic-navigation sudo apt-get install -y ros-Noetic-robot-localization sudo apt-get install -y ros-Noetic-robot-state-publisher 安装 gtsam(如果是18.04的ubuntu直接按照官网配置&…

C++ 基础篇

目录 C开发概述 C特点 C跨平台的原因 C编译器 C库 操作系统API C基本概念 注释 变量 常量 两种定义常量方式的区别 表示符命名规则 常见的关键字 数据类型 整型 浮点数 字符型 转义字符 字符串型 布尔类型 运算符 算术运算符 赋值运算符 比较运算符 逻…

【VScode】超详细图片讲解下载安装、环境配置、编译执行、调试

这里是目录 VScode是什么?VScode的下载和安装环境介绍安装中文插件 配置VScodeC/C开发环境下载和配置MinGW-w64 编译器套件下载:配置: 安装C/C插件在VScode上编写代码设置C/C编译选项创建执行任务编译执行如果想写其他代码在同一个文件夹在不…

springboot 整合 Spring Security 中篇(RBAC权限控制)

1.先了解RBAC 是什么 RBAC(Role-Based Access control) ,也就是基于角色的权限分配解决方案 2.数据库读取用户信息和授权信息 1.上篇用户名好授权等信息都是从内存读取实际情况都是从数据库获取; 主要设计两个类 UserDetails和UserDetailsService 看下…

新媒体营销模拟实训室解决方案

一、引言 随着互联网的发展,新媒体已成为企业进行营销和品牌推广的重要渠道。然而,对于许多企业来说,如何在新媒体上进行有效的营销仍是一大挑战。为了解决这个问题,我们推出了一款新媒体营销模拟实训室解决方案,以帮…

【文末送书】Python OpenCV从入门到精通

文章目录 🍔简介opencv🌹内容简介🛸编辑推荐🎄导读🌺彩蛋 🍔简介opencv OpenCV(Open Source Computer Vision Library)是一个开源的计算机视觉库,提供了丰富的图像处理和…

java学习part31String

142-常用类与基础API-String的理解与不可变性_哔哩哔哩_bilibili 1.String 2.字符串常量池 变更储存区的原因是加快被gc的频率 比地址,equals比内容 3.字符串连接 s3s4都是字符串常量,后面几个会利用StringBuilder的toString()&a…

JAVA全栈开发 day16_MySql01

一、数据库 1.数据储存在哪里? 硬盘、网盘、U盘、光盘、内存(临时存储) 数据持久化 使用文件来进行存储,数据库也是一种文件,像excel ,xml 这些都可以进行数据的存储,但大量数据操作&#x…

C#网络编程TCP程序设计(Socket类、TcpClient类和 TcpListener类)

目录 一、Socket类 1.Socket类的常用属性及说明 2.Socket类的常用方法及说明 二、TcpClient类 三、TcpListener类 四、示例 1.源码 2.生成效果 TCP(Transmission Control Protocol)是一种面向连接的、可靠的、基于字节流的传输层通信协议。在C#中,TCP程序设…

react-flip-move结合array-move实现前端列表置顶效果

你有没有遇到这样的需求?点击左侧列表项,则像聊天会话窗口一样将被点击的列表项置顶。 如果只是单纯的置顶的话,直接使用array-move就可以实现了,但置顶效果多少有点突兀~ 先上代码,直接使用array-move的情况&#xf…

数据可视化私有化部署:为何成本居高不下?

尽管在可视化设计这行干了好多年,也接手过不少项目,但昂贵的私有化部署费用总能让我发出由衷的感叹:“这几十万一年也太贵了!”。可以预见,数据可视化软件私有化部署所带来的高昂成本,将是许多企业面临的问…

Jmeter进行压力测试不为人知的秘密

jmeter是apache公司基于java开发的一款开源压力测试工具,体积小,功能全,使用方便,是一个比较轻量级的测试工具,使用起来非常简单。因为jmeter是java开发的,所以运行的时候必须先要安装jdk才可以。jmeter是免…

每日一练【快乐数】

一、题目描述 202. 快乐数 编写一个算法来判断一个数 n 是不是快乐数。 「快乐数」 定义为: 对于一个正整数,每一次将该数替换为它每个位置上的数字的平方和。然后重复这个过程直到这个数变为 1,也可能是 无限循环 但始终变不到 1。如果这…

Elasticsearch高级

文章目录 一.数据聚合二.RestAPI实现聚合三.ES自动补全(联想)四.数据同步五.elasticsearch集群 一.数据聚合 在ES中的数据聚合(aggregations)可以近似看做成mysql中的groupby分组,聚合可以实现对文档数据的统计、分析、运算,常见的聚合的分类有以下几种…

基于APM(PIX)飞控和mission planner制作遥控无人车-从零搭建自主pix无人车普通舵机转向无人车-1(以乐迪crossflight飞控为例)

1.前期准备 准备通过舵机转向的无人车地盘、遥控器、地面站电脑、飞控等。安装驱动程序、端口程序、netframwork等,不再赘述。 2.安装固件 安装ardurover固件,如果在线失败,选择官方最新的固件下载到本地,选择本地安装。 3.调试…

智能仓库PTL管理系统

清晰电子墨水屏显示,无纸化作业,超低功耗 无线通信,穿透力强,极简部署 支持声光提醒,极大提高作业效率 适用场景:工厂,仓库,物流,货品分类等等