Pytorch网络模型训练

现有网络模型的使用与修改

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)

以下是vgg16预训练模型的输出 

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

预训练模型的输出从1000类别转为10类别

import torchvision
from torch import nn
# 因为数据集过大,所以注释掉此行代码
# train_data = torchvision.datasets.ImageNet("./data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别print(vgg16_true)# vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
# 在预训练模型的最后添加了一个新的全连接层,用于将最后的输出转化为10个类别
print(vgg16_true)print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
# 未预训练模型的最后一层的输出特征数更改为了10
print(vgg16_false)

网络模型的保存与读取

加载未预训练的模型

vgg16 = torchvision.models.vgg16(pretrained=False)

方式一

# 保存方式1  保存的模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pyth")#读取方式1
model = torch.load("vgg16_method1.pth")

方式二

# 保存方式2  不再保存模型结构,而是保存模型的参数为字典结构    推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pyth")# 方式2,加载模型
# model = torch.load("vgg16_method2.pth")     #这样输出的是字典类型
# print(model)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))      # 将其恢复为网络模型
print(vgg16)

完整的模型训练套路

准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为{}".format(train_data_size))    # 50000
print("测试数据集的长度为{}".format(test_data_size))     # 10000# 利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

创建网络模型

# 创建网络模型  神经网络的代码在train_module文件
tudui = Tudui()

train_module文件

# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()# 简化操作,并且按顺序进行操作self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x

构建损失函数

# 损失函数
loss_fn = nn.CrossEntropyLoss()

构建优化器

# 优化器
# 如果学习率过大,模型可能会在最小值附近震荡而无法收敛;如果学习率过小,模型训练可能会过于缓慢
learning_rate = 0.01
# 使用随机梯度下降算法来更新模型的权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

设置训练集参数

# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

添加tensorboard

# 将数据写入 TensorBoard 可视化的日志文件中
writer = SummaryWriter("../logs_train")

训练步骤

# tudui.train()
for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()# 将优化器中的梯度缓存(如果有的话)清零loss.backward()# 计算损失函数(loss)相对于模型参数的梯度optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:# .item()是将tensor张量变为正常的数字print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))# loss.item()是当前步骤的损失值writer.add_scalar("train_loss", loss.item(), total_train_step)# 使用add_scalar可以将一个标量添加到之前的所有标量值中,# 这样就可以在TensorBoard中绘制一个标量随时间变化的图表

测试步骤

# 测试步骤开始
# tudui.eval()
total_test_loss = 0
total_accuracy = 0
# 不会对以下的代码进行调优
with torch.no_grad():for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()# argmax(1)是横向看,argmax(0)是纵向看accuracy = (outputs.argmax(1) == targets).sum()# argmax在找到模型预测的最大概率对应的类别# 预测正确的个数total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
# 测试集上的总损失
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/130005.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FFmpeg直播能力更新计划与新版本发布

// 编者按:客户端作为直接面向用户大众的接口,随着技术的发展进化与时俱进,实现更好的服务是十分必要的。FFmpeg作为最受欢迎的视频和图像处理开源软件,被相关行业的大量用户青睐,而随着HEVC标准的发布到广泛使用&am…

【jvm】虚拟机栈

目录 一、背景二、栈与堆三、声明周期四、作用五、特点(优点)六、可能出现的异常七、设置栈内存大小八、栈的存储单位九、栈运行原理十、栈帧的内部结构10.1 说明10.2 局部变量表10.3 操作数栈10.4 动态链接10.5 方法返回地址10.6 一些附加信息 十一、代…

整理10个地推拉新app接单平台,免费一手推广渠道平台干货分享

1. 聚量推客: “聚量推客”汇聚了众多市场上有的和没有的地推网推拉新接单项目,目前比较火热,我们做地推和网推从业者如果长期在这行业去做推广可以使用这个平台,价格高数据也好,大部分拉新项目也都是官签一手资源 一…

关于Intel Press出版的《Bedyong BIOS》第2版的观后感

文章目录 此书的背景UEFI运行时DXE基础CPU架构协议PCI协议UEFI驱动的初始化串口DXE驱动示例 《Beyond BIOS》首先介绍一个简单的UEFI应用程序模块,用于展示UEFI应用程序的行为。作者为Waldo。该模块名为“InitializeHelloApplication”,它接受两个参数&a…

Leetcode—101.对称二叉树【简单】

2023每日刷题(十九) Leetcode—101.对称二叉树 利用Leetcode101.对称二叉树的思想的实现代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ bool isSa…

【深度学习基础】Pytorch框架CV开发(1)基础铺垫

📢:如果你也对机器人、人工智能感兴趣,看来我们志同道合✨ 📢:不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 📢:文章若有幸对你有帮助,可点赞 👍…

3 — NLP 中的标记化:分解文本数据的艺术

一、说明 这是一个系列文章的第三篇文章, 文章前半部分分别是: 1 、NLP 的文本预处理技术 2、NLP文本预处理技术:词干提取和词形还原 在本文中,我们将介绍标记化主题。在开始之前,我建议您阅读我之前介绍的关…

Docker的简单安装

安装环境 CentOS Linux release 8.1.1911 (Core)内核4.18.0-147.el8.x86_64Mini Installation 安装前的准备工作 切换国内源 由于centos源已经过期,所以切换为阿里云的yum源,第二个是docker的仓库 wget -O /etc/yum.repos.d/CentOS-Base.repo https:…

云闪付app拉新推广一手渠道 附详细教程

云闪付推广拉新可以通过“聚量推客”申请 云闪付是什么呢?是中国银联出的支付平台,在地推和网推项目里也算是比较火热的app拉新产品,属于地推和网推的百搭项目,操作也简单 只需要动账就算一个数据,目前主要招收地推、…

重新思考边缘负载均衡

本文介绍了Netflix在基于轮询的负载均衡的基础上,集成了包括服务器使用率在内的多因素指标,并对冷启动服务器进行了特殊处理,从而优化了负载均衡逻辑,提升了整体业务性能。原文: Rethinking Netflix’s Edge Load Balancing[1] 我…

第十五章 EM期望极大算法及其推广

文章目录 导读符号说明混合模型伯努利混合模型(三硬币模型)问题描述三硬币模型的EM算法1.初值2.E步3.M步初值影响p,q 含义 EM算法另外视角Q 函数BMM的EM算法目标函数LEM算法导出 高斯混合模型GMM的EM算法1. 明确隐变量, 初值2. E步,确定Q函数3. M步4. 停止条件 如何应用GMM在聚…

软测推荐第二期:10本高质量测试书籍

在不断发展的软件开发领域,测试是质量的守护者,确保产品不仅满足功能要求,而且提供无缝的用户体验。随着软件复杂性的增加,对完善的测试方法和见解的需求也随之增加。 上次给大家推荐了五本书,获得了大家的积极反馈&a…

RT-Thread系统使用常见问题处理记录

1.使用telnet连接系统时发送help指令显示不全的问题。 原因:telnet发送缓存太小。 解决办法:更改agile_telnet软件包里Set agile_telnet tx buffer size的大小。 2.使用Paho MQTT软件包过一段时间报错hard fault on thread: mqtt0 解决办法&#xff1…

UE5加载websocket模块为空

今天测试UE 发现工程启动不了,后来看到原来是websocket模块无法加载。 解决的它的方法很简单,这种问题一般会出现在源码版本的引擎或者是停电了,导致UElaunch版本损坏,解决方法是来到源码版本的引擎 这个目录下: D:\…

稳定性测试—fastboot和monkey区别

一、什么是稳定性测试 稳定性测试是指检验程序在一定时间内能否稳定地运行,在不同的场景下能否正常地工作的过程。主要目的是检测崩溃、内存泄漏、堆栈错误等缺陷。 二、Monkey 1.什么是Monkey 是一个命令行工具,通常在adb安卓调试运行,模…

ABAP简单的队列设置QRFC

场景:用job的方式在接口里启用job,如果接口调用比较频繁,存在同一时间启动相同job的情况,会导致锁表锁程序这种情况。 查阅job函数,发现在JOB_CLOSE函数里自带了类似队列的参数,但是因为是接口&#xff0c…

如何卸载干净 IDEA(图文讲解)windows和Mac教程

大家好,我是sun~ 很多小伙伴会问 Windows / Mac 系统上要怎么彻底卸载 IDEA 呢? 本文通过图片文字,详细讲解具体步骤: 如何卸载干净 IDEA(图文讲解) Windows1、卸载 IDEA 程序2、注册表清理3、残留清理 M…

重生奇迹mu下载后仅仅只是挂机吗?

挂挂机、聊聊天,打打怪,如此简单、轻松的游戏或许有,但绝对不是重生奇迹mu!因为重生奇迹mu挂机也不是那么容易,即便是多名高端玩家组队挂机,也有可能是全队惨灭,这样的情况时常发生在游戏中。 …

【入门Flink】- 05Flink运行时架构以及一些核心概念

系统架构 Flink运行时架构Standalone会话模式为例 1)作业管理器(JobManager) JobManager 是一个 Flink 集群中任务管理和调度的核心,是控制应用执行的主进程。每个应用都应该被唯一的 JobManager 所控制执行。 JobManger 又包含…

聚观早报 |盒马参战双11;真我GT5 Pro将压轴登场

【聚观365】11月4日消息 盒马参战双11 真我GT5 Pro将压轴登场 奇瑞汽车10月销量创新高 iQOO 12系列将首发电竞芯片Q1 苹果CEO库克称正改善供需平衡 盒马参战双11 不少消费者反映,今年盒马的双11已悄然开始:10月20日起,盒马APP很多商品页…