pytorch学习(四)绘制loss和correct曲线

这一次学习的时候静态绘制loss和correct曲线,也就是在模型训练完成后,对统计的数据进行绘制。

以minist数据训练为例子

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as npdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')trainning_data =datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)class MinistNet(nn.Module):def __init__(self):super().__init__()# self.flat = nn.Flatten()self.conv1 = nn.Conv2d(1,1,3,1,1)self.hideLayer1 = nn.Linear(28*28,256)self.hideLayer2 = nn.Linear(256,10)def forward(self,x):x= self.conv1(x)x = x.view(-1,28*28)x = self.hideLayer1(x)x = torch.sigmoid(x)x = self.hideLayer2(x)# x = nn.Sigmoid(x)return xmodel = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)def train():train_losses = []train_acces = []eval_losses = []eval_acces = []#训练model.train()for epoch in range(10):batchsizeNum = 0train_loss = 0train_acc = 0train_correct = 0for x,y in train_loader:# print(epoch)# print(x.shape)# print(y.shape)x = x.to('cuda')y = y.to('cuda')bte = type(x)==torch.Tensorbte1 = type(y)==torch.TensorA = x.deviceB = y.devicepred_y = model(x)loss = criterion(pred_y,y)optimer.zero_grad()loss.backward()optimer.step()loss_val = loss.item()batchsizeNum = batchsizeNum +1train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()# print("loss: ",loss_val,"  ",epoch, "  ", batchsizeNum)train_losses.append(train_loss / len(trainning_data))train_acces.append(train_acc / len(trainning_data))#测试model.eval()with torch.no_grad():num_batch = len(test_data)numSize = len(test_data)test_loss, test_correct = 0,0for x,y in test_loader:x = x.to(device)y = y.to(device)pred_y = model(x)test_loss += criterion(pred_y, y).item()test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchtest_correct /= numSizeeval_losses.append(test_loss)eval_acces.append(test_correct)print("test result:",100 * test_correct,"%  avg loss:",test_loss)PATH = "dict_model_%d_dict.pth"%(epoch)torch.save({"epoch": epoch,"model_state_dict": model.state_dict(), }, PATH)plt.plot(np.arange(len(train_losses)), train_losses, label="train loss")plt.plot(np.arange(len(train_acces)), train_acces, label="train acc")plt.plot(np.arange(len(eval_losses)), eval_losses, label="valid loss")plt.plot(np.arange(len(eval_acces)), eval_acces, label="valid acc")plt.legend()  # 显示图例plt.xlabel('epoches')# plt.ylabel("epoch")plt.title('Model accuracy&loss')plt.show()torch.save(model,"mode_con_line2.pth")#保存网络模型结构# torch.save(model,) #保存模型中的参数torch.save(model.state_dict(),"model_dict.pth")# Press the green button in the gutter to run the script.
if __name__ == '__main__':train()

绘制的图如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/47456.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【zabbix6监控java-tomcat全流程】

目录 一、监控主机安装zabbix-server1、zabbix的安装2、配置数据库3、为zabbix server配置数据库4、启动服务,web界面安装 二、被监控主机安装tomcat1、安装JDK2、安装tomcat 三、zabbix的服务端安装zabbix-java-gateway四、被监控主机tomcat的配置五、web界面添加主机 一、监控…

使用 Web APi - MediaRecorder 获取麦克风资源,报错:Cannot find name ‘MediaRecorder‘ 的解决方法

目录 一、背景: 二、具体解决方法 一、背景: angular 调用 MediaRecorder 来使用麦克风获取声音,(具体要求:angular 前端 按键调用 麦克风,松开按键生成音频文件)代码如下(来自通…

【树莓派3B+】控制引脚输出高低电平

前言一、安装RPI.GPIO库二、编写简单的输出高低电平的程序三、运行程序总结 前言 首先检查一下自己的板子有没有带库 我这个是有的。 ok,正式进入步骤 一、安装RPI.GPIO库 如果还没有安装RPi.GPIO库,可以通过以下命令在树莓派上安装: p…

Ubuntu20.04从零开搭PX4MavrosGazebo环境并测试

仅仅是个人搭建记录 参考链接: https://zhuanlan.zhihu.com/p/686439920 仿真平台基础配置(对应PX4 1.13版) 语雀 mkdir -p ~/tzb/catkin_ws/src mkdir -p ~/tzb/catkin_ws/scripts cd catkin_ws && catkin init catkin build cd…

数据结构day2

一、思维导图 内存分配 二、课后习题 分文件编译 //sys.h #ifndef TEST_H #define TEST_H #define MAX_SIZE 100//定义学生类型 typedef struct Stu {char name[20]; //姓名int age; //年龄double score; //分数 }stu;//定义班级类型 typedef struct Class {struct …

实战:详解Spring创建bean的流程(图解+示例+源码)

概叙 这篇主要总结Spring中bean的创建过程,主要分为加载bean信息–>实例化bean–>属性填充–>初始化阶段–>后置处理等步骤,且每个步骤Spring做的事情都很多,这块源码还是很值得我们都去看一看的。而Spring中Bean的声明周期其实…

GEO数据挖掘从数据下载处理质控到差异分析全流程分析步骤指南

0. 综合的教学视频介绍 GEO数据库挖掘分析作图全流程每晚11点在线教学直播录屏回放视频: https://www.bilibili.com/video/BV1rm42157CT/ GEO数据从下载到各种挖掘分析全流程详解: https://www.bilibili.com/video/BV1nm42157ii/ 一篇今年近期发表的转…

捷配总结的SMT工厂安全防静电规则

SMT工厂须熟记的安全防静电规则! 安全对于我们非常重要,特别是我们这种SMT加工厂,通常我们所讲的安全是指人身安全。 但这里我们须树立一个较为全面的安全常识就是在强调人身安全的同时亦必须注意设备、产品的安全。 电气: 怎样预…

IDEA 调试 Ja-Netfilter

首先本地需要有两款IDEA 可以是相同版本,也可以是不同版本。反正要有两个,一个用来调试代码,一个启动。 移除原有ja-netfiler 打开你的ja-netfiler的vmoptions目录,修改其中的idea.vmoptions文件。移除最后一行-javaagent ...参…

分享 .NET EF6 查询并返回树形结构数据的 2 个思路和具体实现方法

前言 树形结构是一种很常见的数据结构,类似于现实生活中的树的结构,具有根节点、父子关系和层级结构。 所谓根节点,就是整个树的起始节点。 节点则是树中的元素,每个节点可以有零个或多个子节点,节点按照层级排列&a…

AI智能名片S2B2C商城小程序在社群去中心化管理中的应用与价值深度探索

摘要:随着互联网技术的飞速发展,社群经济作为一种新兴的商业模式,正逐渐成为企业与用户之间建立深度连接、促进商业增长的重要途径。本文深入探讨了AI智能名片S2B2C商城小程序在社群去中心化管理中的应用,通过详细分析社群去中心化…

【DGL系列】DGLGraph.out_edges简介

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~ 目录 函数说明 用法示例 示例 1: 获取所有边的源节点和目标节点 示例 2: 获取特定节点的出边 示例 3: 获取所有边的边ID 示例 4: 获取所有信息&a…

中国机器视觉行业上市公司市场竞争格局分析

中国机器视觉产业上市公司汇总:分布在各产业链环节 机器视觉就是用机器来代替人眼做测量和判断的系统,机器检测相较于人工视觉检测优势明显。目前,我国机器视觉产业的上市公司数量较多,分布在各产业链环节。具体包括:…

LeetCode-返回链表倒数第K个节点、链表的回文结构,相交链表

一、返回链表倒数第k个节点 . - 力扣(LeetCode) 本体思路参展寻找中间节点的方法,寻找中间节点是定义快慢指针,快指针每次走两步,慢指针每次走一步,当快指针为空或者快指针的下一个节点是空时,…

4000厂商默认账号密码、默认登录凭证汇总.pdf

获取方式: 链接:https://pan.baidu.com/s/1F8ho42HTQhebKURWWVW1BQ?pwdy2u5 提取码:y2u5

音视频开发入门教程(2)配置FFmpeg编译 ~共210节

在上一篇博客介绍了安装,音视频开发入门教程(1)如何安装FFmpeg?共210节-CSDN博客 感兴趣的小伙伴,可以继续跟着老铁,一起开始音视频剪辑功能,😄首先查看一下自己的电脑是几核的&…

SCSA第七天

防火墙的可靠性 因为防火墙上不仅需要同步配置信息,还需要同步状态信息(会话表等),所以,防火墙不能 像路由器那样单纯的靠动态协议来实现切换,需要用到双机热备技术。 1,双机 --- 目前双机热…

Golang面试题整理(持续更新...)

文章目录 Golang面试题总结一、基础知识1、defer相关2、rune 类型3、context包4、Go 竞态、内存逃逸分析5、Goroutine 和线程的区别6、Go 里面并发安全的数据类型7、Go 中常用的并发模型8、Go 中安全读写共享变量方式9、Go 面向对象是如何实现的10、make 和 new 的区别11、Go 关…

破解反爬虫策略 /_guard/auto.js(二)实战

这次我们用上篇文章讲到的方法来真正破解一下反爬虫策略,这两个案例是两个不同的网站,一个用的是 /_guard/auto.js,另一个用的是/_guard/delay_jump.js。经过解析发现这两个网站用的反爬虫策略基本是一模一样,只不过在js混淆和生成…

HTML2048小游戏(最新版)

比上一篇文章的2048更好一点。 控制方法&#xff1a;WASD键&#xff08;小写&#xff09;或页面上四个按钮 效果图如下&#xff1a; 源代码在图片后面 源代码 HTML <!DOCTYPE html> <html lang"en"> <head><meta charset&…