神经网络基础知识:LeNet的搭建-训练-预测

1.参考视频:

2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili

2.总结:

(1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv+2个maxpool+3个linear层

(2)up主整理的train.py等内容里面的细节分析值得学习

(3)对于预测代码的撰写,可以参考代码的predict.py文件

3.几个文件的源代码我都贴一下(都不多——但很精):

(1)首先是 model.py:

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

模型 == 2个conv + 2个max_pool + 3个linear

(2) train.py训练模型的文件:

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():# 定义transform的数据增强transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 处理cifar10的 train和val的数据集的问题# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 训练前的准备: 实例化model网络net , 定义 loss函数 CrossEntropyLoss() 和 Adam优化器net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 开始训练:zero_grad() + outputs + loss backward + optim stepfor epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 最后把 model的 参数save 为一个.pth文件save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

分析:数据集划分 + 实例化网络_优化器_loss函数 + 分epoch开始寻 + save_pth权重

(3)predict.py:

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():# 将需要检测图像 裁剪为32*32transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#实例化网络 + 才入权重net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))# 打开图像,转换格式im = Image.open('1.jpg')im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]# 输入到网络中, 得到预测的结果with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()

predict == 处理图像 + 实例化权重 + 得到预测结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/714562.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL面试题(2)

第一题 创建trade_orders表: create table `trade_orders`( `trade_id` varchar(255) NULL DEFAULT NULL, `uers_id` varchar(255), `trade_fee` int(20), `product_id` varchar(255), `time` varchar(255) )ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_…

web自动化笔记九:验证码的处理方式

一、验证码常用的处理方式 ①、说明:Selenium中并没有对验证码处理的方法,在这里我们介绍一下针对验证码的几种常用处理方式 ②、方式: 1)、去掉验证码(测试环境下采用) …

RDD算子介绍

1. RDD算子 RDD算子也叫RDD方法,主要分为两大类:转换和行动。转换,即一个RDD转换为另一个RDD,是功能的转换与补充,比如map,flatMap。行动,则是触发任务的执行,比如collect。所谓算子…

LeetCode 1551.是数组中所有元素相等的最小操作数

存在一个长度为 n 的数组 arr &#xff0c;其中 arr[i] (2 * i) 1 &#xff08; 0 < i < n &#xff09;。 一次操作中&#xff0c;你可以选出两个下标&#xff0c;记作 x 和 y &#xff08; 0 < x, y < n &#xff09;并使 arr[x] 减去 1 、arr[y] 加上 1 &…

Mac专用投屏工具AirServer 7.27 for Mac中文版2024最新图文教程

Mac专用投屏工具AirServer 7.27 for Mac中文版是一款适用于Mac的投屏工具&#xff0c;可以将Mac屏幕快速投影到其他设备上&#xff0c;如电视、投影仪、平板等。 Mac专用投屏工具AirServer 7.27 for Mac中文版具有优秀的兼容性&#xff0c;可以与各种设备配合使用。无论是iPhon…

基于springboot+vue的在线考试系统(源码+论文)

文章目录 目录 文章目录 前言 一、功能设计 二、功能页面 三、论文 前言 现在我国关于在线考试系统的发展以及专注于对无纸化考试的完善程度普遍不高&#xff0c;关于对考试的模式还大部分还停留在纸介质使用的基础上&#xff0c;这种教学模式已不能解决现在的时代所产生的考试…

【MySQL】数据库的操作

【MySQL】数据库的操作 目录 【MySQL】数据库的操作创建数据库数据库的编码集和校验集查看系统默认字符集以及校验规则查看数据库支持的字符集查看数据库支持的字符集校验规则校验规则对数据库的影响数据库的删除 数据库的备份和恢复备份还原不备份整个数据库&#xff0c;而是备…

YOLOv9改进|增加SPD-Conv无卷积步长或池化:用于低分辨率图像和小物体的新 CNN 模块

专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;主力高效涨点&#xff01;&#xff01;&#xff01; 一、文章摘要 卷积神经网络(CNNs)在计算即使觉任务中如图像分类和目标检测等取得了显著的成功。然而&#xff0c;当图像分辨率较低或物体较小时&…

【LeetCode刷题】146. LRU 缓存

请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;则返回关键字的值&#xff0c;否则返回 -…

全量知识系统问题及SmartChat给出的答复 之9 三套工具之4语法解析器 之2

Q23. 一个语言的语法简约规则 这些规则显示show 在一个给定单词&#xff08;a given word&#xff09;的右边或左边可能出现的单词的类别。句型的多样性variety不是复杂文法&#xff08;a complex grammar&#xff09;的结果&#xff0c;而是简单语法&#xff08;a simple gra…

【InternLM 实战营笔记】浦语·灵笔的图文理解及创作部署、 Lagent 工具调用 Demo

浦语灵笔的图文理解及创作部署 浦语灵笔是基于书生浦语大语言模型研发的视觉-语言大模型&#xff0c;提供出色的图文理解和创作能力&#xff0c;结合了视觉和语言的先进技术&#xff0c;能够实现图像到文本、文本到图像的双向转换。使用浦语灵笔大模型可以轻松的创作一篇图文推…

进程间的通信 -- 共享内存

一 共享内存的概念 1. 1 共享内存的原理 之前我们学过管道通信&#xff0c;分为匿名管道和命名管道&#xff0c;匿名管道通过父子进程的属性继承原理来完成父子进程看到同一份资源的目的&#xff0c;而命名管道则是通过路径与文件名来唯一标识管道文件&#xff0c;来让不同的进…

学习Android的第二十一天

目录 Android ProgressDialog (进度条对话框) 例子 Android DatePickerDialog 日期选择对话框 例子 Android TimePickerDialog 时间选择对话框 Android PopupWindow 悬浮框 构造函数 方法 例子 官方文档 Android OptionMenu 选项菜单 例子 官方文档 Android Progr…

Java实战:Spring Boot中各类参数校验机制

引言 在开发Web应用程序时&#xff0c;对客户端传入的参数进行有效校验是保证系统安全性和稳定性的重要环节。Spring Boot作为一个现代化的Java开发框架&#xff0c;提供了多种参数校验的方法和工具&#xff0c;以满足不同场景下的需求。本文将深入探讨Spring Boot中实现各种参…

typescript 的常用方式

文章目录 前言一、绑定props 默认值的方式&#xff1a;withDefaults1.vue2 的props设置默认值2.vue3 的props设置默认值(1) 不设置默认值的写法(2) 设置默认值的写法&#xff08;分离模式&#xff09;(3) 设置默认值的写法&#xff08;组合模式&#xff09; 二、定义一个二维数…

Matlab在同一张图中如何加入多个图例

根据代码最终画出的图片如下&#xff1a; 其实原理很简单&#xff0c;就是在一张figure中画多个坐标轴&#xff0c;每个坐标轴都有对应的图例&#xff0c;之后再将多余坐标轴隐藏&#xff0c;只保留一个即可。 代码如下&#xff1a; clear all; close all;dd_linewidth 1;a …

maven archetype 项目原型

拓展阅读 maven 包管理平台-01-maven 入门介绍 Maven、Gradle、Ant、Ivy、Bazel 和 SBT 的详细对比表格 maven 包管理平台-02-windows 安装配置 mac 安装配置 maven 包管理平台-03-maven project maven 项目的创建入门 maven 包管理平台-04-maven archetype 项目原型 ma…

Spring学习笔记(六)利用Spring的jdbc实现学生管理系统的用户登录功能

一、案例分析 本案例要求学生在控制台输入用户名密码&#xff0c;如果用户账号密码正确则显示用户所属班级&#xff0c;如果登录失败则显示登录失败。 &#xff08;1&#xff09;为了存储学生信息&#xff0c;需要创建一个数据库。 &#xff08;2&#xff09;为了程序连接数…

洛谷P1927防护伞

题目描述 据说 20122012 的灾难和太阳黑子的爆发有关。于是地球防卫小队决定制造一个特殊防护伞&#xff0c;挡住太阳黑子爆发的区域&#xff0c;减少其对地球的影响。由于太阳相对于地球来说实在是太大了&#xff0c;我们可以把太阳表面看作一个平面&#xff0c;中心定为(0,0…

C 基本语法

我们已经看过 C 程序的基本结构&#xff0c;这将有助于我们理解 C 语言的其他基本的构建块。 C 的令牌&#xff08;Token&#xff09; C 程序由各种令牌组成&#xff0c;令牌可以是关键字、标识符、常量、字符串值&#xff0c;或者是一个符号。例如&#xff0c;下面的 C 语句…