LeNet原理及代码实现

目录

1.原理及介绍 

2.代码实现

2.1model.py

2.2model_train.py

2.3model.test.py


1.原理及介绍 

2.代码实现

2.1model.py

import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# 定义第一层卷积(in_channels;输入通道;out_channels:输出通道;kernel_size:卷积核大小;stride:步长,默认为1;padding:填充,默认为0)self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)self.sig = nn.Sigmoid()  # 激活函数self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)  # 第一次池化self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)  # 定义第二层卷积self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)  # 第二次池化self.flatten = nn.Flatten()  # 平展操作self.fc1 = nn.Linear(5 * 5 * 16, 120)  # 第一个全连接层self.fc2 = nn.Linear(120, 84)  # 第二个全连接层self.fc3 = nn.Linear(84, 10)  # 第三个全连接层def forward(self, x):x = self.conv1(x)x = self.sig(x)x = self.pool1(x)x = self.conv2(x)x = self.sig(x)x = self.pool2(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = LeNet().to(device)print(summary(model, (1, 28, 28)))  # 打印模型信息

2.2model_train.py

import copy
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as data
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from model import LeNetdef train_val_data_process():train_dataset = FashionMNIST(root='./data',  # 指定下载文件夹train=True,  # 只下载训练集# 归一化处理数据集transform=transforms.Compose([transforms.Resize(size=28),transforms.ToTensor()]),  # 数据转换成张量形式,方便模型应用download=True  # 下载数据)# 划分训练集和验证集train_data, val_data = data.random_split(train_dataset,[round(0.8 * len(train_dataset)), round(0.2 * len(train_dataset))])# 训练集加载train_dataloader = data.DataLoader(dataset=train_data,batch_size=32,  # 一个批次数据的数量shuffle=True,  # 数据打乱num_workers=2)  # 分配的进程数目# 验证集加载val_dataloader = data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)return train_dataloader, val_dataloaderdef train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 定义训练使用的设备,有GPU则用,没有则用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 使用Adam优化器进行模型参数更新,学习率为0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 损失函数为交叉熵损失函数criterion = nn.CrossEntropyLoss()# 将模型放入训练设备内model = model.to(device)# 复制当前模型参数(w,b等),以便将最好的模型参数权重保存下来best_model_wts = copy.deepcopy(model.state_dict())# 初始化参数# 最高准确度best_acc = 0.0# 训练集损失值列表train_loss_all = []# 验证集损失值列表val_loss_all = []# 训练集准确度列表train_acc_all = []# 验证集准确度列表val_acc_all = []# 当前时间since = time.time()for epoch in range(num_epochs):print("Epoch {}/{}".format(epoch, num_epochs - 1))print("-" * 10)# 初始化参数# 训练集损失值train_loss = 0.0# 训练集精确度train_corrects = 0# 验证集损失值val_loss = 0.0# 验证集精确度val_corrects = 0# 训练集样本数量train_num = 0# 验证集样本数量val_num = 0# 对每一个mini-batch训练和计算for step, (b_x, b_y) in enumerate(train_dataloader):# 将特征放入到训练设备中b_x = b_x.to(device)  # batch_size*28*28*1的tensor数据# 将标签放入到训练设备中b_y = b_y.to(device)  # batch_size大小的向量tensor数据# 设置模型为训练模式model.train()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)  # 输出为:batch_size大小的行和10列组成的矩阵# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)  # batch_size大小的向量表示属于物品的标签# 计算每一个batch的损失函数,向量形式的交叉熵损失函数loss = criterion(output, b_y)# 将梯度初始化为0,防止梯度累积optimizer.zero_grad()# 反向传播计算loss.backward()# 根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用optimizer.step()# 对损失函数进行累加,该批次的loss值乘于该批次数量得到批次总体loss值,在将其累加得到轮次总体loss值train_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度train_corrects加1train_corrects += torch.sum(pre_lab == b_y.data)# 当前用于训练的样本数量train_num += b_x.size(0)for step, (b_x, b_y) in enumerate(val_dataloader):# 将特征放入到验证设备中b_x = b_x.to(device)# 将标签放入到验证设备中b_y = b_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 计算每一个batch的损失函数loss = criterion(output, b_y)# 对损失函数进行累加val_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度val_corrects加1val_corrects += torch.sum(pre_lab == b_y.data)# 当前用于验证的样本数量val_num += b_x.size(0)# 计算并保存每一轮次迭代的loss值和准确率# 计算并保存训练集的loss值train_loss_all.append(train_loss / train_num)# 计算并保存训练集的准确率train_acc_all.append(train_corrects.double().item() / train_num)# 计算并保存验证集的loss值val_loss_all.append(val_loss / val_num)# 计算并保存验证集的准确率val_acc_all.append(val_corrects.double().item() / val_num)# 打印每一轮次的loss值和准确度print("{} train loss:{:.4f} train acc: {:.4f}".format(epoch, train_loss_all[-1], train_acc_all[-1]))print("{} val loss:{:.4f} val acc: {:.4f}".format(epoch, val_loss_all[-1], val_acc_all[-1]))if val_acc_all[-1] > best_acc:# 保存当前最高准确度best_acc = val_acc_all[-1]# 保存当前最高准确度的模型参数best_model_wts = copy.deepcopy(model.state_dict())# 计算训练和验证的耗时time_use = time.time() - sinceprint("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use // 60, time_use % 60))# 选择最优参数,保存最优参数的模型torch.save(best_model_wts, "./model_save/LeNet_best_model.pth")# 将产生的数据保存成表格,方便查看train_process = pd.DataFrame(data={"epoch": range(num_epochs),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})return train_processdef matplot_acc_loss(train_process):# 显示每一次迭代后的训练集和验证集的损失函数和准确率plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)  # 表示一行两列的第一张图plt.plot(train_process['epoch'], train_process.train_loss_all, "ro-", label="Train loss")plt.plot(train_process['epoch'], train_process.val_loss_all, "bs-", label="Val loss")plt.legend()plt.xlabel("epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)  # 表示一行两列的第二张图plt.plot(train_process['epoch'], train_process.train_acc_all, "ro-", label="Train acc")plt.plot(train_process['epoch'], train_process.val_acc_all, "bs-", label="Val acc")plt.xlabel("epoch")plt.ylabel("acc")plt.legend()plt.show()if __name__ == '__main__':# 加载需要的模型LeNet = LeNet()# 加载数据集train_data, val_data = train_val_data_process()# 利用现有的模型进行模型的训练train_process = train_model_process(LeNet, train_data, val_data, num_epochs=20)matplot_acc_loss(train_process)

2.3model.test.py

import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNetdef test_data_process():test_data = FashionMNIST(root='./data',train=False,  # 用测试集进行测试transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = data.DataLoader(dataset=test_data,batch_size=1,  # 该批次设为1shuffle=True,num_workers=0)return test_dataloaderdef test_model_process(model, test_dataloader):# 设定测试所用到的设备,有GPU用GPU没有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'# 讲模型放入到训练设备中model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0# 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度with torch.no_grad():for test_data_x, test_data_y in test_dataloader:# 将特征放入到测试设备中test_data_x = test_data_x.to(device)# 将标签放入到测试设备中test_data_y = test_data_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值output = model(test_data_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 如果预测正确,则准确度test_corrects加1test_corrects += torch.sum(pre_lab == test_data_y.data)# 将所有的测试样本进行累加test_num += test_data_x.size(0)# 计算测试准确率test_acc = test_corrects.double().item() / test_numprint("测试的准确率为:", test_acc)if __name__ == "__main__":# 加载模型model = LeNet()model.load_state_dict(torch.load('./model_save/LeNet_best_model.pth'))  # 调用训练好的参数权重# 加载测试数据test_dataloader = test_data_process()# 加载模型测试的函数test_model_process(model, test_dataloader)

可以点个免费的赞吗!!!   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43915.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nuxt、vue树形图d3.js

直接上代码 //安装 npm i d3 --save<template><div class"d3"><div :id"id" class"d3-content"></div></div> </template> <script> import * as d3 from "d3";export default {props: {d…

Github Actions 构建Vue3 + Vite项目

本篇文章以自己创建的项目为例&#xff0c;用Github Actions构建。 Github地址&#xff1a;https://github.com/ling08140814/myCarousel 访问地址&#xff1a;https://ling08140814.github.io/myCarousel/ 具体步骤&#xff1a; 1、创建一个Vue3的项目&#xff0c;并完成代…

接口基础知识1:认识接口

课程大纲 一、定义 接口&#xff1a;外部与系统之间、内部各子系统之间的交互点。 比如日常使用的电脑&#xff0c;有电源接口、usb接口、耳机接口、显示器接口等&#xff0c;分别可以实现&#xff1a;与外部的充电、文件数据传输、声音输入输出、图像输入输出等功能。 接口的本…

262个地级市-市场潜力指数(do文件+原始文件)

全国262个地级市-市场潜力指数&#xff08;市场潜力计算方法代码数据&#xff09;_市场潜力数据分析资源-CSDN文库 市场潜力指数&#xff1a;洞察未来发展的指南针 市场潜力指数是一个综合性的评估工具&#xff0c;它通过深入分析市场需求、竞争环境、政策支持和技术创新等多个…

(2)滑动窗口算法练习:无重复字符的最长子串

无重复字符的最长子串 题目链接&#xff1a;3. 无重复字符的最长子串 - 力扣&#xff08;LeetCode&#xff09; 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的最长子串的长度。 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是"a…

mov视频怎么改成mp4?把mov改成MP4的四个方法

mov视频怎么改成mp4&#xff1f;选择合适的视频格式对于确保内容质量和流通性至关重要。尽管苹果公司的mov格式因其出色的视频表现备受赞誉&#xff0c;但在某些情况下&#xff0c;它并非最佳选择&#xff0c;因为使用mov格式可能面临一些挑战。MP4格式在各种设备&#xff08;如…

构造二进制字符串

目录 LeetCode3221 生成不含相邻零的二进制字符串 #include <iostream> #include <vector> using namespace std;void dfs(string s,int n,vector<string>& res){if(s.size()n){res.push_back(s);return;}dfs(s"0",n,res);dfs(s"1"…

使用redis进行短信登录验证(验证码打印在控制台)

使用redis进行短信登录验证 一、流程1. 总体流程图2. 流程文字讲解&#xff1a;3.代码3.1 UserServiceImpl&#xff1a;&#xff08;难点&#xff09;3.2 拦截器LoginInterceptor&#xff1a;3.3 拦截器配置类&#xff1a; 4 功能实现&#xff0c;成功存入redis &#xff08;黑…

java中 使用数组实现需求小案例

Date: 2024.04.08 18:32:57 author: lijianzhan 需求实现&#xff1a; 设计一个java类&#xff0c;java方法&#xff0c;根据用户手动输入的绩点&#xff0c;从而获取到绩点最高的成绩。 实现业务逻辑的代码块 import java.util.Scanner;public class PointDemo {/*** 需求&…

Spring相关面试题(四)

49 JavaConfig方式如何启用AOP?如何强制使用cglib&#xff1f; 在JavaConfig类&#xff0c;加上EnableAspectJAutoProxy 如果要强制使用CGLIB动态代理 &#xff0c;加上(proxyTargetClass true) 加上(exposeProxy true) 就是将对象暴露到线程池中。 50 介绍AOP在Spring中…

详解TCP和UDP通信协议

目录 OSI的七层模型的主要功能 tcp是什么 TCP三次握手 为什么需要三次握手&#xff0c;两次握手不行吗 TCP四次挥手 挥手会什么需要四次 什么是TCP粘包问题&#xff1f;发生的原因 原因 解决方案 UDP是什么 TCP和UDP的区别 网络层常见协议 利用socket进行tcp传输代…

KIVY Button¶

Button — Kivy 2.3.0 documentation Button Jump to API ⇓ Module: kivy.uix.button Added in 1.0.0 The Button is a Label with associated actions that are triggered when the button is pressed (or released after a click/touch). To configure the button, the s…

【论文速读】| 用于安全漏洞防范的人工智能技术

本次分享论文&#xff1a;Artificial Intelligence Techniques for Security Vulnerability Prevention 基本信息 原文作者&#xff1a;Steve Kommrusch 作者单位&#xff1a;Colorado State University, Department of Computer Science, Fort Collins, CO, 80525 USA 关键…

ISO/OSI七层模型

ISO:国际标准化/ OSI:开放系统互联 七层协议必背图 1.注意事项&#xff1a; 1.上三层是为用户服务的&#xff0c;下四层负责实际数据传输。 2.下四层的传输单位&#xff1a; 传输层&#xff1b; 数据段&#xff08;报文&#xff09; 网络层&#xff1a; 数据包&#xff08;报…

Vue项目openlayers中使用jsts处理wkt和geojson的交集-(geojson来源zpi解析)

Vue项目openlayers中使用jsts处理wkt和geojson的交集-(geojson来源zpi解析) 读取压缩包中的shape看上一篇笔记&#xff1a;Vue项目读取zip中的ShapeFile文件&#xff0c;并解析为GeoJson openlayers使用jsts官方示例&#xff1a;https://openlayers.org/en/latest/examples/j…

科技创新引领水利行业升级:深入分析智慧水利解决方案的核心价值,展望其在未来水资源管理中的重要地位与作用

目录 引言 一、智慧水利的概念与内涵 二、智慧水利解决方案的核心价值 1. 精准监测与预警 2. 优化资源配置 3. 智能运维管理 4. 公众参与与决策支持 三、智慧水利在未来水资源管理中的重要地位与作用 1. 推动水利行业转型升级 2. 保障国家水安全 3. 促进生态文明建设…

vb.netcad二开自学笔记5:ActiveX链接CAD的.net写法

一、必不可少的对象引用 使用activex需要在项目属性中勾选以下两个引用&#xff0c;若找不到&#xff0c;则浏览定位直接添加下面两个文件&#xff0c;可以看到位于cad的安装路径下&#xff0c;图中的3个mgd.dll也可以勾选。 C:\Program Files\Autodesk\AutoCAD 2024\Autodes…

实战 | YOLOv8使用TensorRT加速推理教程(步骤 + 代码)

导 读 本文主要介绍如何使用TensorRT加速YOLOv8模型推理的详细步骤与演示。 YOLOv8推理加速的方法有哪些? YOLOv8模型推理加速可以通过多种技术和方法实现,下面是一些主要的策略: 1. 模型结构优化 网络剪枝:移除模型中不重要的神经元或连接,减少模型复杂度。 模型精…

中文大模型基准测评2024上半年报告

中文大模型基准测评2024上半年报告 原创 SuperCLUE CLUE中文语言理解测评基准 2024年07月09日 18:09 浙江 SuperCLUE团队 2024/07 背景 自2023年以来&#xff0c;AI大模型在全球范围内掀起了有史以来规模最大的人工智能浪潮。进入2024年&#xff0c;全球大模型竞争态势日益加…

Obsidian 文档编辑器

Obsidian是一款功能强大的笔记软件 Download - Obsidian