LeNet训练集详细实现

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):img = img / 2 + 0.5nping = img.numpy()plt.imshow(np.transpose(nping, (1, 2, 0)))plt.show()# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

for epoch in range(5):running_loss = 0.0for step, data in enumerate(trainloader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_fuction(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:with torch.no_grad():outputs = net(test_image)predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/725072.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频二维码加密怎么设置?验证密码看内容的二维码做法

现在为了保障内容的私密性,很多人会采用生成二维码的方式来展现或者传递自己的内容,比如文件、视频、音频等等。如果我们想要制作一个视频二维码,并且给二维码加密后需要输入正确密码才能查看视频内容,这种类型的二维码的制作方式…

就业班 2401--3.5 Linux Day11--DNS解析

DNS域名解析服务 少壮不努力,老大徒伤悲。一、 1.DNS介绍 DNS 是域名系统 (Domain Name System) 的缩写,是因特网的一项核心服务,它作为可以将域名和IP地址相互映射的一个分布式数据库,能够使人更方便的访问互联网,…

numpy——基础知识(创建/类型/形状/运算)(python)

简介 NumPy 是一个 Python 包。它代表 Numeric Python。 它是一个由多维数组对象和用于处理数组的例程集合组成的库。 Numeric,即 NumPy 的前身,是由 Jim Hugunin 开发的。 也开发了另一个包 Numarray ,它拥有一些额外的功能。 2005年&#…

VueCli的安装与卸载

文章目录 一.Node安装包的报读网盘提取码二、Vue脚手架Cli三、Vue-CLI使用步骤(自定义安装)1.转换路径并创建项目2.创建步骤的解释(保姆级)3.等待vue项目自己创建好(保姆级) 四、通过npm对vue的安装与卸载 一.Node安装包的报读网盘提取码 下面的链接为地址: Node的百度网盘提取…

代码随想录算法训练营Day52 | 300.最长递增子序列、674.最长连续递增序列、718.最长重复子数组

300.最长递增子序列 这题的重点是DP数组的定义,子序列必须以nums[i]为最后一个元素,这样dp数组中后面的元素才能与前面的元素进行对比 1、DP数组定义:dp[i]表示以nums[i]为最后一个元素的最长递增子序列长度 2、DP数组初始化:全部…

Redis(5.0)

1、什么是Redis Redis是一种开源的、基于内存、支持持久化的高性能Key-Value的NoSQL数据库,它同时也提供了多种数据结构来满足不同场景下的数据存储需求。 2、安装Redis(Linux) 2.1、去官网(http://www.redis.cn/)下…

C++开发基础之简单的计时器也有适配场景

一、前言 计时器的开发通常涉及到计算时间间隔的方法和计算时间的方式。一般计时器的开发步骤: 获取起始时间点:在开始计时时,记录当前的时间戳作为起始时间点。 获取结束时间点:在结束计时时,记录当前的时间戳作为结…

linux安装ngnix

一、将nginx-1.20.1.tar.gz上传至linux服务器目录下 二、将nginx安装包解压到/usr/local目录下 tar -zxvf /home/local/nginx-1.20.1.tar.gz -C /usr/local/三、预先安装依赖 yum -y install pcre-devel yum -y install openssl openssl-devel yum -y install gcc gcc-c auto…

二分查找算法:高效搜索有序数据的利器

二分查找算法:高效搜索有序数据的利器 在计算机科学中,搜索是一项基本而重要的操作。对于有序数据,二分查找算法是一种高效的搜索方法。本文将介绍二分查找算法的原理、实现以及其在实际应用中的优势,帮助读者理解和应用这一常用的…

C++学习第七天(string类)

1、学习string的原因? C语言中的字符串 C语言中,字符串是以‘\0’结尾的一些字符的集合,为了操作方便,C标准库中提供了一些str系列的库函数,但是这些库函数与字符串是分离开的,而且底层空间需要用户自己管…

day.js和moment.js的区别

day.js 和 moment.js 都是非常流行的 JavaScript 日期处理库,它们都提供了丰富的 API 来处理日期和时间。以下是它们的一些主要区别: 大小:day.js 的大小只有 2KB,而 moment.js 的大小约为 60KB。如果你关心你的项目的大小&#x…

【前端】尚硅谷Webpack教程笔记

文章目录 1. 基本使用1.1 功能介绍1.2 开始使用 参考视频:尚硅谷Webpack5入门到原理 课件地址 【前端目录贴】 1. 基本使用 1.1 功能介绍 Webpack 是一个静态资源打包工具。 它会以一个或多个文件作为打包的入口,将我们整个项目所有文件编译组合成一个或多个文件输…

13. C++类使用方式

【类】 C语言使用函数将程序代码模块化,C通过类将代码进一步模块化,类用于将实现一种功能的全局数据、以及操作这些数据的函数集中存储在一起,同时可以设置类成员的访问权限,禁止外部代码使用和篡改本类成员,类成员访…

vscode中开发goalng,debug时遇到的tools报错问题

版本 vscode最新版本golang1.18.10dlv>1.8.3gopls0.16.0 > 0.14.2 1、vscode开发golang,delve dlv版本1.19高于golang版本 Failed to launch is too old for this version of Delve 1.0、前言 下载vscode之后,安装golang1.80.10的版本&#xf…

3月每日一题笔记

感谢我的好朋友的鼓励 3月4日 两种等价方式?都是错误的 ->加减中不能使用等价无穷小? ->不全面。 两项无穷小相减, 那么两项无穷小比值的极限不等于 1 时, 或者两项无穷小相加时, 其比值极限不等于 −1 时, 代数和差各项可以用等价无穷小替换 等价无穷小不精确

腾达路由器检测环境功能破解MISP基础

在虚拟机上用qemu运行腾达路由器的网站固件会遇到无法识别网络的问题,这篇主要是破解这个功能,使腾达路由器成功在虚拟机上运行,方便漏洞复现 本次用到的腾达路由器版本: https://www.tenda.com.cn/download/detail-3683.html下…

讲享元设计模式,顺便~学会了数据库连接池,

设计模式-详细说明享元模式设计,保准一听就会,不会你来打我 1.前言 今天呢,我们来说下享元模式,享元模式是结构型模式,我的感觉,结构型模式都是相对比较简单的设计模式,这个也是,之…

数据伪列

目录 数据伪列 rownum 查询 emp 表中的记录并且取得第一行数据 取得 emp 表的前 5 行记录 rowid 面试题:表中有许多完全重复的数据,要求将重复的数据删除掉(只剩最早的一个) Oracle从入门到总裁:https://blog.csdn.net/weixin…

Guava处理异常

guava由Google开发,它提供了大量的核心Java库,例如:集合、缓存、原生类型支持、并发库、通用注解、字符串处理和I/O操作等。 异常处理 传统的Java异常处理通常包括try-catch-finally块和throws关键字。 遇到FileNotFoundException或IOExce…

【算法可视化】图论专题

运行平台 Algorithm Visualizer 图的深度优先遍历 // import visualization libraries { const { Tracer, Array1DTracer, GraphTracer, LogTracer, Randomize, Layout, VerticalLayout } require(algorithm-visualizer); // }// define tracer variables { const graphTra…