模型训练-保存训练数据

1.目的

找到一个可运行的代码,可以每个epoch打印训练数据,但是不会保存。因为在改进模型需要这些训练数据进行对比,所以需要将每个epoch的训练数据保存下来,写到一个文件中。

2.解决方案

直接问ChatGPT,提示词如下:

做一个深度学习的图像分类任务,现在需要将每个epoch的训练数据存到一个文件,用python来实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import json# 定义一个简单的网络
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.fc1 = nn.Linear(32 * 28 * 28, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = x.view(x.size(0), -1)x = self.fc1(x)return x# 设置超参数
batch_size = 64
num_epochs = 10
learning_rate = 0.01# 创建数据集和数据加载器
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='data/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# 创建网络、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 文件名
output_file = "training_data.json"# 开始训练
training_data = []for epoch in range(num_epochs):epoch_loss = 0.0correct = 0total = 0# 训练一个 epochfor inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()epoch_loss += loss.item() * inputs.size(0)_, predicted = outputs.max(1)correct += predicted.eq(labels).sum().item()total += labels.size(0)# 计算平均损失和准确度avg_loss = epoch_loss / totalaccuracy = correct / total# 将 epoch 的数据添加到列表中epoch_data = {"epoch": epoch + 1,"avg_loss": avg_loss,"accuracy": accuracy}training_data.append(epoch_data)# 打印 epoch 的训练结果print(f"Epoch {epoch + 1}: Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")# 将所有 epoch 的训练数据保存到文件中
with open(output_file, 'w') as f:json.dump(training_data, f)print(f"Training data saved to {output_file}")

由于缺少数据集,也懒得去找一个数据集,就将自己py文件中训练部分摘出来给ChatGPT,让它在训练方法中实现这个功能。 

def train():data_format(os.path.join(config.root_path, './data/train.txt'),os.path.join(config.root_path, './data/data'), os.path.join(config.root_path, './data/train.json'))data = read_from_file(config.train_data_path, config.data_dir, config.only)train_data, val_data = train_val_split(data)train_loader = processor(train_data, config.train_params)val_loader = processor(val_data, config.val_params)best_acc = 0epoch = config.epochfor e in range(epoch):print('-' * 20 + ' ' + 'Epoch ' + str(e+1) + ' ' + '-' * 20)# 训练模型tloss, tloss_list = trainer.train(train_loader)print('Train Loss: {}'.format(tloss))# writer.add_scalar('Training/loss', tloss, e)# 验证模型vloss, vacc = trainer.valid(val_loader)print('Valid Loss: {}'.format(vloss))print('Valid Acc: {}'.format(vacc))# writer.add_scalar('Validation/loss', vloss, e)# writer.add_scalar('Validacc/acc', vacc, e)# 保存训练数据training_data = {"epoch": e + 1,"train_loss": tloss,"valid_loss": vloss,"valid_acc": vacc}with open('training_data.json', 'a') as f:json.dump(training_data, f)f.write('\n')print("数据保存完成")# 保存最佳模型if vacc > best_acc:best_acc = vaccsave_model(config.output_path, config.fuse_model_type, model)print('Update best model!')print('-' * 20 + ' ' + 'Training Finished' + ' ' + '-' * 20)print('Best Validation Accuracy: {}'.format(best_acc))

在我的代码中具体加入的是下列几行代码

# 保存训练数据
training_data = {"epoch": e + 1,"train_loss": tloss,"valid_loss": vloss,"valid_acc": vacc
}
with open('training_data.json', 'a') as f:json.dump(training_data, f)f.write('\n')
print("数据保存完成")

 

代码意思如下: 

  1. with open('training_data.json', 'a') as f:: 打开名为 'training_data.json' 的文件,以追加模式 'a',并将其赋给变量 f。如果文件不存在,将会创建一个新文件。
  2. json.dump(training_data, f): 将变量 training_data 中的数据以 JSON 格式写入到文件 f 中。这个操作会将 training_data 中的内容转换成 JSON 格式,并写入到文件中。
  3. f.write('\n'): 写入一个换行符 \n 到文件 f 中,确保每次写入 JSON 数据后都有一个新的空行,使得每个 JSON 对象都独占一行,便于后续处理。

这段代码的作用是将变量 training_data 中的数据以 JSON 格式写入到文件 'training_data.json' 中,并确保每次写入后都有一个换行符分隔。

3.结果

可以在每个epoch训练完成后,将训练损失,验证损失和验证准确率保存在training_data.json文件中。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/809557.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringMVC原理及工作流程

组件 SpringMVC的原理主要基于它的各个组件之间的相互协作交互,从而实现了Web请求的接收,处理和响应。 它的组件有如下几个: DispatcherServlet前端控制器 HandlerMapping处理器映射器 Controller处理器 ModelAndView ViewResolver视图…

0基础刷图论最短路 3(从ATcoder 0分到1800分)

AT最短路刷题3(本文难度rated 1200~ 1400) 题目来源:Atcoder 题目收集: https://atcoder-tags.herokuapp.com/tags/Graph/Shortest-Path (里面按tag分类好了Atcoder的所有题目,类似cf) &#x…

Ubuntu22.04安装ffmpeg(v7.0)

需下载文件:ffmpeg-7.0.tar.gz 安装步骤 1. 创建目录 mkdir -p /ffmpeg && cd ffmpeg2. 下载文件 wget https://ffmpeg.org/releases/ffmpeg-7.0.tar.gz3. 解压 tar -zxvf ffmpeg-7.0.tar.gz && cd ffmpeg-7.04. 安装环境依赖 官网说明&#…

练习题(2024/4/11)

1每日温度 给定一个整数数组 temperatures ,表示每天的温度,返回一个数组 answer ,其中 answer[i] 是指对于第 i 天,下一个更高温度出现在几天后。如果气温在这之后都不会升高,请在该位置用 0 来代替。 示例 1: 输入…

Leetcode刷题之消失的数字(C语言版)

Leetcode刷题之消失的数字(C语言版) 一、题目描述二、题目解析 一、题目描述 数组nums包含从0到n的所有整数,但其中缺了一个。请编写代码找出那个缺失的整数。你有办法在O(n)时间内完成吗? 注意:本题相对书上原题稍作…

STM32 文档整理

//***********************************************************************************************************// 英文缩写名称NVIC嵌套向量中断控制器SysTick系统滴答定时器RCC复位和时钟控制GPIO通用IO口AFIO复用IO口EXTI外部中断TIM定时器ADC模数转换器DMA直接内存访…

Java中实现监听UDP协议的指定端口并收到数据按照十六进制输出

场景 对接协议中需要监听UDP协议的指定端口并监听数据,且数据格式为十六进制。 如果是在linux服务上,可以快速通过C或者python脚本等方式实现。 这里使用Java代码实现,可便于后续做其他存储数据等的扩展,且只需要在服务器上安装…

华为OD七日集训第6期 - 按算法分类,由易到难,循序渐进,玩转OD

目录 一、适合人群二、本期训练时间三、如何参加四、七日集训第 6 期五、精心挑选21道高频经典题目,作为入门。第1天、逻辑分析第2天、双指针第3天、滑动窗口第4天、二叉树第5天、矩阵第6天、分治递归第7天、深度优先搜索 大家好,我是哪吒。 最近一直在…

《安静的力量》探寻自我的心灵之旅,找到内心的宁静和真正的幸福 - 三余书屋 3ysw.net

安静的力量:通往止境的冒险 大家好,今天我们要解读的书籍是《安静的力量》。让我们先设想一个画面:在纽约曼哈顿,紧邻繁华的时代广场,一位29岁的青年在他的公寓里工作。这里毗邻纽约最富有人群的聚居区,而…

Windows Edge 兼容性问题修复:提升用户体验的关键步骤

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

如何设置MySQL的IP白名单

当我们谈论设置MySQL数据库的IP白名单时,我们通常是在指定哪些IP地址被允许连接到数据库服务器。这是一种安全措施,可确保只有受信任的主机可以访问数据库。以下是一个分步指南,以及如何设置MySQL的IP白名单的说明。 步骤1: 登录到MySQL服务…

Django框架的基础知识

Django(英文发音:dʒŋgəʊ)是一个开放源代码的Web应用框架,使用高性能的Python语言编写而成。Django框架的诞生,最初是用来开发和管理Lawrence Publishing Group(劳伦斯出版集团)旗下的新闻网…

【vscode】在本地加载远端环境并开发

【vscode】在本地利用远程服务器显卡跑代码 写在最前面vscode:远程到本地1、安装ssh插件2、添加服务器连接配置3、连接服务器4. SSH配置5. 在ssh中安装python解释器 vscode基本操作 🌈你好呀!我是 是Yu欸 🌌 2024每日百字篆刻时光…

JS搜索关键字匹配变色

使用场景:用户通过搜索关健字(keyword),对文本进行匹配,并对匹配到的文本进行一些高亮处理 解析:使用的是JavaScript中的 RegExp(正则表达式)对象,var regExp new RegExp(keyword,…

银河麒麟操作系统修改dns(唯一一篇可以解决DNS生效问题)

背景: Kylin V10 SP1 系统修改dns 1.修改/etc/resolv.conf 临时生效,不满足生产要求 2.修改/etc/network/interface 不生效 3.修改/etc/systemd/resolved.conf,遇到问题,最终解决永久修改DNS 系统版本: root@node01:~# cat /etc/issue Kylin V10 SP1 \n \l 一、如何在…

BLIP 算法阅读记录---一个许多多模态大语言模型的基本组件

论文地址:😈 目录 一、环境配置以及数据集准备 数据集准备 数据集格式展示 环境配置,按照官网所述即可 二、一些调整 vit_base的预训练模型 远程debug的设置 Tokenizer初始化失败 读入网络图片的调整 三、训练过程 Image Encoder …

FebHost:英国.UK域名注册使用中存在哪些侵权行为?

截至2023年6月,英国.uk域名作为全球第五大热门顶级域名,注册数量超过1100万,成为全球最知名和广泛使用的域名之一。英国域名家族包括四个独特的域名后缀——.uk、.co.uk、.org.uk 和 .me.uk——每个都有其独特的特点,并根据数字领…

Mac下用adb命令安装apk到android设备笔记

查询了些资料记录备用。以下是在Mac上使用命令行安装APK文件的步骤: 1. 下载并安装ADB: 如果您的Mac上没有安装ADB,请从官方的Android开发者网站下载Android SDK Platform Tools:Android SDK Platform Tools。将下载的ZIP文件解…

python使用Flask框架开发API

Flask是一个基于Python的轻量级Web应用程序框架。 安装依赖库 pip install flask pip install werkzeug 上传接口 Python from flask import Flask, request from werkzeug.utils import secure_filenameapp Flask(__name__)app.route(/upload, methods[POST]) def uploa…

三次 Bspline(B样条曲线) NURBS曲线的绘制 matlab

先来了解几个概念: 1.1 节点向量: B-Spline需要定义曲线的节点向量U,它可以对应到Bezier曲线的参数u。 其元素个数 (m1) 和曲线阶数 k 、控制点个数n满足:m1k1n1 如果U的每段的距离是相等,那么这个B-Spline就被称为均…