【深度学习】(7)--神经网络之保存最优模型

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset) # 总数据大小num_batches = len(dataloader) # 划分的小批次数量model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)print(model.state_dict().keys()) # 输出模型参数名称cnntorch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)if correct > best_acc:best_acc = correct# 2. 保存完整模型(w,b,模型cnn)torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloader,model,loss_fn,optimizer)test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54731.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【从0开始自动驾驶】用python做一个简单的自动驾驶仿真可视化界面

【从0开始自动驾驶】用python做一个简单的自动驾驶仿真可视化界面 废话几句废话不多说,直接上源码目录结构init.pysimulator.pysimple_simulator_app.pyvehicle_config.json 废话几句 自动驾驶开发离不开仿真软件成品仿真软件种类多https://zhuanlan.zhihu.com/p/3…

【CSS】鼠标 、轮廓线 、 滤镜 、 堆叠层级

cursor 鼠标outline 轮廓线filter 滤镜z-index 堆叠层级 cursor 鼠标 值说明值说明crosshair十字准线s-resize向下改变大小pointer \ hand手形e-resize向右改变大小wait表或沙漏w-resize向左改变大小help问号或气球ne-resize向上右改变大小no-drop无法释放nw-resize向上左改变…

蓝桥杯1.小蓝的漆房

样例输入 2 5 2 1 1 2 2 1 6 2 1 2 2 3 3 3样例输出 1 2 import math import os import sys tint(input())#执行的次数 for j in range(t):n,kmap(int,input().split())#n为房间数 k为一次能涂的个数alist(map(int,input().split()))#以列表的形式存放房间的颜色maxvaluemath…

处理RabbitMQ连接和认证问题

在使用RabbitMQ进行消息队列管理时,我们可能会遇到各种连接和认证问题。本文将介绍如何诊断和解决这些问题,并通过使用RabbitMQ的管理端进行登录验证来确保配置正确。 1. 问题概述 在最近的一次部署中,我们遇到了两个主要问题: …

IPSec隧道协议学习(一)

前情回顾 前面介绍的GRE隧道协议,可以字LAN之间通过Internet建立隧道,实现网络间资源共享,但是GRE隧道协议不能实现加密功能,传输的数据不受加密保护,为了实现在隧道间传输数据包收到加密保护,需要使用IPS…

GitLab发送邮件功能详解:如何配置自动化?

GitLab发送邮件的设置指南?怎么优化GitLab发送邮件? GitLab作为一个强大的代码管理平台,不仅提供了代码托管、CI/CD等功能,还集成了发送邮件的功能,使得开发团队能够及时获取项目动态。AokSend将详细介绍如何配置GitL…

代码随想录 -- 回溯 -- 非递减子序列

491. 非递减子序列 - 力扣(LeetCode) 思路:重点是去重 收集结果:每次进入递归先判断path中的元素数量,如果大于1了,就将path收集到result中。 递归参数:nums,index,pa…

2024 go-zero社交项目实战

背景 一位商业大亨,他非常看好国内的社交产品赛道,想要造一款属于的社交产品,于是他找到了负责软件研发的小明。 小明跟张三一拍即合,小明决定跟张三大干一番。 社交产品MVP版本需求 MVP指:Minimum Viable Product&…

职场能力强的人都在做什么---今日头条

【职场里,能力强的人都在做哪些事... - 今日头条】https://m.toutiao.com/is/ikn6kt9q/ 知识雷达 2024-09-21 16:33 目录 职场里,能力强的人都在做哪些事呢? 1、复盘; 2、多角度思考;3、记录信息; 4、永远积极主动;5、主动获取信息差; 6、明确人和人的关系;7、…

【Altium Designer程序开发】BGA芯片自动扇出

BGA自动扇出功能支持将BGA器件从4个方向上扇出,里面有无空白区域均可支持,执行速度非常快,通常在秒级的时间内即可处理完成,程序可以通过以下几种方式启动。 ➡️支持从菜单栏启动 ➡️支持从工具栏启动 ➡️支持从服务器面板启动…

Go weak包前瞻:弱指针为内存管理带来新选择

在介绍Go 1.23引入的unique包的《Go unique包:突破字符串局限的通用值Interning技术实现》一文中,我们知道了unique包底层是基于internal/weak包实现的,internal/weak是一个弱指针功能的Go实现。所谓弱指针(Weak Pointer,也称为弱…

HarmonyOS鸿蒙开发实战(5.0)自定义路由栈管理

鸿蒙HarmonyOS NEXT开发实战往期文章必看(持续更新......) HarmonyOS NEXT应用开发性能实践总结 HarmonyOS NEXT应用开发案例实践总结合集 最新版!“非常详细的” 鸿蒙HarmonyOS Next应用开发学习路线!(从零基础入门…

真实数据,告诉你3S相关专业本硕毕业生就业去向

本期推文将基于2015届-2023届3S相关专业毕业生(包括本硕博所有毕业生)的生源地、性别分布、行业岗位等数据进行分析,为各位同学提供一些参考,希望可以对各位同学的职业规划与有一定的帮助。 GIS开发资料分享https://www.wjx.cn/v…

10.Lab Nine —— file system-上

首先切换分支到fs git checkout fs make clean 预备知识 mkfs程序创建xv6文件系统磁盘映像,并确定文件系统的总块数,这个大小在kernel/param.h中的FSSIZE写明 // kernel/params.h #define FSSIZE 200000 // size of file system in blocks Make…

Redisson分布式锁的概念和使用

Redisson分布式锁的概念和使用 一 简介1.1 什么是分布式锁?1.2 Redisson分布式锁的原理1.3 Redisson分布式锁的优势1.4 Redisson分布式锁的应用场景 二 案例2.1 锁竞争案例2.2 看门狗案例2.3 参考文章 前言 这是我在这个网站整理的笔记,有错误的地方请指出&#xff…

生活英语口语柯桥学英语“再确认一下“ 说成 “double confirm“?这是错误的!

在追求英语表达的过程中,我们常常会遇到一些看似合理实则错误的表达习惯。今天,我们就来聊聊一个常见的误区——“再确认一下”被误译为“double confirm”。 “再次确认”不是double confirm 首先,我们需要明确,“double confi…

2.1 HuggingFists系统架构(二)

部署架构 上图为HuggingFists的部署架构。从架构图可知,HuggingFists主要分为服务器(Server)、计算节点(Node)以及数据库(Storage)三部分。这三部分可以分别部署在不同的机器上,以满足系统的性能需求。为部署方便,HuggingFists社区版将这三部…

YOLOv9改进,YOLOv9主干网络替换为GhostNetV2(华为提出的轻量化架构)

摘要 摘要:轻量级卷积神经网络(CNN)专为移动设备上的应用而设计,具有更快的推理速度。卷积操作只能在窗口区域内捕捉局部信息,这限制了性能的进一步提升。将自注意力引入卷积可以很好地捕捉全局信息,但会极大地拖累实际速度。本文提出了一种硬件友好的注意力机制(称为 D…

前端文件上传全过程

特别说明:ui框架使用的是蚂蚁的antd 这里主要是学习前端上传接口的传递参数包括前端上传之前对于代码的整理 一、第一步将前端页面画出来 源代码: /** 费用管理 - IT费用管理 - 费用数据上传 */ import { useState } from "react"; import {…

Prometheus篇之利用promtool工具校验配置正确性

promtool工具 promtool是Prometheus的一个命令行工具,它提供了一些功能来帮助用户进行Prometheus配置文件(如prometheus.yml)的检查、规则检查和调试。 解释 promtool check config: 验证Prometheus配置文件的语法和设置。 promtool命令&…