LeNet原理及代码实现

目录

1.原理及介绍 

2.代码实现

2.1model.py

2.2model_train.py

2.3model.test.py


1.原理及介绍 

2.代码实现

2.1model.py

import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 定义第一层卷积(in_channels;输入通道;out_channels:输出通道;kernel_size:卷积核大小;stride:步长,默认为1;padding:填充,默认为0)self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)self.sig = nn.Sigmoid()  # 激活函数self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)  # 第一次池化self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)  # 定义第二层卷积self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)  # 第二次池化self.flatten = nn.Flatten()  # 平展操作self.fc1 = nn.Linear(5 * 5 * 16, 120)  # 第一个全连接层self.fc2 = nn.Linear(120, 84)  # 第二个全连接层self.fc3 = nn.Linear(84, 10)  # 第三个全连接层def forward(self, x):x = self.conv1(x)x = self.sig(x)x = self.pool1(x)x = self.conv2(x)x = self.sig(x)x = self.pool2(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = LeNet().to(device)print(summary(model, (1, 28, 28)))  # 打印模型信息

2.2model_train.py

import copy
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as data
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from model import LeNetdef train_val_data_process():train_dataset = FashionMNIST(root='./data',  # 指定下载文件夹train=True,  # 只下载训练集# 归一化处理数据集transform=transforms.Compose([transforms.Resize(size=28),transforms.ToTensor()]),  # 数据转换成张量形式,方便模型应用download=True  # 下载数据)# 划分训练集和验证集train_data, val_data = data.random_split(train_dataset,[round(0.8 * len(train_dataset)), round(0.2 * len(train_dataset))])# 训练集加载train_dataloader = data.DataLoader(dataset=train_data,batch_size=32,  # 一个批次数据的数量shuffle=True,  # 数据打乱num_workers=2)  # 分配的进程数目# 验证集加载val_dataloader = data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)return train_dataloader, val_dataloaderdef train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 定义训练使用的设备,有GPU则用,没有则用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 使用Adam优化器进行模型参数更新,学习率为0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 损失函数为交叉熵损失函数criterion = nn.CrossEntropyLoss()# 将模型放入训练设备内model = model.to(device)# 复制当前模型参数(w,b等),以便将最好的模型参数权重保存下来best_model_wts = copy.deepcopy(model.state_dict())# 初始化参数# 最高准确度best_acc = 0.0# 训练集损失值列表train_loss_all = []# 验证集损失值列表val_loss_all = []# 训练集准确度列表train_acc_all = []# 验证集准确度列表val_acc_all = []# 当前时间since = time.time()for epoch in range(num_epochs):print("Epoch {}/{}".format(epoch, num_epochs - 1))print("-" * 10)# 初始化参数# 训练集损失值train_loss = 0.0# 训练集精确度train_corrects = 0# 验证集损失值val_loss = 0.0# 验证集精确度val_corrects = 0# 训练集样本数量train_num = 0# 验证集样本数量val_num = 0# 对每一个mini-batch训练和计算for step, (b_x, b_y) in enumerate(train_dataloader):# 将特征放入到训练设备中b_x = b_x.to(device)  # batch_size*28*28*1的tensor数据# 将标签放入到训练设备中b_y = b_y.to(device)  # batch_size大小的向量tensor数据# 设置模型为训练模式model.train()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)  # 输出为:batch_size大小的行和10列组成的矩阵# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)  # batch_size大小的向量表示属于物品的标签# 计算每一个batch的损失函数,向量形式的交叉熵损失函数loss = criterion(output, b_y)# 将梯度初始化为0,防止梯度累积optimizer.zero_grad()# 反向传播计算loss.backward()# 根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用optimizer.step()# 对损失函数进行累加,该批次的loss值乘于该批次数量得到批次总体loss值,在将其累加得到轮次总体loss值train_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度train_corrects加1train_corrects += torch.sum(pre_lab == b_y.data)# 当前用于训练的样本数量train_num += b_x.size(0)for step, (b_x, b_y) in enumerate(val_dataloader):# 将特征放入到验证设备中b_x = b_x.to(device)# 将标签放入到验证设备中b_y = b_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 计算每一个batch的损失函数loss = criterion(output, b_y)# 对损失函数进行累加val_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度val_corrects加1val_corrects += torch.sum(pre_lab == b_y.data)# 当前用于验证的样本数量val_num += b_x.size(0)# 计算并保存每一轮次迭代的loss值和准确率# 计算并保存训练集的loss值train_loss_all.append(train_loss / train_num)# 计算并保存训练集的准确率train_acc_all.append(train_corrects.double().item() / train_num)# 计算并保存验证集的loss值val_loss_all.append(val_loss / val_num)# 计算并保存验证集的准确率val_acc_all.append(val_corrects.double().item() / val_num)# 打印每一轮次的loss值和准确度print("{} train loss:{:.4f} train acc: {:.4f}".format(epoch, train_loss_all[-1], train_acc_all[-1]))print("{} val loss:{:.4f} val acc: {:.4f}".format(epoch, val_loss_all[-1], val_acc_all[-1]))if val_acc_all[-1] > best_acc:# 保存当前最高准确度best_acc = val_acc_all[-1]# 保存当前最高准确度的模型参数best_model_wts = copy.deepcopy(model.state_dict())# 计算训练和验证的耗时time_use = time.time() - sinceprint("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use // 60, time_use % 60))# 选择最优参数,保存最优参数的模型torch.save(best_model_wts, "./model_save/LeNet_best_model.pth")# 将产生的数据保存成表格,方便查看train_process = pd.DataFrame(data={"epoch": range(num_epochs),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})return train_processdef matplot_acc_loss(train_process):# 显示每一次迭代后的训练集和验证集的损失函数和准确率plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)  # 表示一行两列的第一张图plt.plot(train_process['epoch'], train_process.train_loss_all, "ro-", label="Train loss")plt.plot(train_process['epoch'], train_process.val_loss_all, "bs-", label="Val loss")plt.legend()plt.xlabel("epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)  # 表示一行两列的第二张图plt.plot(train_process['epoch'], train_process.train_acc_all, "ro-", label="Train acc")plt.plot(train_process['epoch'], train_process.val_acc_all, "bs-", label="Val acc")plt.xlabel("epoch")plt.ylabel("acc")plt.legend()plt.show()if __name__ == '__main__':# 加载需要的模型LeNet = LeNet()# 加载数据集train_data, val_data = train_val_data_process()# 利用现有的模型进行模型的训练train_process = train_model_process(LeNet, train_data, val_data, num_epochs=20)matplot_acc_loss(train_process)

2.3model.test.py

import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNetdef test_data_process():test_data = FashionMNIST(root='./data',train=False,  # 用测试集进行测试transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = data.DataLoader(dataset=test_data,batch_size=1,  # 该批次设为1shuffle=True,num_workers=0)return test_dataloaderdef test_model_process(model, test_dataloader):# 设定测试所用到的设备,有GPU用GPU没有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'# 讲模型放入到训练设备中model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0# 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度with torch.no_grad():for test_data_x, test_data_y in test_dataloader:# 将特征放入到测试设备中test_data_x = test_data_x.to(device)# 将标签放入到测试设备中test_data_y = test_data_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值output = model(test_data_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 如果预测正确,则准确度test_corrects加1test_corrects += torch.sum(pre_lab == test_data_y.data)# 将所有的测试样本进行累加test_num += test_data_x.size(0)# 计算测试准确率test_acc = test_corrects.double().item() / test_numprint("测试的准确率为:", test_acc)if __name__ == "__main__":# 加载模型model = LeNet()model.load_state_dict(torch.load('./model_save/LeNet_best_model.pth'))  # 调用训练好的参数权重# 加载测试数据test_dataloader = test_data_process()# 加载模型测试的函数test_model_process(model, test_dataloader)

可以点个免费的赞吗!!!   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43915.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nuxt、vue树形图d3.js

直接上代码 //安装 npm i d3 --save<template><div class"d3"><div :id"id" class"d3-content"></div></div> </template> <script> import * as d3 from "d3";export default {props: {d…

Github Actions 构建Vue3 + Vite项目

本篇文章以自己创建的项目为例&#xff0c;用Github Actions构建。 Github地址&#xff1a;https://github.com/ling08140814/myCarousel 访问地址&#xff1a;https://ling08140814.github.io/myCarousel/ 具体步骤&#xff1a; 1、创建一个Vue3的项目&#xff0c;并完成代…

接口基础知识1:认识接口

课程大纲 一、定义 接口&#xff1a;外部与系统之间、内部各子系统之间的交互点。 比如日常使用的电脑&#xff0c;有电源接口、usb接口、耳机接口、显示器接口等&#xff0c;分别可以实现&#xff1a;与外部的充电、文件数据传输、声音输入输出、图像输入输出等功能。 接口的本…

262个地级市-市场潜力指数(do文件+原始文件)

全国262个地级市-市场潜力指数&#xff08;市场潜力计算方法代码数据&#xff09;_市场潜力数据分析资源-CSDN文库 市场潜力指数&#xff1a;洞察未来发展的指南针 市场潜力指数是一个综合性的评估工具&#xff0c;它通过深入分析市场需求、竞争环境、政策支持和技术创新等多个…

面向字节流传输数据

当提到“传输数据面向字节流”&#xff0c;这是指在网络通信中&#xff0c;数据被视作一连串的无结构字节&#xff0c;而不是按照特定的数据块或记录进行传输。这种传输方式是面向传输层协议&#xff08;如TCP&#xff09;的一个特性&#xff0c;它允许数据以连续的字节流形式在…

phpstudy框架,window平台,如何开端口给局域网访问?

Windows平台上使用phpstudy框架开端口给同事访问&#xff0c;主要涉及到几个步骤&#xff1a;查看并确认本机IP地址、配置phpstudy及网站项目、开放防火墙端口以及确保同事能够通过局域网访问。以下是详细的步骤说明&#xff1a; 1. 查看并确认本机IP地址 首先&#xff0c;需…

SQLAlchemy pool_pre_ping

pool_pre_ping 是 SQLAlchemy 中 create_engine 函数的一个参数&#xff0c;它用于配置连接池的行为。当设置为 True 时&#xff0c;pool_pre_ping 启用了连接池在每次从池中取出&#xff08;即“签出”或“checkout”&#xff09;连接之前&#xff0c;先测试该连接是否仍然活跃…

(2)滑动窗口算法练习:无重复字符的最长子串

无重复字符的最长子串 题目链接&#xff1a;3. 无重复字符的最长子串 - 力扣&#xff08;LeetCode&#xff09; 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的最长子串的长度。 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是"a…

mov视频怎么改成mp4?把mov改成MP4的四个方法

mov视频怎么改成mp4&#xff1f;选择合适的视频格式对于确保内容质量和流通性至关重要。尽管苹果公司的mov格式因其出色的视频表现备受赞誉&#xff0c;但在某些情况下&#xff0c;它并非最佳选择&#xff0c;因为使用mov格式可能面临一些挑战。MP4格式在各种设备&#xff08;如…

构造二进制字符串

目录 LeetCode3221 生成不含相邻零的二进制字符串 #include <iostream> #include <vector> using namespace std;void dfs(string s,int n,vector<string>& res){if(s.size()n){res.push_back(s);return;}dfs(s"0",n,res);dfs(s"1"…

使用redis进行短信登录验证(验证码打印在控制台)

使用redis进行短信登录验证 一、流程1. 总体流程图2. 流程文字讲解&#xff1a;3.代码3.1 UserServiceImpl&#xff1a;&#xff08;难点&#xff09;3.2 拦截器LoginInterceptor&#xff1a;3.3 拦截器配置类&#xff1a; 4 功能实现&#xff0c;成功存入redis &#xff08;黑…

搜维尔科技为空气分离、氢气、石化和天然气工厂的现场操作员提供虚拟现实(VR)培训

搜维尔科技为空气分离、氢气、石化和天然气工厂的现场操作员提供虚拟现实(VR)培训 搜维尔科技为空气分离、氢气、石化和天然气工厂的现场操作员提供虚拟现实(VR)培训

python 中关于append和extend的区别用法

#方法1 d[1,2,[3,4]] c[] for i in d:if type(i) int:c.append(i)else:c.extend(i)# append方法用于将单个元素添加到列表的末尾&#xff0c;这意味着无论元素是什么类型# &#xff08;如整数、字符串等&#xff09;&#xff0c;它都将作为一个独立的元素添加到列表中。# exten…

UE5.2 AI实时抠像(无需绿幕) + OBS推流直播 全流程

最近通过2个UE5.2插件实现了从AI实时抠像到OBS推流的直播流程搭建&#xff0c;也为了水一篇博客&#xff0c;就在这里记录一下了&#xff0c;觉得没有意思的朋友&#xff0c;这里先说为敬了。 具体教程参考&#xff1a;【UE5 AI抠像OBS推流全流程&#xff08;简单免费&#xf…

华为机考真题 -- 寻找身高相近的小朋友

题目描述: 小明今年升学到z小学—年级,来到新班级后发现其他小朋友们身高参差不齐,然后就想基于各4朋友和自己的身高差q对他们进行排序,请帮他实现排序。 输入描述: 有一行为正整数h和n,0<h<200,为小明的身高,0<n<50,为新班级其他小朋友个数。 第二行为…

java中 使用数组实现需求小案例

Date: 2024.04.08 18:32:57 author: lijianzhan 需求实现&#xff1a; 设计一个java类&#xff0c;java方法&#xff0c;根据用户手动输入的绩点&#xff0c;从而获取到绩点最高的成绩。 实现业务逻辑的代码块 import java.util.Scanner;public class PointDemo {/*** 需求&…

Spring相关面试题(四)

49 JavaConfig方式如何启用AOP?如何强制使用cglib&#xff1f; 在JavaConfig类&#xff0c;加上EnableAspectJAutoProxy 如果要强制使用CGLIB动态代理 &#xff0c;加上(proxyTargetClass true) 加上(exposeProxy true) 就是将对象暴露到线程池中。 50 介绍AOP在Spring中…

【3】迁移学习模型

【3】迁移学习模型 文章目录 前言一、安装相关模块二、训练代码2.1. 管理预训练模型2.2. 模型训练代码2.3. 可视化结果2.4. 类别函数 总结 前言 主要简述一下训练代码 三叶青图像识别研究简概 一、安装相关模块 #xingyun的笔记本 print(xingyun的笔记本) %pip install d2l %…

详解TCP和UDP通信协议

目录 OSI的七层模型的主要功能 tcp是什么 TCP三次握手 为什么需要三次握手&#xff0c;两次握手不行吗 TCP四次挥手 挥手会什么需要四次 什么是TCP粘包问题&#xff1f;发生的原因 原因 解决方案 UDP是什么 TCP和UDP的区别 网络层常见协议 利用socket进行tcp传输代…

【js面试题】深入理解DOM操作:创建、查询、更新、添加和删除节点

面试题&#xff1a;DOM常见的操作有哪些 引言&#xff1a; 在前端开发中&#xff0c;DOM&#xff08;文档对象模型&#xff09;操作是日常工作中不可或缺的一部分。DOM提供了一种以编程方式访问和更新文档内容、结构和样式的接口。 任何html或 xml 文档都可以用dom表示为一个由…