pytorch学习(十一)checkpoint

当训练一个大模型数据的时候,中途断电就可以造成已经训练几天或者几个小时的工作白做了,再此训练的时候需要从epoch=0开始训练,因此中间要不断保存(epoch,net,optimizer,scheduler)等几个内容,这样才能在发生意外之后快速恢复工作。

通过本博客的学习,你将学会最优模型保存和模型自动加载的方法

2.首先看代码

    1.2保存最优模型

保存最优模型的代码如下:

        if epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)

 min_loss_val 定义成全局的变量之后,应该在用到的函数中,使用global min_loss_val再次定义,否则会报错误

 local variable 'min_loss_val' referenced before assignment 

1.2自动加载模型

#找文件夹中数字最大的文件
def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_bestResume = Falseif Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])

2.代码

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR
import os
import recur_pwd_path = os.getcwd()def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_besttensor = torch.randn(3,3)
bTensor = type(tensor) == torch.Tensor
print(bTensor)
print("tensor is on ", tensor.device)
#数据转到GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():tensor = tensor.to(device)print("tensor is on ",tensor.device)
#数据转到CPU
if tensor.device == 'cuda:0':tensor = tensor.to(torch.device("cpu"))print("tensor is on", tensor.device)
if tensor.device == "cpu":tensor = tensor.to(torch.device("cuda:0"))print("tensor is on", tensor.device)trainning_data =datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)class MinistNet(nn.Module):def __init__(self):super().__init__()# self.flat = nn.Flatten()self.conv1 = nn.Conv2d(1,1,3,1,1)self.hideLayer1 = nn.Linear(28*28,256)self.hideLayer2 = nn.Linear(256,10)def forward(self,x):x= self.conv1(x)x = x.view(-1,28*28)x = self.hideLayer1(x)x = torch.sigmoid(x)x = self.hideLayer2(x)# x = nn.Sigmoid(x)return xmodel = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)scheduler_1 = LambdaLR(optimer, lr_lambda=lambda epoch: 1/(epoch+1))num_epoches =10
min_loss_val = 100000
Resume = Falsedef train():global min_loss_valstart_epoch = -1if Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])train_losses = []train_acces = []eval_losses = []eval_acces = []#训练model.train()tensorboard_ind =0;for epoch in range(num_epoches):batchsizeNum = 0train_loss = 0train_acc = 0train_correct = 0for x,y in train_loader:# print(epoch)# print(x.shape)# print(y.shape)x = x.to('cuda')y = y.to('cuda')bte = type(x)==torch.Tensorbte1 = type(y)==torch.TensorA = x.deviceB = y.devicepred_y = model(x)loss = criterion(pred_y,y)optimer.zero_grad()loss.backward()optimer.step()loss_val = loss.item()batchsizeNum = batchsizeNum +1train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()tensorboard_ind += 1train_losses.append(train_loss / len(trainning_data))train_acces.append(train_acc / len(trainning_data))#测试test_loss_value = 0model.eval()with torch.no_grad():num_batch = len(test_data)numSize = len(test_data)test_loss, test_correct = 0,0for x,y in test_loader:x = x.to(device)y = y.to(device)pred_y = model(x)test_loss += criterion(pred_y, y).item()test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchtest_correct /= numSizeeval_losses.append(test_loss)eval_acces.append(test_correct)test_loss_value = test_lossprint("test result:",100 * test_correct,"%  avg loss:",test_loss)scheduler_1.step()#设置checkpointif epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)# Press the green button in the gutter to run the script.if __name__ == '__main__':train()

3.运行结果为

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873959.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入探索:Stable Diffusion 与传统方法对比:优劣分析

深入探索&#xff1a;Stable Diffusion 与传统方法对比&#xff1a;优劣分析 一、引言 随着人工智能和深度学习的发展&#xff0c;优化算法在神经网络训练中的重要性日益凸显。传统的优化方法&#xff0c;如随机梯度下降&#xff08;SGD&#xff09;、动量法和Adam等&#xf…

动手学深度学习——5.卷积神经网络

1.卷积神经网络特征 现在&#xff0c;我们将上述想法总结一下&#xff0c;从而帮助我们设计适合于计算机视觉的神经网络架构。 平移不变性&#xff08;translation invariance&#xff09;&#xff1a;不管检测对象出现在图像中的哪个位置&#xff0c;神经网络的前面几层应该对…

《昇思 25 天学习打卡营第 15 天 | 基于MindNLP+MusicGen生成自己的个性化音乐 》

《昇思 25 天学习打卡营第 15 天 | 基于MindNLPMusicGen生成自己的个性化音乐 》 活动地址&#xff1a;https://xihe.mindspore.cn/events/mindspore-training-camp 签名&#xff1a;Sam9029 MusicGen概述 MusicGen是由Meta AI的Jade Copet等人提出的一种基于单个语言模型&…

密码学原理精解【8】

文章目录 概率分布哈夫曼编码实现julia官方文档建议的变量命名规范&#xff1a;julia源码 熵一、信息熵的定义二、信息量的概念三、信息熵的计算步骤四、信息熵的性质五、应用举例 哈夫曼编码&#xff08;Huffman Coding&#xff09;基本原理编码过程特点应用具体过程1. 排序概…

Bubbliiiing 的 Retinaface rknn python推理分析

Bubbliiiing 的 Retinaface rknn python推理分析 项目说明 使用的是Bubbliiiing的深度学习教程-Pytorch 搭建自己的Retinaface人脸检测平台的模型&#xff0c;下面是项目的Bubbliiiing视频讲解地址以及源码地址和博客地址&#xff1b; 作者的项目讲解视频&#xff1a;https:…

FFmpeg音视频流媒体的顶级项目

搞音视频、流媒体的圈子,没法躲开ffmpeg这个神级项目。 FFmpeg 是一个功能强大且广泛使用的多媒体处理工具。FFmpeg 具备众多出色的特性。它支持多种音频和视频格式的转换,能轻松将一种格式的文件转换为另一种,满足不同设备和应用的需求。不仅如此,它还可以进行视频的裁剪、…

php编译安装

一、基础环境准备 # php使用www用户 useradd -s /sbin/nologin -M www二、下载php包 # 下载地址 https://www.php.net/downloads wget https://www.php.net/distributions/php-8.3.9.tar.gz三、配置编译安装 编译安装之前需要处理必要的依赖&#xff0c;在编译配置安装&…

使用多进程和多线程实现服务器并发【C语言实现】

在TCP通信过程中&#xff0c;服务器端启动之后可以同时和多个客户端建立连接&#xff0c;并进行网络通信&#xff0c;但是在一个单进程的服务器的时候&#xff0c;提供的服务器代码却不能完成这样的需求&#xff0c;先简单的看一下之前的服务器代码的处理思路&#xff0c;再来分…

Python中自定义上下文管理器的创建与应用

Python中自定义上下文管理器的创建与应用 在Python编程中,上下文管理器是一个非常重要的概念。它们允许我们更好地管理资源,如文件句柄、网络连接或数据库连接等,确保这些资源在使用后能够被正确关闭或释放,从而避免资源泄露。Python标准库中的contextlib模块提供了强大的…

代码随想录学习 day54 图论 Bellman_ford 算法精讲

Bellman_ford 算法精讲 卡码网&#xff1a;94. 城市间货物运输 I 题目描述 某国为促进城市间经济交流&#xff0c;决定对货物运输提供补贴。共有 n 个编号为 1 到 n 的城市&#xff0c;通过道路网络连接&#xff0c;网络中的道路仅允许从某个城市单向通行到另一个城市&#xf…

【HarmonyOS】网络连接 - Http 请求数据

在日常开发应用当中&#xff0c;应用内部有很多数据并不是保存在应用内部&#xff0c;而是在服务端。所以就需要向服务端发起请求&#xff0c;由服务端返回数据。这种请求方式就是 Http 请求。 一、申请网络权限 在 module.json5 文件中&#xff0c;添加网络权限&#xff1a; …

手持式气象站:便携科技,掌握微观气象的利器

手持式气象站&#xff0c;顾名思义&#xff0c;是一种可以随身携带的气象监测设备。它小巧轻便&#xff0c;通常配备有温度、湿度、风速、风向、气压等多种传感器&#xff0c;能够实时测量并显示各种气象参数。不仅如此&#xff0c;它还具有数据存储、数据传输、远程控制等多种…

Django教程(003):orm操作数据库

文章目录 1 orm连接Mysql1.1 安装第三方模块1.2 ORM1.2.1、创建数据库1.2.2、Django连接数据库1.2.3、django操作表1.2.4、创建和修改表结构1.2.5、增删改查1.2.5.1 增加数据1.2.5.2 删除数据1.2.5.3 获取数据1.2.5.4 修改数据 1 orm连接Mysql Django为了使操作数据库更加简单…

Linux shell编程学习笔记65: nice命令 显示和调整进程优先级

0 前言 我们前面学习了Linux命令ps和top&#xff0c;命令的返回信息中包括优先序&#xff08;NI&#xff0c;nice&#xff09; &#xff0c;我们可以使用nice命令来设置进程优先级。 1 nice命令 的功能、格式和选项说明 1.1 nice命令 的功能 nice命令的功能是用于调整进程的…

AP ERP与汉得SRM系统集成案例(制药行业)

一、项目环境 江西某医药集团公司&#xff0c;是一家以医药产业为主营、资本经营为平台的大型民营企业集团。公司成立迄今&#xff0c;企业经营一直呈现稳健、快速发展的态势&#xff0c; 2008 年排名中国医药百强企业前 20 强&#xff0c;2009年集团总销售额约38亿元人民币…

原码、补码、反码、移码是什么?

计算机很多术语翻译成中文之后&#xff0c;不知道是译者出于什么目的&#xff0c;往往将其翻译成一个很难懂的名词。 奇怪的数学定义 下面是关于原码的“吐槽”&#xff0c;可以当作扩展。你可以不看&#xff0c;直接去下一章&#xff0c;没有任何影响。 原码的吐槽放在前面是…

配置单区域OSPF

目录 引言 一、搭建基础网络 1.1 配置网络拓扑图如下 1.2 IP地址表 二、测试每个网段都能单独连通 2.1 PC0 ping通Router1所有接口 2.2 PC1 ping通Router1所有接口 2.3 PC2 ping通Router2所有接口 2.4 PC3 ping通Router2所有接口 2.5 PC4 ping通Router3所有接口 2.…

Git仓库拆分和Merge

1. 问题背景 我们原先有一个项目叫open-api&#xff0c;后来想要做租户独立发展&#xff0c;每个租户独立成一个项目&#xff0c;比如租户akc独立部署一个akc-open-api&#xff0c;租户yhd独立部署一个yhd-open-api&#xff0c;其中大部分代码是相同的&#xff0c;少量租户定制…

2024牛客暑期多校训练营1——A,B

题解&#xff1a; 更新&#xff1a; k1的时候要乘n 代码&#xff1a; #include<bits/stdc.h> #define int long long using namespace std; const int N5e35; typedef long long ll; typedef pair<int,int> PII; int T; int n,m,mod; int fac[N][N]; int dp[N][…

设计模式使用场景实现示例及优缺点(结构型模式——外观模式)

在一个繁忙而复杂的城市中&#xff0c;有一座名为“技术森林”的巨大图书馆。这座图书馆里藏着各种各样的知识宝典&#xff0c;从古老的卷轴到现的电子书籍&#xff0c;无所不包。但是&#xff0c;图书馆之所以得名“技术森林”&#xff0c;是因为它的结构异常复杂&#xff0c;…