【PyTorch 卷积】实战自定义的图片归类

前言

        卷积神经网络是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习的代表算法之一,它通过卷积层、池化层、全连接层等结构,可以有效地处理如时间序列和图片数据等。关于卷积的概念网络上也比较多,这里就不一一描述了。实战为主当然要从实际问题出发,用代码的方式加深印象。在写代码前,我先说一下为什么我要写这篇文章?

        之前我也用 Tensorflow.js 跟着别人试过图片分类,虽然结果是有了,但是对代码的理解和印象并不深刻。后来由于工作业务原因才接触 PyTorch,发现这个框架更好上手,整一圈后就想用这个把之前用得图片也实现一下分类。开始也是看文章实现,但是网上大部分都是用 MNIST 数据集实现的手写字识别,而业务中有时就是一些指定的不规则小众图片识别,所以下面就简单实现一个自定义的图片集归类。

流程

  • 根据自己的定义,收集图片并归类
  • 读取图片数据和归类标签,保存数据集
  • 固定图片大小 (会变形),归一化转张量
  • 定义超参数,损失函数和优化器等
  • 炼丹,重复查看损失值准确率等指标
  • 保存模型参数,加载测试图片分类效果

环境

  • Python 3.8
  • Torch 1.9.0
  • Pillow 10.0
  • Torchvision
  • Numpy
  • Pandas
  • Matplotlib

编码

        写代码前已经把需要的图片做好了分类,上面的依赖包也已经安装完毕。由于只是演示这里没有用预训练模型(ResNet、VGG),因为训练时要用的是 Tensor,所以需要先读取文件夹内的图片先转化为 PIL 的对象数据或 Numpy 数据,然后可以对图片进行调整,最后全都转成 Tensor(也可以跳过 PIL 直接转张量)。这里需要注意的是对灰彩图片通道,不同尺寸图的统一处理,就是灰色图的单通道要通过复制的方式创建三个通道,所以图片设置一样的像素大小。因为在卷积网络中,输入的通道数和输入大小要一致,不然可能在训练中报错。

图片数据生成

        这里就是遍历各个分类文件夹的图片转换为对象信息数据,和提取所有分类,分别保存到指定位置,当然也可以在这里划分训练数据,校验数据,测试数据,需要的可以扩展这里就跳过了。

# -*- coding: utf-8 -*-
import os
import pickle as pkl
import pandas as pd
from PIL import Imageall_cate = []
data_set = []
directory = "./data/train"
for index, data in enumerate(os.walk(directory)):root, dirs, files = dataif index == 0:all_cate += dirselse:sorted(all_cate)root_names = root.split("\\")dir_name = root_names[-1]for img in files:img_path = root + "\\" + imgimg_np = Image.open(img_path)dict = {}dict['img_np'] = img_npdict['label'] = all_cate.index(dir_name) + 1data_set.append(dict)# 字典转DataFrame
df = pd.DataFrame(data_set)
pkl.dump(df, open('data/train_dataset.p', 'wb'))
open("data/all_cate.txt", encoding="utf-8", mode="w+").write("\n".join(all_cate))print("存档数据成功~")
批量数据集标准化

        这里是读取序列化的图片信息,对所有图片统一像素 (一般配置电脑最好在 100px 以内,不然会很卡) 并标准归一化后,转换为 Tensor。然后判断图片通道数,如果是灰色图,可以复制张量三次以创建三个通道,最后通过 torch 的 DataLoader 在训练前完成数据集的加载。

# -*- coding: utf-8 -*-
import torch
from torchvision import transforms
import pickle as pkl
from torch.utils.data import Datasetclass DataSet(Dataset):def __init__(self, pkl_file):df = pkl.load(open(pkl_file, 'rb'))self.dataFrame = dfdef __len__(self):return len(self.dataFrame)def __getitem__(self, item):img_np = self.dataFrame.iloc[item, 0]label = self.dataFrame.iloc[item, 1]transform = transforms.Compose([transforms.Resize((100, 100)),  # 根据需要调整图像大小transforms.ToTensor(),transforms.Normalize([0.5], [0.5])    # 标准归一化, p1.均值  p2.方差])image_tensor = transform(img_np)if image_tensor.shape[0] == 1:  image_tensor = image_tensor.repeat(3, 1, 1)  res = {'img_tensor': image_tensor,'label': torch.LongTensor([label-1])    # 需要实际的索引值}return res
神经网络模型

        这里创建的是卷积神经网络,接收 3 通道,第一层卷积层卷积核 3x3,输出 25 维张量,通过批标准化(BatchNorm2d)进行归一化处理,最后通过 ReLU 激活函数进行非线性变换。第一层池化使用 2x2 的最大池化操作对卷积后的特征图进行下采样。第二层也是卷积和对应的池化,最后是全连接层。将经过池化的特征图展平,然后通过一个有 1024 个神经元的全连接层,再通过 ReLU 激活函数进行非线性变换。之后是一个有 128 个神经元的全连接层,最后再通过 ReLU 激活函数进行非线性变换,输出 5 个神经元代表分类的概率分布。

# -*- coding: utf-8 -*-
import torch.nn as nn
import torch
import math
import torch.functional as Fclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(3, 25, kernel_size=3),nn.BatchNorm2d(25),nn.ReLU(inplace=True))self.layer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))self.layer3 = nn.Sequential(nn.Conv2d(25, 50, kernel_size=3),nn.BatchNorm2d(50),nn.ReLU(inplace=True))self.layer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Sequential(nn.Linear(50 * 23 * 23, 1024),nn.ReLU(inplace=True),nn.Linear(1024, 128),nn.ReLU(inplace=True),nn.Linear(128, 5))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = x.view(x.size(0), -1)x = self.fc(x)return x
开始训练
# -*- coding:utf-8 -*-
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from data_set import DataSet
from torch.autograd import Variable
from utils import *
import cnn
import torch.nn as nn
import numpy as np
import torch.optim as optim# 定义超参数
batch_size = 1
learning_rate = 0.02
num_epoches = 1# 加载图片tensor训练集
tain_dataset = DataSet("data/train_dataset.p")
train_loader = DataLoader(tain_dataset, batch_size=batch_size, shuffle=True)model = cnn.CNN()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练模型
train_loses = []
records = []
for i in range(num_epoches):for ii, data in enumerate(train_loader):img = data['img_tensor']label = data['label'].view(-1)optimizer.zero_grad()out = model(img)loss = criterion(out, label)train_loses.append(loss.data.item())loss.backward()optimizer.step()if ii % 50 == 0:print('epoch: {}, loop: {}, loss: {:.4}'.format(i, ii, np.mean(train_loses)))records.append([np.mean(train_loses)])# 绘制模型的损失,准确率走势图
train_loss = [data[0] for data in records]
plt.plot(train_loss, label = 'Train Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()
plt.show()# 模型评估(略)
# model.eval()# 模型保存
torch.save(model, 'params/cnn_imgs_02.pkl')
模型检测

        训练完成保存参数到本地,下面就是将加载进的参数来测试其他图片的分类效果,同样的也是将指定图片和训练时一样的转换操作,最后将预测结果取出最大分布索引值,根据索引就可以匹配出分类名称了。另一个是工具函数,将 tensor 格式的图片在预测结果后显示在 pyplot 中。

# -*- coding:utf-8 -*-
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from data_set import DataSet
from utils import *
import torchvision
from PIL import Image
from torchvision import transforms
import cnndef imshow(img):img = img / 2 + 0.5npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()model = torch.load("params/cnn_imgs_02.pkl")img_path= "imgs/05.jpg"
img_np = Image.open(img_path)
transform = transforms.Compose([transforms.Resize((100, 100)),  transforms.ToTensor(),transforms.Normalize([0.5], [0.5])  
])
image_tensor = transform(img_np)# 如果是灰度图片
if image_tensor.shape[0] == 1:  image_tensor = image_tensor.repeat(3, 1, 1)  image_tensor = image_tensor.view(-1, 3, 100, 100)predict = model(image_tensor)
indices = torch.max(predict, 1)[1].item()all_cate = []
for line in open("data/all_cate.txt", encoding="utf-8", mode="r"):all_cate.append(line.strip())cate_name = ""
try:cate_name = all_cate[indices]
except ValueError:cate_name = "未知"print("识别结果是:", cate_name)
# imshow(torchvision.utils.make_grid(image_tensor))
# 原图显示
img_np.show()
exit()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/131024.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【强化学习】17 ——DDPG(Deep Deterministic Policy Gradient)

文章目录 前言DDPG特点 随机策略与确定性策略DDPG:深度确定性策略梯度伪代码代码实践 前言 之前的章节介绍了基于策略梯度的算法 REINFORCE、Actor-Critic 以及两个改进算法——TRPO 和 PPO。这类算法有一个共同的特点:它们都是在线策略算法&#xff0c…

【踩坑及思考】浏览器存储 cookie 最大值超过 4kb,或 http 头 cookie 超过限制值

背景 本地生产环境:超过最大值 cookie token 不存储;客户生产环境:打开系统空白,且控制台报 http 400 错误; 出现了两种现象 现象一:浏览器对大于 4kb 的 cookie 值不存储 导致用户名密码登录&#xff…

解决问题 [Vue warn]: Missing required prop: “index“

vue项目控制台报错 [Vue warn]: Missing required prop: “index” 出现这个报错原因是<el-submenu></el-submenu>标签中缺少index属性&#xff0c;需要加上才能不报错 解决办法是&#xff1a; <el-submenu index""></el-submenu>

linux 下 物理迁移 mysql 数据库 不能启动问题

1、授权问题 # chown -R 777 /app/db/mysql 2、/etc/my.cnf配置问题 [mysqld] basedir/app/db/mysql datadir/app/db/mysql/data socket/app/db/mysql/mysql.sock.lock innodb_buffer_pool_size128M innodb_force_recovery 1 symbolic-links0 [mysqld_safe] log-error/app/…

linux驱动开发环境搭建

使用的是parallel 创建的ubuntu 16.04 ubuntu20.04虚拟机 源码准备 # 先查看本机版本 $ uname -r 5.15.0-86-generic# 搜索相关源码 $ sudo apt-cache search linux-source [sudo] password for showme: linux-source - Linux kernel source with Ubuntu patches linux-sourc…

笔记软件 Keep It mac v2.3.3中文版新增功能

Keep It mac是一款专为 Mac、iPad 和 iPhone 设计的笔记和信息管理应用程序。它允许用户在一个地方组织和管理他们的笔记、网络链接、PDF、图像和其他类型的内容。Keep It 还具有标记、搜索、突出显示、编辑和跨设备同步功能。 Keep It for mac更新日志 修复了更改注释或富文本…

Nacos-2.2.2源码修改集成高斯数据库GaussDB,postresql

一 &#xff0c;下载代码 Release 2.2.2 (Apr 11, 2023) alibaba/nacos GitHub 二&#xff0c; 执行打包 mvn -Prelease-nacos -Dmaven.test.skiptrue -Drat.skiptrue clean install -U 或 mvn -Prelease-nacos ‘-Dmaven.test.skiptrue’ ‘-Drat.skiptrue’ clean instal…

【H.264】RTP h264 码流 实例解析分析 3 : webrtc

【srs】SRS检测IBMF还是annexb 【H.264】RTP h264 码流 实例解析分析 2 : mediasoup收包 mediasoup 并没完整解析rtp包的内容,可能与mediasoup 只需要转发,不需要解码有关系。 webrtc 本身都是全的。 m98代码,先说关键: webrtc的VideoRtpDepacketizer 第一:对RTPVideoType…

成员变量为动态数据时不可轻易使用

问题描述 业务验收阶段&#xff0c;遇到了一个由于成员变量导致的线程问题 有一个kafka切面&#xff0c;用来处理某些功能在调用前后的发送消息&#xff0c;资产类型type是成员变量定义&#xff1b; 资产1类型推送消息是以zichan1为节点&#xff1b;资产2类型推送消息是以zi…

社区牛奶智能售货机为你带来便利与实惠

社区牛奶智能售货机为你带来便利与实惠 低成本&#xff1a;社区牛奶智能货机的最大优势在于成本低廉&#xff0c;租金和人工开支都很少。大部分时间&#xff0c;货柜都是由无人操作来完成销售任务。 购买便利&#xff1a;社区居民只需通过手机扫码支付&#xff0c;支付后即可自…

二、计算机组成原理与体系结构

&#xff08;一&#xff09;数据的表示 不同进制之间的转换 R 进制转十进制使用按权展开法&#xff0c;其具体操作方式为&#xff1a;将 R 进制数的每一位数值用 Rk 形式表示&#xff0c;即幂的底数是 R &#xff0c;指数为 k &#xff0c;k 与该位和小数点之间的距离有关。当…

【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)

目录 0. 前言 1. Cifar10数据集 2. AlexNet网络模型 2.1 AlexNet的网络结构 2.2 激活函数ReLu 2.3 Dropout方法 2.4 数据增强 3. 使用GPU加速进行批量训练 4. 网络模型构建 5. 训练过程 6. 完整代码 0. 前言 按照国际惯例&#xff0c;首先声明&#xff1a;本文只是我…

[开源]企业级在线办公系统,基于实时音视频完成在线视频会议功能

一、开源项目简介 企业级在线办公系统 本项目使用了SpringBootMybatisSpringMVC框架&#xff0c;技术功能点应用了WebSocket、Redis、Activiti7工作流引擎&#xff0c; 基于TRTC腾讯实时音视频完成在线视频会议功能。 二、开源协议 使用GPL-3.0开源协议 三、界面展示 部分…

2024天津理工大学中环信息学院专升本机械设计制造自动化专业考纲

2024年天津理工大学中环信息学院高职升本科《机械设计制造及其自动化》专业课考试大纲《机械设计》《机械制图》 《机械设计》考试大纲 教 材&#xff1a;《机械设计》&#xff08;第十版&#xff09;&#xff0c;高等教育出版社&#xff0c;濮良贵、陈国定、吴立言主编&#…

1.如何实现统一的API前缀-web组件篇

文章目录 1. 问题的由来2.实现原理3. 总结 1. 问题的由来 系统提供了 2 种类型的用户&#xff0c;分别满足对应的管理后台、用户 App 场景。 两种场景的前缀不同&#xff0c;分别为/admin-api/和/app-api/&#xff0c;都写在一个controller里面&#xff0c;显然比较混乱。分开…

AI:60-基于深度学习的瓜果蔬菜分类识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

网络基础扫盲-多路转发

博客内容&#xff1a;多路转发的常见方式select&#xff0c;poll&#xff0c;epoll 文章目录 一、五种IO模型二、多路转发的常见接口1.select2、poll3、epoll 总结 前言 Linux下一切皆文件&#xff0c;是文件就会存在IO的情况&#xff0c;IO的方式决定了效率的高低。 一、五种…

基于java+springboot+vue在线选课系统

项目介绍 本系统结合计算机系统的结构、概念、模型、原理、方法&#xff0c;在计算机各种优势的情况下&#xff0c;采用JAVA语言&#xff0c;结合SpringBoot框架与Vue框架以及MYSQL数据库设计并实现的。员工管理系统主要包括个人中心、课程管理、专业管理、院系信息管理、学生…

Cube MX 开发高精度电流源跳坑过程/SPI连接ADS1255/1256系列问题总结/STM32 硬件SPI开发过程

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 1.使用STM32F系列开发一款高精度恒流电源&#xff0c;用到了24位高精度采样芯片ADS1255/ADS1256系列。 2.使用时发现很多的坑&#xff0c;详细介绍了每个坑的具体情况和实际的解决办法。 坑1&#xff1a;波特率设置…

如何使用Ruby 多线程爬取数据

现在比较主流的爬虫应该是用python&#xff0c;之前也写了很多关于python的文章。今天在这里我们主要说说ruby。我觉得ruby也是ok的&#xff0c;我试试看写了一个爬虫的小程序&#xff0c;并作出相应的解析。 Ruby中实现网页抓取&#xff0c;一般用的是mechanize&#xff0c;使…