【深度学习】(7)--神经网络之保存最优模型

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset) # 总数据大小num_batches = len(dataloader) # 划分的小批次数量model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)print(model.state_dict().keys()) # 输出模型参数名称cnntorch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 2. 保存完整模型(w,b,模型cnn)torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloader,model,loss_fn,optimizer)test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54731.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【从0开始自动驾驶】用python做一个简单的自动驾驶仿真可视化界面

【从0开始自动驾驶】用python做一个简单的自动驾驶仿真可视化界面 废话几句废话不多说,直接上源码目录结构init.pysimulator.pysimple_simulator_app.pyvehicle_config.json 废话几句 自动驾驶开发离不开仿真软件成品仿真软件种类多https://zhuanlan.zhihu.com/p/3…

【CSS】鼠标 、轮廓线 、 滤镜 、 堆叠层级

cursor 鼠标outline 轮廓线filter 滤镜z-index 堆叠层级 cursor 鼠标 值说明值说明crosshair十字准线s-resize向下改变大小pointer \ hand手形e-resize向右改变大小wait表或沙漏w-resize向左改变大小help问号或气球ne-resize向上右改变大小no-drop无法释放nw-resize向上左改变…

蓝桥杯1.小蓝的漆房

样例输入 2 5 2 1 1 2 2 1 6 2 1 2 2 3 3 3样例输出 1 2 import math import os import sys tint(input())#执行的次数 for j in range(t):n,kmap(int,input().split())#n为房间数 k为一次能涂的个数alist(map(int,input().split()))#以列表的形式存放房间的颜色maxvaluemath…

处理RabbitMQ连接和认证问题

在使用RabbitMQ进行消息队列管理时,我们可能会遇到各种连接和认证问题。本文将介绍如何诊断和解决这些问题,并通过使用RabbitMQ的管理端进行登录验证来确保配置正确。 1. 问题概述 在最近的一次部署中,我们遇到了两个主要问题: …

IPSec隧道协议学习(一)

前情回顾 前面介绍的GRE隧道协议,可以字LAN之间通过Internet建立隧道,实现网络间资源共享,但是GRE隧道协议不能实现加密功能,传输的数据不受加密保护,为了实现在隧道间传输数据包收到加密保护,需要使用IPS…

GitLab发送邮件功能详解:如何配置自动化?

GitLab发送邮件的设置指南?怎么优化GitLab发送邮件? GitLab作为一个强大的代码管理平台,不仅提供了代码托管、CI/CD等功能,还集成了发送邮件的功能,使得开发团队能够及时获取项目动态。AokSend将详细介绍如何配置GitL…

Spring 的依赖注入原理

Spring 的依赖注入(Dependency Injection,DI)是其核心特性之一,它的主要作用是管理对象之间的依赖关系,降低对象之间的耦合度,提高代码的可维护性和可测试性。其原理如下: 一、基本概念 控制反…

代码随想录 -- 回溯 -- 非递减子序列

491. 非递减子序列 - 力扣(LeetCode) 思路:重点是去重 收集结果:每次进入递归先判断path中的元素数量,如果大于1了,就将path收集到result中。 递归参数:nums,index,pa…

2024 go-zero社交项目实战

背景 一位商业大亨,他非常看好国内的社交产品赛道,想要造一款属于的社交产品,于是他找到了负责软件研发的小明。 小明跟张三一拍即合,小明决定跟张三大干一番。 社交产品MVP版本需求 MVP指:Minimum Viable Product&…

职场能力强的人都在做什么---今日头条

【职场里,能力强的人都在做哪些事... - 今日头条】https://m.toutiao.com/is/ikn6kt9q/ 知识雷达 2024-09-21 16:33 目录 职场里,能力强的人都在做哪些事呢? 1、复盘; 2、多角度思考;3、记录信息; 4、永远积极主动;5、主动获取信息差; 6、明确人和人的关系;7、…

ISO8583包简介(一)

简介 ISO8583包(简称8583包)是一个国际标准的包格式,最多由128个字段域组成,每个域都有统一的规定,并有定长与变长之分。8583包前面一段为位图,用来确定包的字段域组成情况。 其中位图是8583包的灵魂&#…

【Altium Designer程序开发】BGA芯片自动扇出

BGA自动扇出功能支持将BGA器件从4个方向上扇出,里面有无空白区域均可支持,执行速度非常快,通常在秒级的时间内即可处理完成,程序可以通过以下几种方式启动。 ➡️支持从菜单栏启动 ➡️支持从工具栏启动 ➡️支持从服务器面板启动…

Go weak包前瞻:弱指针为内存管理带来新选择

在介绍Go 1.23引入的unique包的《Go unique包:突破字符串局限的通用值Interning技术实现》一文中,我们知道了unique包底层是基于internal/weak包实现的,internal/weak是一个弱指针功能的Go实现。所谓弱指针(Weak Pointer,也称为弱…

HarmonyOS鸿蒙开发实战(5.0)自定义路由栈管理

鸿蒙HarmonyOS NEXT开发实战往期文章必看(持续更新......) HarmonyOS NEXT应用开发性能实践总结 HarmonyOS NEXT应用开发案例实践总结合集 最新版!“非常详细的” 鸿蒙HarmonyOS Next应用开发学习路线!(从零基础入门…

真实数据,告诉你3S相关专业本硕毕业生就业去向

本期推文将基于2015届-2023届3S相关专业毕业生(包括本硕博所有毕业生)的生源地、性别分布、行业岗位等数据进行分析,为各位同学提供一些参考,希望可以对各位同学的职业规划与有一定的帮助。 GIS开发资料分享https://www.wjx.cn/v…

10.Lab Nine —— file system-上

首先切换分支到fs git checkout fs make clean 预备知识 mkfs程序创建xv6文件系统磁盘映像,并确定文件系统的总块数,这个大小在kernel/param.h中的FSSIZE写明 // kernel/params.h #define FSSIZE 200000 // size of file system in blocks Make…

牛客小白月赛101

考点为:A题 滑动窗口、B题 栈、C题 找规律、D 差分、E 筛约数。C题可能会卡住,不过手搓几组,或者模拟几组规律就显而易见了 A: 思路: 无论去头还是去尾,最后所留下的数据长度一定为:n - k &am…

Dbt自动化测试实战教程

数据团队关键核心资产是给消费者提供可信赖的数据。如果提供了不被信任的数据,那么支持决策智能依赖于猜测和直觉。原始数据从不同来源被摄取智数据仓库,数据产品团队有责任定义转换逻辑,将源数据整合到有意义的数据产品中,用于报…

Redisson分布式锁的概念和使用

Redisson分布式锁的概念和使用 一 简介1.1 什么是分布式锁?1.2 Redisson分布式锁的原理1.3 Redisson分布式锁的优势1.4 Redisson分布式锁的应用场景 二 案例2.1 锁竞争案例2.2 看门狗案例2.3 参考文章 前言 这是我在这个网站整理的笔记,有错误的地方请指出&#xff…

生活英语口语柯桥学英语“再确认一下“ 说成 “double confirm“?这是错误的!

在追求英语表达的过程中,我们常常会遇到一些看似合理实则错误的表达习惯。今天,我们就来聊聊一个常见的误区——“再确认一下”被误译为“double confirm”。 “再次确认”不是double confirm 首先,我们需要明确,“double confi…