Pytorch实战01——CIAR10数据集

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。# 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的def forward(self, x):   # 正向传播# Pytorch Tensor的通道排序:[channel,height,width]'''卷积后的尺寸大小计算:N = (W-F+2P)/S + 1其中,默认的padding:0   stride:1①输入图片大小:WxW②Filter大小 FxF  (卷积核大小)③步长S④padding的像素数P'''x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1代表第一个维度x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transformsfrom pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose([transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',train=True,transform=transform,download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',train=False,transform=transform,download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')'''标准化处理:output = (input - 0.5) / 0.5反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮running_loss = 0.0  # 叠加,训练过程中的损失for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本inputs, labels = data   # 获取图像及其对应的标签optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签loss.backward()optimizer.step()running_loss += loss.item()if step % 500 == 499:with torch.no_grad():  # with 是一个上下文管理器outputs = net(test_image)  # [batch,10]predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0
print('Finished Training')save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNettransform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。with torch.no_grad():  # 不需要计算损失梯度outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;# torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。# .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()# .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/742043.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

吴恩达机器学习笔记 十七 通过偏差与方差诊断性能 正则化 偏差 方差

高偏差(欠拟合):在训练集上表现得也不好 高方差(过拟合):J_cv要远大于J_train 刚刚好:J_cv和J_train都小 J_cv和J_train与拟合多项式阶数的关系 从一阶到四阶,训练集的误差越来越小…

力扣串题:验证回文串2

整体思路&#xff1a;先找到可能存在问题的点&#xff0c;然后判断&#xff0c;如果一切正常则左指针会来到字符串中部 bool isValidPalindrome(char *s, int i, int j) {while (i < j) {if (s[i] ! s[j]) {return false;}i;j--;}return true; }bool validPalindrome(char …

禁用文本框输入中文,禁用中文输入法的ImeMode方法

之前遇到一个问题&#xff0c;在文本框切换输入法为中文后&#xff0c;使用扫码枪扫码时 会出现 比如条码NH123456 在文本框内会显示 你好23456 这里可以使用输入法编辑器ImeMode枚举属性 如果文本框只能输入英文数字&#xff0c;可以使用ImeMode.Disable&#xff0c;但默…

LeetCode(力扣)算法题_1261_在受污染的二叉树中查找元素

今天是2024年3月12日&#xff0c;可能是因为今天是植树节的原因&#xff0c;今天的每日一题是二叉树&#x1f64f;&#x1f3fb; 在受污染的二叉树中查找元素 题目描述 给出一个满足下述规则的二叉树&#xff1a; root.val 0 如果 treeNode.val x 且 treeNode.left ! n…

js【详解】ajax (含XMLHttpRequest、 同源策略、跨域)

ajax 的核心API – XMLHttpRequest get 请求 // 新建 XMLHttpRequest 对象的实例 const xhr new XMLHttpRequest(); // 发起 get 请求&#xff0c;open 的三个参数为&#xff1a;请求类型&#xff0c;请求地址&#xff0c;是否异步请求&#xff08; true 为异步&#xff0c;f…

Linux使用git命令行教程

. 个人主页&#xff1a;晓风飞 专栏&#xff1a;数据结构|Linux|C语言 路漫漫其修远兮&#xff0c;吾将上下而求索 文章目录 git安装git仓库的创建.git 文件添加文件git 三板斧(add,commit,push)解释拓展git log.gitignore git安装 首先输入git --version看看有没有安装git 如…

Python语言在编程业界的地位——《跟老吕学Python编程》附录资料

Python语言在编程业界的地位——《跟老吕学Python编程》附录资料 ⭐️Python语言在编程业界的地位2024年3月编程语言排行榜&#xff08;TIOBE前十&#xff09; ⭐️Python开发语言开发环境介绍1.**IDLE**2.⭐️PyCharm3.**Anaconda**4.**Jupyter Notebook**5.**Sublime Text** …

操作系统——cpu、内存、缓存介绍

一、内存是什么 内存就是系统资源的代名词&#xff0c;它是其他硬件设备与 CPU 沟通的桥梁&#xff0c; 计算机中的所有程序都在内存中运行。其作用是暂时存放CPU的运算数据&#xff0c;以及与硬盘交换的数据。也是相当于CPU与硬盘沟通的桥梁。只要计算机在运行&#xff0c;CP…

【C++那些事儿】深入理解C++类与对象:从概念到实践(下)| 再谈构造函数(初始化列表)| explicit关键字 | static成员 | 友元

&#x1f4f7; 江池俊&#xff1a;个人主页 &#x1f525; 个人专栏&#xff1a;✅C那些事儿 ✅Linux技术宝典 &#x1f305; 此去关山万里&#xff0c;定不负云起之望 文章目录 1. 再谈构造函数1.1 构造函数体赋值1.2 初始化列表1.3 explicit 关键字 2. static成员2.1 概念…

EasyPoi 教程

文章目录 EasyPoi教程文档1. 前传1.1 前言 这个服务即将关闭,文档迁移到 http://www.wupaas.com/ 请大家访问最新网站1.2 Easypoi介绍1.3 使用1.4 测试项目1.5 可能存在的小坑 2. Excel 注解版2.1 Excel导入导出2.2 注解注解介绍ExcelTargetExcelEntityExcelCollectionExcelIgn…

【PTA】L1-026 L1-027(c++) L1-028 L1-029 L1-030 L1-031(C)第五天

目录 L1-026 I Love GPLT 题解&#xff1a; L1-027 出租 题解&#xff08;c&#xff09;&#xff1a; L1-028 判断素数 题解&#xff1a; L1-029 是不是太胖了 题解&#xff1a; L1-030 一帮一 题解&#xff1a; L1-031 到底是不是太胖了 题解&#xff1a; L1-026 I…

智慧城市与数字经济:共创城市新价值

随着科技的快速发展&#xff0c;智慧城市与数字经济已成为推动城市现代化进程的重要引擎。它们不仅提升了城市治理的效率和公共服务水平&#xff0c;还为城市经济发展注入了新的活力。本文旨在探讨智慧城市与数字经济如何共同创造城市新价值&#xff0c;并分析其面临的挑战与发…

R语言复现:如何利用logistic逐步回归进行影响因素分析?

Logistic回归在医学科研、特别是观察性研究领域&#xff0c;无论是现况调查、病例对照研究、还是队列研究中都是大家经常用到的统计方法&#xff0c;而在影响因素研究筛选自变量时&#xff0c;大家习惯性用的比较多的还是先单后多&#xff0c;P&#xff1c;0.05纳入多因素研究&…

【考研学子必看 ★2024考研国家线及调剂策略(2)】

----------------------------------------------------------------------------------------------------- 考研复试科研背景提升班 教你快速深入了解掌握考研复试面试中的常见问题以及注意事项&#xff0c;系统的教你如何在短期内快速提升自己的专业知识水平和编程以及英语…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:TimePicker)

时间选择组件&#xff0c;根据指定参数创建选择器&#xff0c;支持选择小时及分钟。 说明&#xff1a; 该组件从API Version 8开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 无 接口 TimePicker(options?: TimePickerOptions)…

银河麒麟V10SP3操作系统-网络时间配置

1、动态网络配置 打开终端&#xff0c;以网口 eth0 为例&#xff1a; nmcli conn add connection.id eth0-dhcp type ether ifname eth0 ipv4.method auto其中“eth0-dhcp”为连接的名字&#xff0c;可以根据自己的需要命名方便记忆和操作 的名字&#xff1b;“ifname eth0”…

CVE-2023-49442 利用分析

1. 漏洞介绍 JEECG(J2EE Code Generation) 是开源的代码生成平台&#xff0c;目前官方已停止维护。JEECG 4.0及之前版本中&#xff0c;由于/api接口鉴权时未过滤路径遍历&#xff0c;攻击者可构造包含 ../ 的url绕过鉴权。攻击者可构造恶意请求利用 jeecgFormDemoController.do…

Redis安装(单机、主从、哨兵、集群)

一、单机安装Redis 首先需要安装Redis所需要的依赖&#xff1a; yum install -y gcc tcl 复制 下载Redis wget https://gitcode.net/weixin_44624117/software/-/raw/master/software/Linux/Redis/redis-6.2.4.tar.gz 复制 创建安装目录 mkdir /usr/local/redis 复制 …

走进AI新时代:织信低代码的实践与启示

最近 AIGC 很火&#xff0c;在各个领域都玩出了一些新花样。 比如在“低代码”领域&#xff0c;可以通过 AI 自动生成一个网站门户。 但这会带来开发效率的提升吗&#xff1f;如果 AI 能快速开发网站、APP等业务应用&#xff0c;那么 AI 生成能否完全取代低代码的可视化配置&a…

产品实操——立项阶段

一、项目开发设计流程&#xff1a; 立项阶段&#xff1a;基本信息、主要方案、市场调研、用户调研、分析得出结论 设计阶段&#xff1a;原型、UI效果图、结构流程设计 开发阶段&#xff1a;前端、后端、数据库、运维等 测试阶段&#xff1a;可用性测试、性能测试、单元测试、集…