pytorch 实战【以图像处理为例】

pytorch 实战【以图像处理为例】

  • 训练过程中保存模型
  • 具体在训练中断如何继续


训练过程中保存模型

在PyTorch中,模型训练过程中保存模型通常涉及以下几个步骤:

  1. 保存整个模型:
    使用 torch.save 函数,你可以保存整个模型,包括模型的结构和参数。

    torch.save(model, 'model.pth')
    

    加载模型时,使用 torch.load 函数。

    model = torch.load('model.pth')
    
  2. 保存模型的参数:
    这种方法通常更受欢迎,因为它只保存模型的参数,不保存模型的结构。这样,模型文件会比较小,并且在加载模型时可以更加灵活。

    torch.save(model.state_dict(), 'model_params.pth')
    

    加载模型时,首先创建模型的实例,然后加载参数。

    model = ModelClass()  # replace ModelClass with your model's class name
    model.load_state_dict(torch.load('model_params.pth'))
    
  3. 保存训练的检查点:
    在训练过程中,除了保存模型或模型的参数,通常还会保存其他关键信息,例如优化器的状态、当前的epoch、最佳准确率等。这样,如果训练被中断,可以从检查点继续训练,而不是从头开始。

    checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,# ... any other relevant information
    }
    torch.save(checkpoint, 'checkpoint.pth')
    

    加载检查点时:

    checkpoint = torch.load('checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
  4. 在训练时定期保存模型:
    通常,我们会在每个epoch结束时或在验证准确率提高时保存模型。这样,如果训练过程中出现任何问题,我们可以从最近的检查点恢复。

  • 保存检查点:

在训练循环中,你可能会在每个 epoch 结束时或在模型在验证集上达到新的最佳性能时保存检查点:

# 假设以下变量已经定义:
# model: 你的模型
# optimizer: 你使用的优化器
# epoch: 当前的epoch
# loss: 最近的loss值
# best_accuracy: 迄今为止在验证集上的最佳准确率# 在每个 epoch 结束时或在验证准确率提高时:
if current_accuracy > best_accuracy:  # current_accuracy是这个epoch在验证集上的准确率best_accuracy = current_accuracycheckpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,'best_accuracy': best_accuracy}torch.save(checkpoint, 'best_checkpoint.pth')
  • 加载检查点:

当你希望从检查点继续训练或评估模型时,可以使用以下代码来加载检查点:

# 假设以下变量已经定义:
# model: 你的模型 (需要先实例化)
# optimizer: 你使用的优化器 (需要先实例化)checkpoint = torch.load('best_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
best_accuracy = checkpoint['best_accuracy']# 如果继续训练,可以从上一个 epoch 开始
model.train()

这样,即使训练过程中断,你也可以从上次停止的地方继续,而不是重新开始。

  1. 保存在不同设备上的模型:
    如果你在GPU上训练模型,但希望在CPU上加载模型,可以使用以下方式:
    torch.save(model.state_dict(), 'model_params.pth')
    # Loading on CPU
    model.load_state_dict(torch.load('model_params.pth', map_location=torch.device('cpu')))
    

总之,保存模型是训练深度学习模型的关键部分,它允许我们在训练中断时恢复,或在训练完成后部署模型。

具体在训练中断如何继续

如果训练过程中断并且你已经定期保存了检查点,那么你可以从最近的检查点恢复。以下是一个基本流程,描述如何在训练中断后从上次停止的地方继续:

  1. 加载检查点:

    在开始训练之前,首先加载保存的检查点。

    checkpoint = torch.load('best_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_accuracy = checkpoint.get('best_accuracy', -1)  # 默认为-1,假设你保存了这个值
    
  2. 恢复训练:

    使用从检查点中加载的 start_epoch 作为起始点,并从那里开始你的训练循环。

    for epoch in range(start_epoch, total_epochs):# 训练代码...train_one_epoch()# 验证代码...current_accuracy = validate()# 保存新的检查点,如果模型在验证集上有更好的性能if current_accuracy > best_accuracy:best_accuracy = current_accuracycheckpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'best_accuracy': best_accuracy# ... 你可以添加其他信息,如loss等}torch.save(checkpoint, 'best_checkpoint.pth')
    
  3. 注意点:

    • 学习率调整:如果你使用了学习率调度器,例如 ReduceLROnPlateauStepLR,那么你也应该保存和加载它的状态。这样可以确保学习率调整策略在中断后正确地继续。
    • 随机种子:为了确保训练的可复现性,如果你设置了随机种子,那么在恢复训练之前,你可能需要重新设置相同的随机种子。

通过这种方式,你可以在训练中断后恢复并从上次停止的地方继续,而不会丢失任何进度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/86724.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

hive分区表的元数据信息numRows显示为0

创建分区表 CREATE TABLE `dept_partition`(`deptno` int, `dname` string, `loc` string) PARTITIONED BY (

小程序如何关联公众号来发送模板消息

有时候我们可能需要通过公众号来发送一些小程序的服务通知,比如订单提醒、活动通知等。那么要如何操作呢? 1. 有一个通过了微信认证的服务号。需要确保小程序和公众号是同一个主体的。也就是说,小程序和公众号应该都是属于同一个企业。如果还…

RestTemplate:简化HTTP请求的强大工具

文章目录 什么是RestTemplateRestTemplate的作用代码示例 RestTemplate与HttpClient 什么是RestTemplate RestTemplate是一个在Java应用程序中发送RESTful HTTP请求的强大工具。本文将介绍RestTemplate的定义、作用以及与HttpClient的对比,以帮助读者更好地理解和使…

postgresql 内核源码分析 clog机制流程 commit log文件格式,分离的原因,分组优化及leader更新机制

clog 介绍 ​专栏内容: postgresql内核源码分析手写数据库toadb并发编程 ​开源贡献: toadb开源库 个人主页:我的主页 管理社区:开源数据库 座右铭:天行健,君子以自强不息;地势坤,君…

Vis.js教程(一):基础教程

1、Vis.js是什么 一个动态的、基于浏览器的可视化库。 该库的设计易于使用,能够处理大量动态数据,并能够对数据进行操作和交互。 该库由 DataSet、Timeline、Network、Graph2d 和 Graph3d 组件组成。 Vis.js官网:https://visjs.org/ github…

TongWeb8 专用机使用指导

前言 专用机要求软件以deb、rpm安装包形式提供,通过三合一安全管理工具进行安装,否则软件的可执行程序无法运行,所以TongWeb6、7版本的专用机版本遵循此原则。 TongWeb8安装使用方式 TongWeb8除可以提供deb、rpm安装包形式外,还支…

设计模式:备忘录模式

目录 组件代码示例源码中使用优缺点总结 备忘录模式(Memento Pattern)是一种行为型设计模式,用于在不破坏封装性的前提下,捕获和恢复对象的内部状态。备忘录模式可以将对象的状态保存到备忘录对象中,并在需要时从备忘录…

电脑计算机xinput1_3.dll丢失的解决方法分享,四种修复手段解决问题

日常生活中可能会遇到的问题——xinput1_3.dll丢失的解决方法。我相信,在座的很多朋友都曾遇到过这个问题,那么接下来,我将分享如何解决这个问题的解决方法。 首先,让我们来了解一下xinput1_3.dll文件。xinput1_3.dll是一个动态链…

第1篇 目标检测概述 —(1)目标检测基础知识

前言:Hello大家好,我是小哥谈。目标检测是计算机视觉领域中的一项任务,旨在自动识别和定位图像或视频中的特定目标,目标可以是人、车辆、动物、物体等。目标检测的目标是从输入图像中确定目标的位置,并使用边界框将其标…

Go基础语法:map

9 map Go 语言中提供的映射关系容器为 map ,其内部使用 散列表(hash) 实现。它是一种无序的基于 key-value 的数据结构。 Go 语言中的 map 是引用类型,必须初始化之后才能使用。 9.1 map 定义 Go 语言中 map 的定义语法为&…

sql on条件判断是要注意null值

我是因为用了merge into语法,然后on条件中判断的字段是可配置的,这就导致了,有时候判断条件多的情况下,判断的字段会碰到有null值的情况,如果on两边的字段都是null,null和null对比就会导致结果为false&…

安全防御第二次作业

1. 防火墙支持那些NAT技术,主要应用场景是什么? 防火墙支持几乎所有的NAT技术,包括源NAT、目标NAT、双向NAT等,主要应用场景是保护内部网络免受外部网络的攻击 NAT技术可以将IP数据报文头中的IP地址转换为另一个IP地址&#xff…

stc8H驱动并控制三相无刷电机综合项目技术资料综合篇

stc8H驱动并控制三相无刷电机综合项目技术资料综合篇 🌿相关项目介绍《基于stc8H驱动三相无刷电机开源项目技术专题概要》 🔨停机状态,才能进入设置状态,可以设置调速模式,以及转动方向。 ✨所有的功能基本已经完成调试,目前所想到的功能基本已经都添加和实现。引脚利…

C++入门知识

Hello,今天我们分享一些关于C入门的知识,看完至少让你为后面的类和对象有一定的基础,所以在讲类和对象的时候,我们需要来了解一些关于C入门的知识。 什么是C C语言是结构化和模块化的语言,适合处理较小规模的程序。对…

【Python从入门到进阶】37、selenium关于phantomjs的基本使用

接上篇《36、Selenium 动作交互》 上一篇我们介绍了selenium操作网页的动作内容。本篇我们来学习有关phantomjs的相关知识。 一、selenium的缺点 在介绍PhantomJS之前,让我们先讨论一下直接使用Selenium的一些缺点。 1、显示浏览器窗口:Selenium通常需…

AndroidUtil - 强大易用的安卓工具类库

官网 https://github.com/Blankj/AndroidUtilCode/blob/master/README-CN.md 项目介绍 AndroidUtilCode 🔥 是一个强大易用的安卓工具类库,它合理地封装了安卓开发中常用的函数,具有完善的 Demo 和单元测试,利用其封装好的 API…

CUDA学习笔记0924

一、nvprof分析线程束和内存读写 (1)线程束占用率分析 线程束占用率:nvprof --metrics achieved_occupancy (2)内存读写分析 内核数据读取效率:nvprof --metrics gld_throughput 程序对设备内存带宽利…

《动手学深度学习 Pytorch版》 7.4 含并行连接的网络(GoogLeNet)

import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l7.4.1 Inception块 GoogLNet 中的基本卷积块叫做 Inception 块(大概率得名于盗梦空间),由 4 条并行路径组成。 前 3 条路径使用窗口…

合规性管理如何帮助产品团队按时交付?

成功的产品和产品发布背后通常需要经过一个涉及多个监督机构、多功能团队和利益相关者的复杂流程。在组织的治理、风险管理和合规性(GRC)框架下,产品团队不仅需要追求市场创新,还需要确保符合所有适用的法规、标准和合同要求。由于…

libpcap之socket创建

一、 lipcap回调注册 在libpcap中,最重要的就是打开接口,其中关键函数为pcap_activate。这里只关注Linux平台。 只分析通用平台。 pcap_t * pcap_create(const char *device, char *errbuf) { ... p pcap_create_interface(device_str, errbuf); ... …