笔记3:torch训练测试VGG网络

(1)利用Netron查看网络实际情况

在这里插入图片描述
上图链接
python生成上图代码如下,其中GETVGGnet是搭建VGG网络的程序GETVGGnet.py,VGGnet是该程序中的搭建网络类。netron是需要pip安装的可视化库,注意do_constant_folding=False可以防止Netron中不显示Batchnorm2D层,禁用参数隐藏。

import torch
from torch.autograd import Variable
from GetVGGnet import VGGnet
import netronnet = VGGnet()
x = Variable(torch.FloatTensor(1,3,28,28))
y = net(x)
print(y.data.shape)
onnx_path = "./save_model/VGGnet.onnx"
torch.onnx.export(net, x, onnx_path,do_constant_folding=False)
print(net)
netron.start(onnx_path)

(2)VGG训练测试全过程

此次训练在CPU上进行,迭代次epoch = 10,迭代内轮次batch=300,训练集10000张,测试集2000张。
train loss和train corre分别代表损失和正确率,横轴是不同迭代下每一个伦次的loss&corre累加,一个迭代进行33个轮次,每个迭代最后一个伦次数据不足被网络舍弃,10个迭代总共320次。test loss和test corre是每个一个迭代下所有伦次的正确率平均值。根据图可以看出,训练和测试结果都较好。
在这里插入图片描述
训练的损失和正确率在波动,但总体趋势较好。
在这里插入图片描述
数据集大小可以在此处修改:在这里插入图片描述

代码:cifar10_handle和GetVGGnet在上几篇文章有说明

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: 楠楠星球
@time: 2024/5/10 10:15 
@file: VGGTrain.py-->test
@project: pythonProject
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from GetVGGnet import VGGnet
from cifar10_handle import train_dataset,test_dataset
import matplotlib.pyplot as pltepoch = 10  #迭代次数
learn_rate = 0.01 #初始学习率net = VGGnet().to(device='cpu') #模型实例化
loss_fun = nn.CrossEntropyLoss() #调用损失函数
train_data_loder = DataLoader(dataset=train_dataset,batch_size=300,  #每一次迭代的调用的波次shuffle=True,    #这个波次是否打乱数据集num_workers=4,   # 线程数drop_last=True)  # 最后一个波次数据不足是否舍去test_data_loder = DataLoader(dataset=test_dataset,batch_size=300,shuffle=False,num_workers=4,drop_last=True)# optimizer = torch.optim.Adam(net.parameters(), lr=learn_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learn_rate, momentum=0.5) #优化器# scheduler = torch.optim.lr_scheduler.StepLR(optijumizer, step_size=5, gamma=0.9) #step_size=1表示每迭代一次更新一下学习率
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.7) #学习率调整器def train(epoch_num,train_net):# ------------------------------------------()--------------------------------------loss_base = []corre_base = []test_loss_base = []test_corre_base =[]for epoch in range(epoch_num):# ------------------------------------------(TRAIN)--------------------------------------train_net.train()for i, data in enumerate(train_data_loder):input_tensor, label = datainput_tensor = input_tensor.to(device='cpu')label = label.to(device='cpu')output_tensor = train_net(input_tensor)loss = loss_fun(output_tensor, label)optimizer.zero_grad()loss.backward()optimizer.step()_, pred = torch.max(output_tensor.data, dim=1)correct = pred.eq(label.data).cpu().sum()print(f"训练中:第{epoch + 1}次迭代的小迭代{i}的损失率为:{1.00 * loss.item()},正确率为:{100.00 * correct / 300}")loss_base.append(loss.item())corre_base.append(100.00 * correct.item() / 300)scheduler.step()# ------------------------------------------(TEST)--------------------------------------sum_test_loss = 0sum_test_corre = 0train_net.eval()for i, test_data in enumerate(test_data_loder):input_tensor, label = test_datainput_tensor = input_tensor.to(device='cpu')label = label.to(device='cpu')output_tensor = train_net(input_tensor)loss = loss_fun(output_tensor, label)_, pred = torch.max(output_tensor.data, dim=1)correct = pred.eq(label.data).cpu().sum()sum_test_loss += loss.item()sum_test_corre += correct.item()test_loss = sum_test_loss * 1.0 / len(test_data_loder)test_corre = sum_test_corre * 100.0 / len(test_data_loder) / 300test_loss_base.append(test_loss)test_corre_base.append(test_corre)print(f"测试中:当前迭代的测试集损失为:{test_loss},正确率为:{test_corre}")return loss_base,corre_base,test_loss_base,test_corre_base# ------------------------------------------()--------------------------------------if __name__ == '__main__':[train_loss,train_corre,test_loss,test_corr] = train(epoch,net)fig, axes = plt.subplots(2, 2)axes[0, 0].plot(list(range(1, len(train_loss)+1 )), train_loss,color ='r')axes[0, 0].set_title('train loss')axes[0, 1].plot(list(range(1, len(train_corre) + 1)), train_corre, color ='r')axes[0, 1].set_title('train corre')axes[1, 0].plot(list(range(1, len(test_loss) + 1)), test_loss,color ='r')axes[1, 0].set_title('test loss')axes[1, 1].plot(list(range(1, len(test_corr) + 1)), test_corr,color ='r')axes[1, 1].set_title('test corre')plt.show()# torch.save(net.state_dict(), './save_model/example1.pt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/10818.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【简单介绍下Sass】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

Windows 查找端口号关闭端口号关闭进程的操作流程

Windows 查找端口号关闭端口号关闭进程 8000为端口号 1.查看端口占用程序的ID号 netstat -aon|findstr "8000"比如结果是5684 2.查看ID对应的程序进程 tasklist|findstr "6884"3.关闭进程 taskkill -PID 6884 -F成功: 已终止 PID 为 5684 的进程。

华为机试打卡 HJ2 计算某字符出现次数

要机试了,华孝子求捞,功德 描述 写出一个程序,接受一个由字母、数字和空格组成的字符串,和一个字符,然后输出输入字符串中该字符的出现次数。(不区分大小写字母) 数据范围: 1≤&a…

【复杂网络】如何用简易通俗的方式快速理解什么是“相对重要节点挖掘”?

什么是相对重要节点? 一、相对重要节点的定义二、如何区分相对重要节点与重要节点?1. 相对重要性与节点相似性2. 识别相对重要节点的两个阶段第一阶段:个体重要性值的计算第二阶段:累积重要性值的计算 三、相对重要节点挖掘算法1.…

条件变量解决同步问题之打印金鱼

说明 本代码为jyy老师上课演示条件变量解决同步问题示例(本人只做记录与分享) 本人未使用老师封装的POSIX线程库, 直接在单文件中调试并注释 问题描述 有三类线程 T1 若干: 死循环打印< T2 若干: 死循环打印> T3 若干: 死循环打印_ 任务: 对线程同步&#xff0c;使得屏幕…

ASP.NET一种基于C2C模式的网上购物系统的设计与实现

摘 要 网络购物已经慢慢地从一个新鲜的事物逐渐变成日常生活的一部分&#xff0c;以其特殊的优势而逐渐深入人心。本课题是设计开发一种基于C2C模式的网上购物系统。让各用户使用浏览器进行商品浏览。注册用户可以轻松的展示自己的网络商店&#xff0c;能对自己的用户信息进行…

Vagrant + docker搭建Jenkins 部署环境

有人问&#xff0c;为什么要用Jenkins&#xff1f;我说下我以前开发的痛点&#xff0c;在一些中小型企业&#xff0c;每次开发一个项目完成后&#xff0c;需要打包部署&#xff0c;可能没有专门的运维人员&#xff0c;只能开发人员去把项目打成一个war包&#xff0c;可能这个项…

钉钉群定时发送消息1.0软件【附源码】

内容目录 一、详细介绍二、效果展示1.部分代码2.效果图展示 三、学习资料下载 一、详细介绍 有时候需要在钉钉群里提醒一些消息。要通知的群成员又不方便用定时钉的功能&#xff0c;所以写了这么一个每日定时推送群消息的工具。 易语言程序&#xff0c;附上源码与模块&#x…

C++中vector的简单实现

文章目录 一、主要任务1. 查看文档的网站的链接2.内部模拟的函数 二、本人的模拟实现过程1. 所需模拟实现的函数a.构造、拷贝构造b. reverse()扩容c.insert()、push_back()插入数据d. erase()、pop_back()删除数据e. swap()交换f. begin()、end()非const与const迭代器g. 完善构…

mysql的存储结构

一个表就是一个ibd文件 .ibd文件大小取决于数据和索引&#xff0c;在5.7之后才会为每个表生成一个独立表空间即一个ibd文件&#xff0c;在此之前&#xff0c;所有表默认下都会存储在“系统表空间”&#xff08;共享表空间&#xff09;&#xff0c;所有表都在一个ibd文件。 inn…

示例六、湿敏传感器

通过以下几个示例来具体展开学习,了解湿敏传感器原理及特性&#xff0c;学习湿敏传感器的应用&#xff1a; 示例六、湿敏传感器 一、基本原理&#xff1a;随着人们生活水平的不断提高&#xff0c;湿度监控逐步提到议事日程上。由于北方地区秋冬季干燥&#xff0c;需要控制室内…

16.接口自动化学习-编码处理与装饰器

1.编码和解码 编码&#xff1a;将自然语言翻译成计算机可以识别的语言 hello–01010 解码&#xff1a;将机器识别的语言翻译成自然语言 2.编码格式 UTF-8 GBK unicode 3.编码操作 #编码操作str1"hello呀哈哈哈"str2str1.encode(gbk)print(str2)print(type(str2))…

js原型链与继承笔记

前置阅读&#xff1a;https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Inheritance_and_the_prototype_chain js中的“类”是一个函数。function test() {}中&#xff0c;test是由Function生成的。prototype与__proto__的区别&#xff1a; 前者是js函数&#xff08;C…

Linux学习之路 -- 文件系统 -- 缓冲区

前面介绍了文件描述符的相关知识&#xff0c;下面我们将介绍缓冲区的相关知识。 本质上来说&#xff0c;缓冲区就是一块内存区域&#xff0c;因为内核上的缓冲区较复杂&#xff0c;所以本文主要介绍C语言的缓冲区。 目录 1.为什么要有缓冲区 2.应用层缓冲区的默认刷新策略 …

如何在bud里弄3d模型?---模大狮模型网

随着数字化设计的不断发展&#xff0c;越来越多的设计软件提供了对3D模型的支持&#xff0c;为设计师们带来了更广阔的创作空间。Bud作为一款功能强大的设计工具&#xff0c;也提供了添加和编辑3D模型的功能&#xff0c;让用户能够更加灵活地进行设计创作。本文将为您详细介绍如…

【计算机网络】计算机网络体系结构

&#x1f6a9;本文已收录至专栏&#xff1a;计算机网络学习之旅 一.常见的三种结构 (1) OSI参考模型 为了使不同体系结构的计算机网络都能互连起来&#xff0c;国际标准化组织于1977年成立了专门机构研究该问题&#xff0c;提出了著名的开放系统互连基本参考模型&#xff0c…

pycharm 将项目连同库一起打包及虚拟环境的使用

目录 一、创建虚拟环境 1、用 anaconda 创建 2、Pycharm 直接创建 二、虚拟环境安装第三方库 1、创建项目后&#xff0c;启动终端(Alt F12)&#xff0c;或者点击下方标记处。 2、使用 pip 或者 conda 来进行三方库的安装或卸载 3、将项目中的库放入文档&#xff0c;便于…

李宏毅-注意力机制详解

原视频链接&#xff1a;attention 一. 基本问题分析 1. 模型的input 无论是预测视频观看人数还是图像处理&#xff0c;输入都可以看作是一个向量&#xff0c;输出是一个数值或类别。然而&#xff0c;若输入是一系列向量&#xff0c;长度可能会不同&#xff0c;例如把句子里的…

Spring STOMP-消息处理流程

一旦STOMP的接口被公布&#xff0c;Spring应用程序就成为连接客户端的STOMP代理。本节描述服务端消息处理的流程。 spring-messaging模块包含消息类应用的基础功能&#xff0c;这些功能起源于Spring Integration项目。并且&#xff0c;后来被提取整合到Spring框架&#xff0c;…