01——LenNet网络结构,图片识别

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。# 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的def forward(self, x):   # 正向传播# Pytorch Tensor的通道排序:[channel,height,width]'''卷积后的尺寸大小计算:N = (W-F+2P)/S + 1其中,默认的padding:0   stride:1①输入图片大小:WxW②Filter大小 FxF  (卷积核大小)③步长S④padding的像素数P'''x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1表示第一个维度batch需要自动推理x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transformsfrom pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose([transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',train=True,transform=transform,download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',train=False,transform=transform,download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')'''标准化处理:output = (input - 0.5) / 0.5反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮running_loss = 0.0  # 叠加,训练过程中的损失for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本inputs, labels = data   # 获取图像及其对应的标签optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签loss.backward()optimizer.step()running_loss += loss.item()if step % 500 == 499:with torch.no_grad():  # with 是一个上下文管理器outputs = net(test_image)  # [batch,10]predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0
print('Finished Training')save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

测试下展示图片:

运行下,train.py文件,看下正确率、损失率:

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNettransform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。with torch.no_grad():  # 不需要计算损失梯度outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;# torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。# .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()# .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/748823.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OceanBase中分区的运用

分区的概念 分区实质上是根据特定的规则,将表划分为若干个独立的物理存储单位。以MySQL为例,表会被拆分为多个物理文件。而在OceanBase​​​​​​​中,每个分区则表现为一个物理副本组,每个分区默认都拥有三个副本。 分区表的优…

lqb省赛日志[8/37]-[搜索·DFS·BFS]

一只小蒟蒻备考蓝桥杯的日志 文章目录 笔记DFS记忆化搜索 刷题心得小结 笔记 DFS 参考 深度优先搜索(DFS) 总结(算法剪枝优化总结) DFS的模板框架: function dfs(当前状态){if(当前状态 目的状态){}for(寻找新状态){if(状态合法){vis[访问该点];dfs(新状态);?…

一道题学会如何使用哈希表

给你一个整数数组 nums 和一个整数 k ,请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例 1: 输入:nums [1,1,1], k 2 输出:2示例 2: 输入:nums [1,2,3], …

视频素材哪里有?这几个高清无水印素材库看看

关于短视频素材,我知道大家都在找那种能让人一看就心旷神怡的地方,尤其是我们自媒体人,更是离不开这些优质的短视频素材来吸引观众的眼球。别着急,今天我就给大家安利几个网站,保证让你找到满意的短视频素材。 1&…

黑马微服务p30踩坑

报错详情 : orderservice开不起来 : 发生报错 : 然后检查了以下端口啥的 ,配置啥的都是没有问题的 ; 解决办法 : 1 . 修改nacos1,2,3中的端口,将conf 中 cluster.conf中 的 127.0.0.1 全部改成自己本机的真实ipv4地址; 本机真实ipv4地址查看 :…

C#求水仙花数

目录 1.何谓水仙花数 2.求三位数的水仙花数 3.在遍历中使用Math.DivRem方法再求水仙花数 1.何谓水仙花数 水仙花数(Narcissistic number)是指一个 n 位正整数,它的每个位上的数字的 n 次幂之和等于它本身。例如,153 是一个 3 …

SpringBoot3下Kafka分组均衡消费实现

首先添加maven依赖&#xff1a; <dependency><groupId>org.springframework.kafka</groupId><artifactId>spring-kafka</artifactId><version>2.8.11</version><exclusions><!--此处一定要排除kafka-clients&#xff0c;然…

C++进阶:详解多态(多态、虚函数、抽象类以及虚函数原理详解)

C进阶&#xff1a;详解多态&#xff08;多态、虚函数、抽象类以及虚函数原理详解&#xff09; 结束了继承的介绍&#xff1a;C进阶&#xff1a;详细讲解继承 那紧接着的肯定就是多态啦 文章目录 1.多态的概念2.多态的定义和实现2.1多态的构成条件2.2虚函数2.2.1虚函数的概念2…

网站的上线部署

域名&#xff1a; 原本用于标识互联网上计算机使用的是IP地址&#xff0c;但是由于IP地址不便于记忆&#xff0c;所以人们便设计出来比较容易记忆的域名&#xff0c;然后通过DNS将域名和IP地址关联&#xff0c;这样人们便可以通过记忆域名直接访问到对应的计算机 DNS服务器&am…

Ant Design Pro complete版本的下载及运行

前言 complete 版本提供了很多基础、美观的页面和组件&#xff0c;对于前端不太熟练的小白十分友好&#xff0c;可以直接套用或者修改提供的代码完成自己的页面开发&#xff0c;简直不要太爽。故记录一些下载的步骤。 环境 E:\code>npm -v 9.8.1E:\code>node -v v18.1…

Apache Paimon 使用之 Lookup Joins 解析

Lookup Join 是流式查询中的一种 Join&#xff0c;Join 要求一个表具有处理时间属性&#xff0c;另一个表由lookup source connector支持。 Paimon支持在主键表和附加表上进行Lookup Join。 a) 准备 创建一个Paimon表并实时更新它。 -- Create a paimon catalog CREATE CAT…

RabbitMQ学习总结-延迟消息

1.死信交换机 一致不被消费的信息/过期的信息/被标记nack/reject的信息&#xff0c;这些消息都可以进入死信交换机&#xff0c;但是首先要配置的有私信交换机。私信交换机可以再RabbitMQ的客户端上选定配置-dead-letter-exchange。 2.延迟消息 像我们买车票&#xff0c;外卖…

Python | 机器学习中的模型验证曲线

模型验证是数据科学项目的重要组成部分&#xff0c;因为我们希望选择一个不仅在训练数据集上表现良好&#xff0c;而且在测试数据集上具有良好准确性的模型。模型验证帮助我们找到一个具有低方差的模型。 什么是验证曲线 验证曲线是一种重要的诊断工具&#xff0c;它显示了机…

基于Java+SpringMVC+vue+element宠物管理系统设计实现

基于JavaSpringMVCvueelement宠物管理系统设计实现 博主介绍&#xff1a;5年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获取源…

当你能量越来越高的时候, 你开始真正理解每一个人,

在追求个人成长和内心平静的道路上&#xff0c;我们不可避免地会意识到能量频率对于我们自身以及周围人的影响。随着能量越来越高&#xff0c;我们开始真正理解每一个人&#xff0c;意识到世界并非黑白分明&#xff0c;而是充满了各种不同的状态和选择。在这篇博客中&#xff0…

计算机二级C语言的注意事项及相应真题-4-程序设计

目录 31.找出学生的最高分&#xff0c;由函数值返回32.计算并输出下列多项式的值33.将一个数字字符串转换成与其面值相同的长整型整数。可调用strlen函数求字符串的长度34.将字符串中的前导*号全部移到字符串的尾部。函数fun中给出的语句仅供参考35.将一组得分中&#xff0c;去…

算法---滑动窗口练习-7(串联所有单词的子串)

串联所有单词的子串 1. 题目解析2. 讲解算法原理3. 编写代码 1. 题目解析 题目地址&#xff1a;串联所有单词的子串 2. 讲解算法原理 算法的基本思想是使用滑动窗口来遍历字符串s&#xff0c;并利用两个哈希表&#xff08;hash1和hash2&#xff09;来统计窗口中子串的频次。 …

docker容器技术基础入门-1

文章目录 容器(Container)传统虚拟化与容器的区别Linux容器技术Linux NamespacesCGroupsLXCdocker基本概念docker工作方式docker容器编排 容器(Container) 容器是一种基础工具&#xff1b;泛指任何可以用于容纳其他物品的工具&#xff0c;可以部分或完全封闭&#xff0c;被用于…

#QT(定时轮播电子相册)

1.IDE&#xff1a;QTCreator 2.实验&#xff1a; &#xff08;1&#xff09;使用QOBJECT的TIMER &#xff08;2&#xff09;EVENT时间 &#xff08;3&#xff09;多定时器定时溢出判断 &#xff08;4&#xff09;QLABEL填充图片 3.记录 4.代码 widget.h #ifndef WIDGET_H…

批量查询快递不再难,前缀单号助你轻松搞定!

在快递业务日益繁忙的当下&#xff0c;批量查询快递单号成为了许多人的迫切需求。如何能够快速、准确地找到所需的快递单号呢&#xff1f;其实&#xff0c;利用前缀单号进行批量查询是一个高效且实用的方法。下面&#xff0c;就让我们一起了解如何利用前缀单号轻松查找快递单号…