模型训练-保存训练数据

1.目的

找到一个可运行的代码,可以每个epoch打印训练数据,但是不会保存。因为在改进模型需要这些训练数据进行对比,所以需要将每个epoch的训练数据保存下来,写到一个文件中。

2.解决方案

直接问ChatGPT,提示词如下:

做一个深度学习的图像分类任务,现在需要将每个epoch的训练数据存到一个文件,用python来实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import json# 定义一个简单的网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.fc1 = nn.Linear(32 * 28 * 28, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = x.view(x.size(0), -1)x = self.fc1(x)return x# 设置超参数
batch_size = 64
num_epochs = 10
learning_rate = 0.01# 创建数据集和数据加载器
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='data/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# 创建网络、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 文件名
output_file = "training_data.json"# 开始训练
training_data = []for epoch in range(num_epochs):epoch_loss = 0.0correct = 0total = 0# 训练一个 epochfor inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()epoch_loss += loss.item() * inputs.size(0)_, predicted = outputs.max(1)correct += predicted.eq(labels).sum().item()total += labels.size(0)# 计算平均损失和准确度avg_loss = epoch_loss / totalaccuracy = correct / total# 将 epoch 的数据添加到列表中epoch_data = {"epoch": epoch + 1,"avg_loss": avg_loss,"accuracy": accuracy}training_data.append(epoch_data)# 打印 epoch 的训练结果print(f"Epoch {epoch + 1}: Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")# 将所有 epoch 的训练数据保存到文件中
with open(output_file, 'w') as f:json.dump(training_data, f)print(f"Training data saved to {output_file}")

由于缺少数据集,也懒得去找一个数据集,就将自己py文件中训练部分摘出来给ChatGPT,让它在训练方法中实现这个功能。 

def train():data_format(os.path.join(config.root_path, './data/train.txt'),os.path.join(config.root_path, './data/data'), os.path.join(config.root_path, './data/train.json'))data = read_from_file(config.train_data_path, config.data_dir, config.only)train_data, val_data = train_val_split(data)train_loader = processor(train_data, config.train_params)val_loader = processor(val_data, config.val_params)best_acc = 0epoch = config.epochfor e in range(epoch):print('-' * 20 + ' ' + 'Epoch ' + str(e+1) + ' ' + '-' * 20)# 训练模型tloss, tloss_list = trainer.train(train_loader)print('Train Loss: {}'.format(tloss))# writer.add_scalar('Training/loss', tloss, e)# 验证模型vloss, vacc = trainer.valid(val_loader)print('Valid Loss: {}'.format(vloss))print('Valid Acc: {}'.format(vacc))# writer.add_scalar('Validation/loss', vloss, e)# writer.add_scalar('Validacc/acc', vacc, e)# 保存训练数据training_data = {"epoch": e + 1,"train_loss": tloss,"valid_loss": vloss,"valid_acc": vacc}with open('training_data.json', 'a') as f:json.dump(training_data, f)f.write('\n')print("数据保存完成")# 保存最佳模型if vacc > best_acc:best_acc = vaccsave_model(config.output_path, config.fuse_model_type, model)print('Update best model!')print('-' * 20 + ' ' + 'Training Finished' + ' ' + '-' * 20)print('Best Validation Accuracy: {}'.format(best_acc))

在我的代码中具体加入的是下列几行代码

# 保存训练数据
training_data = {"epoch": e + 1,"train_loss": tloss,"valid_loss": vloss,"valid_acc": vacc
}
with open('training_data.json', 'a') as f:json.dump(training_data, f)f.write('\n')
print("数据保存完成")

 

代码意思如下: 

  1. with open('training_data.json', 'a') as f:: 打开名为 'training_data.json' 的文件,以追加模式 'a',并将其赋给变量 f。如果文件不存在,将会创建一个新文件。
  2. json.dump(training_data, f): 将变量 training_data 中的数据以 JSON 格式写入到文件 f 中。这个操作会将 training_data 中的内容转换成 JSON 格式,并写入到文件中。
  3. f.write('\n'): 写入一个换行符 \n 到文件 f 中,确保每次写入 JSON 数据后都有一个新的空行,使得每个 JSON 对象都独占一行,便于后续处理。

这段代码的作用是将变量 training_data 中的数据以 JSON 格式写入到文件 'training_data.json' 中,并确保每次写入后都有一个换行符分隔。

3.结果

可以在每个epoch训练完成后,将训练损失,验证损失和验证准确率保存在training_data.json文件中。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/809557.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringMVC原理及工作流程

组件 SpringMVC的原理主要基于它的各个组件之间的相互协作交互,从而实现了Web请求的接收,处理和响应。 它的组件有如下几个: DispatcherServlet前端控制器 HandlerMapping处理器映射器 Controller处理器 ModelAndView ViewResolver视图…

练习题(2024/4/11)

1每日温度 给定一个整数数组 temperatures ,表示每天的温度,返回一个数组 answer ,其中 answer[i] 是指对于第 i 天,下一个更高温度出现在几天后。如果气温在这之后都不会升高,请在该位置用 0 来代替。 示例 1: 输入…

Leetcode刷题之消失的数字(C语言版)

Leetcode刷题之消失的数字(C语言版) 一、题目描述二、题目解析 一、题目描述 数组nums包含从0到n的所有整数,但其中缺了一个。请编写代码找出那个缺失的整数。你有办法在O(n)时间内完成吗? 注意:本题相对书上原题稍作…

Java中实现监听UDP协议的指定端口并收到数据按照十六进制输出

场景 对接协议中需要监听UDP协议的指定端口并监听数据,且数据格式为十六进制。 如果是在linux服务上,可以快速通过C或者python脚本等方式实现。 这里使用Java代码实现,可便于后续做其他存储数据等的扩展,且只需要在服务器上安装…

华为OD七日集训第6期 - 按算法分类,由易到难,循序渐进,玩转OD

目录 一、适合人群二、本期训练时间三、如何参加四、七日集训第 6 期五、精心挑选21道高频经典题目,作为入门。第1天、逻辑分析第2天、双指针第3天、滑动窗口第4天、二叉树第5天、矩阵第6天、分治递归第7天、深度优先搜索 大家好,我是哪吒。 最近一直在…

《安静的力量》探寻自我的心灵之旅,找到内心的宁静和真正的幸福 - 三余书屋 3ysw.net

安静的力量:通往止境的冒险 大家好,今天我们要解读的书籍是《安静的力量》。让我们先设想一个画面:在纽约曼哈顿,紧邻繁华的时代广场,一位29岁的青年在他的公寓里工作。这里毗邻纽约最富有人群的聚居区,而…

Windows Edge 兼容性问题修复:提升用户体验的关键步骤

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

Django框架的基础知识

Django(英文发音:dʒŋgəʊ)是一个开放源代码的Web应用框架,使用高性能的Python语言编写而成。Django框架的诞生,最初是用来开发和管理Lawrence Publishing Group(劳伦斯出版集团)旗下的新闻网…

【vscode】在本地加载远端环境并开发

【vscode】在本地利用远程服务器显卡跑代码 写在最前面vscode:远程到本地1、安装ssh插件2、添加服务器连接配置3、连接服务器4. SSH配置5. 在ssh中安装python解释器 vscode基本操作 🌈你好呀!我是 是Yu欸 🌌 2024每日百字篆刻时光…

BLIP 算法阅读记录---一个许多多模态大语言模型的基本组件

论文地址:😈 目录 一、环境配置以及数据集准备 数据集准备 数据集格式展示 环境配置,按照官网所述即可 二、一些调整 vit_base的预训练模型 远程debug的设置 Tokenizer初始化失败 读入网络图片的调整 三、训练过程 Image Encoder …

FebHost:英国.UK域名注册使用中存在哪些侵权行为?

截至2023年6月,英国.uk域名作为全球第五大热门顶级域名,注册数量超过1100万,成为全球最知名和广泛使用的域名之一。英国域名家族包括四个独特的域名后缀——.uk、.co.uk、.org.uk 和 .me.uk——每个都有其独特的特点,并根据数字领…

Mac下用adb命令安装apk到android设备笔记

查询了些资料记录备用。以下是在Mac上使用命令行安装APK文件的步骤: 1. 下载并安装ADB: 如果您的Mac上没有安装ADB,请从官方的Android开发者网站下载Android SDK Platform Tools:Android SDK Platform Tools。将下载的ZIP文件解…

三次 Bspline(B样条曲线) NURBS曲线的绘制 matlab

先来了解几个概念: 1.1 节点向量: B-Spline需要定义曲线的节点向量U,它可以对应到Bezier曲线的参数u。 其元素个数 (m1) 和曲线阶数 k 、控制点个数n满足:m1k1n1 如果U的每段的距离是相等,那么这个B-Spline就被称为均…

关于UCG游戏平台的一些思考

UCG游戏平台,全称User Generated Content,即用户生成内容。它涵盖了所有玩家可以自主编辑的部分,包含并不限于换装、捏脸、关卡摆放等内容。 UCG概念在最近又火了起来,但这个模式出现的并不早。早在10多年前,war3编辑器…

为linux和windows系统备份还原点,防止系统出问题无法恢复

一、linux系统操作办法: sudo apt update sudo apt install timeshift timeshift --create 输出结果如下: 等待约5分钟就会创建成功: 这个备份功能只备份系统,不备份文件,但也不会删除文件。 工作站系统的保存位置&a…

Win10安装sqlplus遇到报错的解决办法

1.下载安装sqlplus.exe的错误解决过程 最近有用到sqlplus连接Oracle数据库执行自动化脚本,Orcle服务器版本是11.2.0.1。在Navicat工具上通过如下语句查询到的版本信息截图如图1所示: SELECT * FROM v$version; 图1 Oracle服务器版本信息 其中“Oracle Da…

Docker部署SpringBoot+Vue前后端分离项目

文章目录 1. 安装Docker1. 1 卸载旧版Docker1.2 配置yum仓库1.3 安装Docker1.4 添加自启动配置1.5 配置阿里云镜像加速1.6 测试 2. 安装Nginx2.1 拉取镜像2.2 安装Nginx2.3 测试 3. 安装MySQL3.1 拉取镜像3.2 安装MySQL3.3 连接MySQL 4. 部署SpringBoot项目4.1 Maven打包4.2 编…

深度学习Vue框架生命周期(三)

一.什么是生命周期? 在vue中,生命周期就是vue实例程序从创建到销毁的这个过程,在生命周期中,不同阶段我们可以做不同的事情。vue的生命周期是创建阶段、挂载阶段、更新阶段、销毁阶段 二.什么是钩子函数? 钩子函数就是…

数据库数据恢复—Sql Server数据库文件丢失如何恢复数据?

服务器数据恢复环境: 一台安装windows server操作系统的服务器。一组由8块硬盘组建的RAID5,划分LUN供这台服务器使用。 在windows服务器内装有SqlServer数据库。存储空间LUN划分了两个逻辑分区。 服务器故障&初检: 由于未知原因&#xf…

Windows联网状态工具TCPView

文章目录 TCPView命令行工具更多Sysinternals Suite工具 TCPView TCPView用于显示系统上所有 TCP 和 UDP 终结点的详细列表,包括本地和远程地址以及 TCP 连接的状态,界面如下。 列表的表头含义如下 表头含义表头含义Process name应用名称Process id进程…