Pytorch训练LeNet模型MNIST数据集

如何用torch框架训练深度学习模型(详解)

0. 需要的包

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

1. 数据加载和导入

以MNIST数据集为例

# 1.1 需要设置数据归一化
train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
# 1.2 用dataset.MNIST函数下载和加载训练集与测试集 
train_dataset = datasets.MNIST(dataset_path, train=True, download=False, transform=train_transform)
test_dataset = datasets.MNIST(dataset_path, train=False, download=False, transform=test_transform)
# 1.3 加载进dataload用于后续数据按batch取用
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

补充:这里的transform根据不同的数据集选择不同的值
datasets加载数据集时path的路径为:'.\data\' 该目录下包括\MNIST文件夹

2. 加载模型和设置超参数

# 2.1 这里需要提前定义model的class,包括层结构和forward函数
model = LeNet_Mnist().to(device)
# 2.2 设置优化器、损失函数、训练轮次
learning_rate = 1e-2
# 传入模型参数,用于优化更新
sgd = SGD(model.parameters(), lr=learning_rate)  
loss_fn = CrossEntropyLoss()
all_epoch = 20

3. 训练

# 3.1 首先设置训练模式
model.train()
# 3.2 按照batch从train_loader中批量选择数据
for idx, (train_x, train_label) in enumerate(train_loader):train_x = train_x.to(device)train_label = train_label.to(device)sgd.zero_grad()predict_y = model(train_x.float())loss = loss_fn(predict_y, train_label.long())loss.backward()sgd.step()

补充:可以在外面再套一层迭代次数

for current_epoch in range(all_epoch):  # local training

4. 测试

# 4.1 记录测试结果
all_correct_num = 0
all_sample_num = 0
# 4.2 进入模型验证模式,该模式下不会修改梯度
model.eval()
# 4.3 按批次测试
for idx, (test_x, test_label) in enumerate(test_loader):test_x = test_x.to(device)test_label = test_label.to(device)predict_y = model(test_x.float()).detach()predict_y = torch.argmax(predict_y, dim=-1)current_correct_num = predict_y == test_labelall_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)all_sample_num += current_correct_num.shape[0]
# 4.4 记录结果并输出
acc = all_correct_num / all_sample_num
print('accuracy: {:.3f}'.format(acc), flush=True)

5. 保存结果

# 5.1 保存参数
print("Save the model state dict")
torch.save(model.state_dict(), "./lenet_mnist.pt")
# 5.2 或者也可以选择保存checkpoint,每轮都保存一次,万一中断能继续
checkpoint = {"model": model.state_dict(),"optim": sgd.state_dict(),}
print("Save the checkpoint")
torch.save(checkpoint, "./checkpoint{}.pt".format(current_epoch))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/18774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python图形界面(GUI)Tkinter笔记(九):用【Button()】功能按钮实现人机交互

在Tkinter库中,功能按钮(Button)是实现人机交互的一个非常重要的组件: 【一】主要可实现功能及意义: (1)响应用户交互: Button组件允许用户通过点击来触发某个事件或动作。当用户点击按钮时,可以执行一个指定的函数或方法。 (2)提供用户输入: Button组件是图形用户界面(G…

持续总结中!2024年面试必问 20 道 Rocket MQ面试题(三)

上一篇地址:持续总结中!2024年面试必问 20 道 Rocket MQ面试题(二)-CSDN博客 五、什么是生产者(Producer)和消费者(Consumer)在RocketMQ中? RocketMQ是一个高性能、高吞…

Linux完整版命令大全(二十五)

pine 功能说明&#xff1a;收发电子邮件&#xff0c;浏览新闻组。语  法&#xff1a;pine [-ahikorz][-attach<附件>][-attach_and_delete<附件>][-attachlist<附件清单>][-c<邮件编号>][-conf][-create_lu<地址薄><排序法>][-f<收件…

剧本杀小程序开发,探索市场发展新的商业机遇

剧本杀游戏作为一个新兴行业&#xff0c;经历了爆发式的增长&#xff0c;剧本杀游戏在市场中的热度不断升高。 不过&#xff0c;在市场的火热下&#xff0c;竞争也在逐渐加大。因此&#xff0c;在市场竞争下&#xff0c;成本低、主题多样、有趣的线上剧本杀小程序成为了创业者…

竹云董事长在第二届ICT技术发展与企业数字化转型高峰论坛作主题演讲

5月25日&#xff0c;由中国服务贸易协会指导&#xff0c;中国服务贸易协会信息技术服务委员会主办的 “第二届ICT技术发展与企业数字化转型高峰论坛” 在北京隆重召开。 本次论坛以 “数据驱动&#xff0c;AI引领&#xff0c;打造新质生产力” 为主题&#xff0c;特邀业内200余…

WebGL实现医学教学软件

使用WebGL实现医学教学软件是一个复杂但非常有益的项目&#xff0c;可以显著提升医学教育的互动性和效果。以下是详细的实现步骤&#xff0c;包括需求分析、技术选型、开发流程和注意事项。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎交流合作…

redis-cli help使用

1. redis-cli命令使用—先连接上服务器 连接到 Redis 服务器&#xff1a; 使用 redis-cli 命令即可连接到本地运行的 Redis 服务器&#xff0c;默认连接到本地的 6379 端口。 redis-cli如果 Redis 服务器不在本地或者端口不同&#xff0c;可以使用 -h 和 -p 参数指定主机和端…

华为校招机试 - LRU模拟(20240515)

题目描述 LRU(Least Recently Used)缓存算法是一种常用于管理缓存的策略,其目标是保留最近使用过的数据,而淘汰最久未被使用的数据。 实现简单的LRU缓存算法,支持查询、插入、删除操作。 最久未被使用定义:查询、插入和删除操作均为一次访问操作,每个元素均有一个最后…

探索Django 5: 从零开始,打造你的第一个Web应用

今天我们将一起探索 Django 5&#xff0c;一个备受开发者喜爱的 Python Web 框架。我们会了解 Django 5 的简介&#xff0c;新特性&#xff0c;如何安装 Django&#xff0c;以及用 Django 编写一个简单的 “Hello, World” 网站。最后&#xff0c;我会推荐一本与 Django 5 相关…

苏洵,大器晚成的家风塑造者

&#x1f4a1; 如果想阅读最新的文章&#xff0c;或者有技术问题需要交流和沟通&#xff0c;可搜索并关注微信公众号“希望睿智”。 苏洵&#xff0c;字明允&#xff0c;号老泉&#xff0c;生于宋真宗大中祥符二年&#xff08;公元1009年&#xff09;&#xff0c;卒于宋英宗治平…

量产导入 | 产品可靠性测试标准完整大集合(JEDEC/IEC/SAE…)

产品可靠性测试标准完整大集合(JEDEC/IEC/SAE…) 产品可靠性测试是产品质量保证中的重要一环, 包含有Pre-con, aging(寿命)和ESD(静电)等, 下面就收集了权威标准JEDEC全系列, 请参照如下 同时也附上其它的可靠性标准供大家参考及交叉理解, 可能侧重点不同, 大家可以参…

go语言同一包中的同一变量实现不同平台设置不同的默认值 //go:build 编译语法使用示例

在使用go来开发跨平台应用的时候&#xff0c;比如配置文件的路径&#xff0c;我们希望设置一个默认值&#xff0c;windows下的路径是类似 d:\myapp\app.conf 这样的&#xff0c; unix系统中的路径是 /opt/myapp/app.conf 这样的&#xff0c; 而我们在使用的时候需要使用的是同…

PPT忘记保存?教你如何轻松恢复

在日常办公中PPT文件作为主流文档格式&#xff0c;承载着我们大量的工作成果。然而当不小心误点了“不保存”按钮&#xff0c;或是遭遇软件崩溃等意外情况导致文档丢失时&#xff0c;文件内容是否还能够能恢复&#xff0c;往往成为我们最关心的问题。本文将为您提供五大免费且实…

NetCore PetaPoco 事务处理分享

PetaPoco是一个轻量级的.NET和Mono数据库访问库&#xff0c;它以单个C#文件的形式存在&#xff0c;便于集成到任何项目中。PetaPoco的主要特点包括无依赖性、快速的性能和对简单事务的支持。它适用于严格的没有装饰的Poco类以及几乎全部加了特性的Poco类&#xff0c;并提供了多…

现在版本的ultralytics没有setup.py以后,本地代码中修改了ultralytics源码,怎么安装到python环境中。

问题&#xff0c;在使用ultralytics训练yolov8-obb模型时&#xff0c;修改了ultralytics源码的网络结构&#xff0c;发现调用的还是pip install安装的ultralytics库&#xff0c;新版本源码中还没有setup.py&#xff0c;该怎么把源码中的ultralytics安装到环境中。 解决方法&am…

《探索网络七层模型:构建高效通信架构的关键》

在当今数字化时代&#xff0c;网络通信已经成为人们生活和工作中不可或缺的一部分。而网络七层模型作为计算机网络体系结构的重要基础&#xff0c;其技术架构对于构建高效、稳定的通信系统具有重要意义。本文将深入探讨网络七层模型的技术架构设计&#xff0c;以及其在构建现代…

轻松掌握图片批量处理,赶紧学习这些小技巧!

在现今数字化的社会中&#xff0c;我们每天都会接触到大量的图片&#xff0c;无论是在工作中还是日常生活中。要想高效处理这些图片&#xff0c;掌握图片批量处理的技巧就显得尤为重要。幸运的是&#xff0c;有许多小技巧和工具可以让这一过程变得轻松愉快。 在本文中&#xf…

长安链使用Golang编写智能合约教程(三)

本篇主要介绍长安链Go SDK写智能合约的一些常见方法的使用方法或介绍 资料来源&#xff1a; 官方文档官方示例合约库官方SDK接口文档 一、获取参数、获取状态、获取历史记录的方法解析 注意&#xff01; 这些查询链上数据的方法&#xff1a;只能是查询本合约之前上链的数据&a…

信息学一周赛事安排

本周比赛提醒 本周有以下几场比赛即将开始&#xff1a; 1.ABC-356 比赛时间&#xff1a;6月1日&#xff08;周六&#xff09;晚20:00 比赛链接&#xff1a;https://atcoder.jp/contests/abc356 2.ARC-179 比赛时间&#xff1a;6月2日&#xff08;周日&#xff09;晚20:00 …

【Go】十、路由配置以及ZAP 高性能日志库的使用

Project 目录创建 mxshop-api user-web api ---- 服务接口 config ---- 配置信息 forms ---- 表单验证信息 global ---- 全局信息 initialize ---- 初始化信息 middlewares ---- 中间件信息 proto ---- 数据信息 router ---- 路由信息 utils ---- 公用工具信息 validator ----…