pytorch学习(十一)checkpoint

当训练一个大模型数据的时候,中途断电就可以造成已经训练几天或者几个小时的工作白做了,再此训练的时候需要从epoch=0开始训练,因此中间要不断保存(epoch,net,optimizer,scheduler)等几个内容,这样才能在发生意外之后快速恢复工作。

通过本博客的学习,你将学会最优模型保存和模型自动加载的方法

2.首先看代码

    1.2保存最优模型

保存最优模型的代码如下:

        if epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)

 min_loss_val 定义成全局的变量之后,应该在用到的函数中,使用global min_loss_val再次定义,否则会报错误

 local variable 'min_loss_val' referenced before assignment 

1.2自动加载模型

#找文件夹中数字最大的文件
def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_bestResume = Falseif Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])

2.代码

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR
import os
import recur_pwd_path = os.getcwd()def getBestModuleFilename(browser):file_name = browser             #"tf_logs/save_module"filenames = os.listdir(file_name)pattern = r"d+"result = []for i in range(len(filenames)):rst = int(filenames[i][10:-4])result.append(rst)val = max(result)index = result.index(val)file_best = filenames[index]print(file_best)return file_besttensor = torch.randn(3,3)
bTensor = type(tensor) == torch.Tensor
print(bTensor)
print("tensor is on ", tensor.device)
#数据转到GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():tensor = tensor.to(device)print("tensor is on ",tensor.device)
#数据转到CPU
if tensor.device == 'cuda:0':tensor = tensor.to(torch.device("cpu"))print("tensor is on", tensor.device)
if tensor.device == "cpu":tensor = tensor.to(torch.device("cuda:0"))print("tensor is on", tensor.device)trainning_data =datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)class MinistNet(nn.Module):def __init__(self):super().__init__()# self.flat = nn.Flatten()self.conv1 = nn.Conv2d(1,1,3,1,1)self.hideLayer1 = nn.Linear(28*28,256)self.hideLayer2 = nn.Linear(256,10)def forward(self,x):x= self.conv1(x)x = x.view(-1,28*28)x = self.hideLayer1(x)x = torch.sigmoid(x)x = self.hideLayer2(x)# x = nn.Sigmoid(x)return xmodel = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)scheduler_1 = LambdaLR(optimer, lr_lambda=lambda epoch: 1/(epoch+1))num_epoches =10
min_loss_val = 100000
Resume = Falsedef train():global min_loss_valstart_epoch = -1if Resume == False:start_epoch = 0else:#找到数字最大的pth文件path_checkpoint = r'tf_logs/'+"save_module"best_path_checkpoint = getBestModuleFilename(path_checkpoint)if(best_path_checkpoint == ""):returnelse:checkpointResume = torch.load(path_checkpoint)start_epoch = checkpointResume["epoch"]model.load_state_dict(checkpointResume["net"])optimer.load_state_dict(checkpointResume["optimizer"])scheduler_1.load_state_dict(checkpointResume["lr_schedule"])train_losses = []train_acces = []eval_losses = []eval_acces = []#训练model.train()tensorboard_ind =0;for epoch in range(num_epoches):batchsizeNum = 0train_loss = 0train_acc = 0train_correct = 0for x,y in train_loader:# print(epoch)# print(x.shape)# print(y.shape)x = x.to('cuda')y = y.to('cuda')bte = type(x)==torch.Tensorbte1 = type(y)==torch.TensorA = x.deviceB = y.devicepred_y = model(x)loss = criterion(pred_y,y)optimer.zero_grad()loss.backward()optimer.step()loss_val = loss.item()batchsizeNum = batchsizeNum +1train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()tensorboard_ind += 1train_losses.append(train_loss / len(trainning_data))train_acces.append(train_acc / len(trainning_data))#测试test_loss_value = 0model.eval()with torch.no_grad():num_batch = len(test_data)numSize = len(test_data)test_loss, test_correct = 0,0for x,y in test_loader:x = x.to(device)y = y.to(device)pred_y = model(x)test_loss += criterion(pred_y, y).item()test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchtest_correct /= numSizeeval_losses.append(test_loss)eval_acces.append(test_correct)test_loss_value = test_lossprint("test result:",100 * test_correct,"%  avg loss:",test_loss)scheduler_1.step()#设置checkpointif epoch > int(num_epoches/3) and test_loss_value < min_loss_val:min_loss_val = test_loss_valuecheckpoint = {"epoch": epoch,"net": model.state_dict(),"optimizer":optimer.state_dict(),"lr_schedule":scheduler_1.state_dict()}if not os.path.isdir(r'tf_logs/' + "save_module"):os.makedirs("tf_logs/" + "save_module")PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))torch.save(checkpoint, PATH)# Press the green button in the gutter to run the script.if __name__ == '__main__':train()

3.运行结果为

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873959.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

动手学深度学习——5.卷积神经网络

1.卷积神经网络特征 现在&#xff0c;我们将上述想法总结一下&#xff0c;从而帮助我们设计适合于计算机视觉的神经网络架构。 平移不变性&#xff08;translation invariance&#xff09;&#xff1a;不管检测对象出现在图像中的哪个位置&#xff0c;神经网络的前面几层应该对…

Bubbliiiing 的 Retinaface rknn python推理分析

Bubbliiiing 的 Retinaface rknn python推理分析 项目说明 使用的是Bubbliiiing的深度学习教程-Pytorch 搭建自己的Retinaface人脸检测平台的模型&#xff0c;下面是项目的Bubbliiiing视频讲解地址以及源码地址和博客地址&#xff1b; 作者的项目讲解视频&#xff1a;https:…

FFmpeg音视频流媒体的顶级项目

搞音视频、流媒体的圈子,没法躲开ffmpeg这个神级项目。 FFmpeg 是一个功能强大且广泛使用的多媒体处理工具。FFmpeg 具备众多出色的特性。它支持多种音频和视频格式的转换,能轻松将一种格式的文件转换为另一种,满足不同设备和应用的需求。不仅如此,它还可以进行视频的裁剪、…

使用多进程和多线程实现服务器并发【C语言实现】

在TCP通信过程中&#xff0c;服务器端启动之后可以同时和多个客户端建立连接&#xff0c;并进行网络通信&#xff0c;但是在一个单进程的服务器的时候&#xff0c;提供的服务器代码却不能完成这样的需求&#xff0c;先简单的看一下之前的服务器代码的处理思路&#xff0c;再来分…

手持式气象站:便携科技,掌握微观气象的利器

手持式气象站&#xff0c;顾名思义&#xff0c;是一种可以随身携带的气象监测设备。它小巧轻便&#xff0c;通常配备有温度、湿度、风速、风向、气压等多种传感器&#xff0c;能够实时测量并显示各种气象参数。不仅如此&#xff0c;它还具有数据存储、数据传输、远程控制等多种…

Django教程(003):orm操作数据库

文章目录 1 orm连接Mysql1.1 安装第三方模块1.2 ORM1.2.1、创建数据库1.2.2、Django连接数据库1.2.3、django操作表1.2.4、创建和修改表结构1.2.5、增删改查1.2.5.1 增加数据1.2.5.2 删除数据1.2.5.3 获取数据1.2.5.4 修改数据 1 orm连接Mysql Django为了使操作数据库更加简单…

Linux shell编程学习笔记65: nice命令 显示和调整进程优先级

0 前言 我们前面学习了Linux命令ps和top&#xff0c;命令的返回信息中包括优先序&#xff08;NI&#xff0c;nice&#xff09; &#xff0c;我们可以使用nice命令来设置进程优先级。 1 nice命令 的功能、格式和选项说明 1.1 nice命令 的功能 nice命令的功能是用于调整进程的…

AP ERP与汉得SRM系统集成案例(制药行业)

一、项目环境 江西某医药集团公司&#xff0c;是一家以医药产业为主营、资本经营为平台的大型民营企业集团。公司成立迄今&#xff0c;企业经营一直呈现稳健、快速发展的态势&#xff0c; 2008 年排名中国医药百强企业前 20 强&#xff0c;2009年集团总销售额约38亿元人民币…

原码、补码、反码、移码是什么?

计算机很多术语翻译成中文之后&#xff0c;不知道是译者出于什么目的&#xff0c;往往将其翻译成一个很难懂的名词。 奇怪的数学定义 下面是关于原码的“吐槽”&#xff0c;可以当作扩展。你可以不看&#xff0c;直接去下一章&#xff0c;没有任何影响。 原码的吐槽放在前面是…

配置单区域OSPF

目录 引言 一、搭建基础网络 1.1 配置网络拓扑图如下 1.2 IP地址表 二、测试每个网段都能单独连通 2.1 PC0 ping通Router1所有接口 2.2 PC1 ping通Router1所有接口 2.3 PC2 ping通Router2所有接口 2.4 PC3 ping通Router2所有接口 2.5 PC4 ping通Router3所有接口 2.…

Git仓库拆分和Merge

1. 问题背景 我们原先有一个项目叫open-api&#xff0c;后来想要做租户独立发展&#xff0c;每个租户独立成一个项目&#xff0c;比如租户akc独立部署一个akc-open-api&#xff0c;租户yhd独立部署一个yhd-open-api&#xff0c;其中大部分代码是相同的&#xff0c;少量租户定制…

2024牛客暑期多校训练营1——A,B

题解&#xff1a; 更新&#xff1a; k1的时候要乘n 代码&#xff1a; #include<bits/stdc.h> #define int long long using namespace std; const int N5e35; typedef long long ll; typedef pair<int,int> PII; int T; int n,m,mod; int fac[N][N]; int dp[N][…

笔记:Few-Shot Learning小样本分类问题 + 孪生网络 + 预训练与微调

内容摘自王老师的B站视频&#xff0c;大家还是尽量去看视频&#xff0c;老师讲的特别好&#xff0c;不到一小时的时间就缕清了小样本学习的基础知识点~Few-Shot Learning (1/3): 基本概念_哔哩哔哩_bilibili Few-Shot Learning&#xff08;小样本分类&#xff09; 假设现在每类…

【Linux】基础I/O——动静态库的制作

我想把我写的头文件和源文件给别人用 1.把源代码直接给他2.把我们的源代码想办法打包为库 1.制作静态库 1.1.制作静态库的过程 我们先看看怎么制作静态库的&#xff01; makefile 所谓制作静态库 需要将所有的.c源文件都编译为(.o)目标文件。使用ar指令将所有目标文件打包…

谷粒商城实战笔记-38-前端基础-Vue-指令-单向绑定双向绑定

文章目录 一&#xff0c;插值表达式注意事项1&#xff1a;不适合复杂的逻辑处理注意事项2&#xff1a;插值表达式支持文本拼接注意事项3&#xff1a;插值表达式只能在标签体中 二&#xff0c;v-html和v-textv-textv-html区别总结&#xff1a;最佳实践 三&#xff0c;v-model复选…

FastAPI 学习之路(五十六)将token缓存到redis

在之前的文章中&#xff0c;FastAPI 学习之路&#xff08;二十九&#xff09;使用&#xff08;哈希&#xff09;密码和 JWT Bearer 令牌的 OAuth2&#xff0c;FastAPI 学习之路&#xff08;二十八&#xff09;使用密码和 Bearer 的简单 OAuth2&#xff0c;FastAPI 学习之路&…

阵列信号处理学习笔记(二)--空域滤波基本原理

阵列信号 阵列信号处理学习笔记&#xff08;一&#xff09;–阵列信号处理定义 阵列信号处理学习笔记&#xff08;二&#xff09;–空域滤波基本原理 文章目录 阵列信号前言一、阵列信号模型1.1 信号的基本模型1.2 阵列的几何构型1.3 均匀直线阵的阵列信号基本模型 总结 前言…

嵌入式面试总结

C语言中struct和union的区别 struct和union都是常见的复合结构。 结构体和联合体虽然都是由多个不同的数据类型成员组成的&#xff0c;但不同之处在于联合体中所有成员共用一块地址空间&#xff0c;即联合体只存放了一个被选中的成员&#xff0c;结构体中所有成员占用空间是累…

【网络】windows和linux互通收发

windows和linux互通收发 一、windows的udp客户端代码1、代码剖析2、总体代码 二、linux服务器代码三、成果展示 一、windows的udp客户端代码 1、代码剖析 首先我们需要包含头文件以及lib的一个库&#xff1a; #include <iostream> #include <WinSock2.h> #inclu…

【模板代码】用于编写Threejs Demo的模板代码

基础模板代码 使用须知常规模板代码常规Shader模板代码 使用须知 本模板代码&#xff0c;主要用于编写Threejs的Demo&#xff0c;因为本人在早期学习的过程中&#xff0c;大量抄写Threejs/examples下的代码以及各个demo站的代码&#xff0c;所以养成了编写Threejs的demo的习惯…