Pytorch实战(二):VGG神经网络

文章目录

  • 一、诞生背景
  • 二、VGG网络结构
    • 2.1VGG块
    • 2.2网络运行流程
    • 2.3总结
  • 三、实战
    • 3.1搭建模型
    • 3.2模型训练
    • 3.3训练结果可视化
    • 3.4模型参数初始化


一、诞生背景

在这里插入图片描述
在这里插入图片描述
  从网络结构中可看出,所有版本VGG均全部使用3×3大小、步长为1的小卷积核,3×3卷积核同时也是最小的能够表示上下左右中心的尺寸。
在这里插入图片描述
假设输入图像尺寸为假输入为5×5,使用2次3×3卷积后最终得到1×1的特征图,那么这个1×1的特征图的感受野为5×5。这和直接使用一个5×5卷积核得到1×1的特征图是一样的。也就是说2次3×3卷积可以代替一次5×5卷积同时,并且,2次3×3卷积的参数更少(2×3×3=18<5×5=25)而且会经过两次激活函数进行非线性变换,学习能力会更好。同样的3次3×3卷积可以替代一次7×7的卷积。并且,步长为1可以不会丢失信息,网络深度增加可以提高网络性能。

二、VGG网络结构

2.1VGG块

在这里插入图片描述
一个VGG_bolck的组成:

  • 带填充以保持分辨率的卷积层:指对输入特征图卷积操作时会带有填充,使得只改变通道数而不改变图像高、宽。
  • 非线性激活函数ReLU:卷积操作后将特征图输入激活函数,提供使之具有非线性性。
  • 池化层、最大池化层:使用最大池化函数,不改变图像通道数,但会缩小图像尺寸。

对于卷积层、池化层有:

  • 卷积层:使用3x3大小的卷积核,padding=1,stride=1,output=(input-3+2×1)/1+1=input,使得特征图尺寸不变。
  • 池化层:使用2x2大小的核,padding=0,stride=2,output=(input-2)/2+1=1/2input,特征图尺寸减半。

2.2网络运行流程

输入层:输入大小为 ( 224 , 224 , 3 ) (224,224,3) (224,224,3)的RGB图像。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.3总结

在这里插入图片描述

三、实战

3.1搭建模型

import torch
from torch import nn
from torchsummary import summarydevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()self.block1 = nn.Sequential(# 本案例中使用FashionMNIST数据集,所以输入通道数为1nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block5 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Sequential(nn.Flatten(),nn.Linear(512 * 7 * 7, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10))def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.fc(x)return xmodel = VGG16().to(device)
summary(model, (1, 224, 224))

在这里插入图片描述

3.2模型训练

  使用模板:

import torch
from torch import nn
import copy
import time
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.utils.data as Data
train_data = FashionMNIST(root='./', train=True, download=True,transform=transforms.Compose([transforms.Resize(size=224), transforms.ToTensor()]))
def train_val_process(train_data, batch_size=128):train_data, val_data = Data.random_split(train_data,lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])train_loader = Data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True,num_workers=8)val_loader = Data.DataLoader(dataset=val_data,batch_size=batch_size,shuffle=True,num_workers=8)return train_loader, val_loadertrain_dataloader, val_dataloader = train_val_process(train_data, batch_size=64)def train(model, train_dataloader, val_dataloader, epochs=30, lr=0.001, model_saveName=None, model_saveCsvName=None ):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")optimizer = torch.optim.Adam(model.parameters(), lr=lr)criterion = nn.CrossEntropyLoss()model = model.to(device)# 复制当前模型的参数best_model_params = copy.deepcopy(model.state_dict())# 最高准确率best_acc = 0.0# 训练集损失函数列表train_loss_list = []# 验证集损失函数列表val_loss_list = []# 训练集精度列表train_acc_list = []# 验证集精度列表val_acc_list = []# 记录当前时间since = time.time()for epoch in range(epochs):print("Epoch {}/{}".format(epoch + 1, epochs))print("-" * 10)# 当前轮次训练集的损失值train_loss = 0.0# 当前轮次训练集的精度train_acc = 0.0# 当前轮次验证集的损失值val_loss = 0.0# 当前轮次验证集的精度val_acc = 0.0# 训练集样本数量train_num = 0# 验证集样本数量val_num = 0# 按批次进行训练for step, (x, y) in enumerate(train_dataloader):  # 取出一批次的数据及标签x = x.to(device)y = y.to(device)# 设置模型为训练模式model.train()out = model(x)# 查找每一行中最大值对应的行标,即为对应标签pre_label = torch.argmax(out, dim=1)# 计算损失函数loss = criterion(out, y)optimizer.zero_grad()loss.backward()optimizer.step()# 累计损失函数,其中,loss.item()是一批次内每个样本的平均loss值(因为x是一批次样本),乘以x.size(0),即为该批次样本损失值的累加train_loss += loss.item() * x.size(0)# 累计精度(训练成功的样本数)train_acc += torch.sum(pre_label == y.data)# 当前用于训练的样本数量(对应dim=0)train_num += x.size(0)# 按批次进行验证for step, (x, y) in enumerate(val_dataloader):x = x.to(device)y = y.to(device)# 设置模型为验证模式model.eval()torch.no_grad()out = model(x)# 查找每一行中最大值对应的行标,即为对应标签pre_label = torch.argmax(out, dim=1)# 计算损失函数loss = criterion(out, y)# 累计损失函数val_loss += loss.item() * x.size(0)# 累计精度(验证成功的样本数)val_acc += torch.sum(pre_label == y.data)# 当前用于验证的样本数量val_num += x.size(0)# 计算该轮次训练集的损失值(train_loss是一批次样本损失值的累加,需要除以批次数量得到整个轮次的平均损失值)train_loss_list.append(train_loss / train_num)# 计算该轮次的精度(训练成功的总样本数/训练集样本数量)train_acc_list.append(train_acc.double().item() / train_num)# 计算该轮次验证集的损失值val_loss_list.append(val_loss / val_num)# 计算该轮次的精度(验证成功的总样本数/验证集样本数量)val_acc_list.append(val_acc.double().item() / val_num)# 打印训练、验证集损失值(保留四位小数)print("轮次{} 训练 Loss: {:.4f}, 训练 Acc: {:.4f}".format(epoch+1, train_loss_list[-1], train_acc_list[-1]))print("轮次{} 验证 Loss: {:.4f}, 验证 Acc: {:.4f}".format(epoch+1, val_loss_list[-1], val_acc_list[-1]))# 如果当前轮次验证集精度大于最高精度,则保存当前模型参数if val_acc_list[-1] > best_acc:# 保存当前最高准确度best_acc = val_acc_list[-1]# 保存当前模型参数best_model_params = copy.deepcopy(model.state_dict())print("保存当前模型参数,最高准确度: {:.4f}".format(best_acc))# 训练耗费时间time_use = time.time() - sinceprint("当前轮次耗时: {:.0f}m {:.0f}s".format(time_use // 60, time_use % 60))# 加载最高准确率下的模型参数,并保存模型torch.save(best_model_params, model_saveName)train_process = pd.DataFrame(data={'epoch': range(epochs),'train_loss_list': train_loss_list,'train_acc_list': train_acc_list,'val_loss_list': val_loss_list,'val_acc_list': val_acc_list})train_process.to_csv(model_saveCsvName, index=False)return train_process
model_saveName="VGG16_best_model.pth"
model_saveCsvName="VGG16_train_process.csv"
train_process = train(model, train_dataloader, val_dataloader, epochs=15, lr=0.001, model_saveName=model_saveName, model_saveCsvName=model_saveCsvName)

3.3训练结果可视化

def train_process_visualization(train_process):plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_process['epoch'], train_process['train_loss_list'], 'ro-', label='train_loss')plt.plot(train_process['epoch'], train_process['val_loss_list'], 'bs-', label='val_loss')plt.legend()plt.xlabel('epoch')plt.ylabel('loss')plt.subplot(1, 2, 2)plt.plot(train_process['epoch'], train_process['train_acc_list'], 'ro-', label='train_acc')plt.plot(train_process['epoch'], train_process['val_acc_list'], 'bs-', label='val_acc')plt.legend()plt.xlabel('epoch')plt.ylabel('acc')plt.legend()plt.show()
train_process_visualization(train_process)

  训练后可能会出现如下结果:
在这里插入图片描述
训练结果可能会时好时坏。事实上,VGG16共有16层网络,当进行反向传播从输出层向输入层运算时,可能会出现梯度消失使得参数无法收敛的情况。由于参数初始化是随机的,可能相对于真实值过大或过小,此时梯度消失就可能会使得参数值无法收敛。此时就需要按照一定的方式初始化参数。

3.4模型参数初始化

import torch
from torch import nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()self.block1 = nn.Sequential(# 本案例中使用FashionMNIST数据集,所以输入通道数为1nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block5 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Sequential(nn.Flatten(),nn.Linear(512 * 7 * 7, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10))# 参数初始化for m in self.modules():# 判断是否是具有参数的网络层,无参数就无需初始化if isinstance(m, nn.Conv2d):# Kaiming初始化方法常用于初始化卷积层参数w,需指定下一层使用的激活函数nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None: # 偏置项bias有约定俗成的初始化方式(初始化为0)nn.init.constant_(m.bias, val=0)elif isinstance(m, nn.Linear):# 全连接层参数初始化往往使用正态分布的方式nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.zeros_(m.bias)if m.bias is not None:nn.init.constant_(m.bias, val=0)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.fc(x)return xmodel = VGG16().to(device)

事实上,批次大小会影响模型学习到的特征和参数更新的方向。较大的批次可以获得更稳定的梯度更新,但可能会丢失一些细节信息;较小的批次则可以捕捉到更细节的模式,但更新的梯度可能会更加不稳定。合理的批次大小选择可以在训练速度和模型性能之间达到平衡。一般的,建议批次大小为64、128左右,而若硬件性能不够,也可通过减少全连接层参数个数以换取较大的批次,因为全连接层参数过多,往往并不全都需要:

self.fc = nn.Sequential(nn.Flatten(),nn.Linear(512 * 7 * 7, 256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 10)
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/40852.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java | Leetcode Java题解之第205题同构字符串

题目&#xff1a; 题解&#xff1a; class Solution {public boolean isIsomorphic(String s, String t) {Map<Character, Character> s2t new HashMap<Character, Character>();Map<Character, Character> t2s new HashMap<Character, Character>(…

Java-数据结构

数据结构概述 常见的数据结构 栈 队列 数组 链表 二叉树 二叉查找树 平衡二叉树 红黑树 示例&#xff1a;

【Go】编译frp,绕过内网安全工具

文章目录 概述常用命令编译环境配置开发环境拉取依赖打包exe输出运行打包好的exe测试 绕过安全产品实践frp使用教程 本文所提供的程序(方法)可能带有攻击性&#xff0c;仅供安全研究与教学之用。文章作者无法鉴别判断读者使用信息及工具的真实用途&#xff0c;若读者将文章中的…

2024 年第十四届 APMCM 亚太地区大学生数学建模 B题 洪水灾害的数据分析与预测--完整思路代码分享(仅供学习)

洪水是暴雨、急剧融冰化雪、风暴潮等自然因素引起的江河湖泊水量迅速增加&#xff0c;或者水位迅猛上涨的一种自然现象&#xff0c;是自然灾害。洪水又称大水&#xff0c;是河流、海洋、湖泊等水体上涨超过一定水位&#xff0c;威胁有关地区的安全&#xff0c;甚至造成灾害的水…

基于惯性加权PSO优化的目标函数最小值求解matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于惯性加权PSO优化的目标函数最小值求解matlab仿真。 2.测试软件版本以及运行结果展示 MATLAB2022A版本运行 &#xff08;完整程序运行后无水印&#xff09;…

使用java stream对集合中的对象按指定字段进行分组并统计

一、概述 有这样一个需求&#xff0c;在一个list集合中的对象有相同的name&#xff0c;我需要把相同name的对象进行汇总计算。使用java stream来实现这个需求&#xff0c;这里做一个记录&#xff0c;希望对有需求的同学提供帮助 一、根据指定字段进行分组 一、先准备好给前端要…

三菱plc gxwork3 0X121201F 报错;三菱标签区域的保留容量不足;

如果占用过多把r文件寄存器的地址范围改小&#xff0c;一般文件寄存器的地址r0-8000足够了

zk集群搭建

zk集群在搭建部署的时候&#xff0c;通常选择2n1奇数台。底层 Paxos 算法支持&#xff08;过半成功&#xff09;。 zk部署之前&#xff0c;保证服务器基础环境正常、JDK成功安装。 服务器基础环境 IP主机名hosts映射防火墙关闭时间同步ssh免密登录 JDK环境 1、虚拟机克隆 …

014-GeoGebra基础篇-快速解决滑动条的角度无法输入问题

有客户反馈&#xff0c;他的Geogebra一直有个bug&#xff0c;那就是输入角度最大值时总不按照他设定的展示&#xff0c;快被气炸了~ 目录 一、问题复现&#xff08;1&#xff09;插入一个滑动条&#xff08;2&#xff09;选择Angle&#xff08;3&#xff09;输入90&#xff0c;…

复现centernet时,报错RuntimeError: CUDA error: out of memory

运行 python test.py ctdet --dataset coco --exp_id coco_dla --load_model /root/CenterNet/exp/ctdet/coco_dla/model_last.pth --gpus 0 --test_scales 1 报错下面&#xff1a; RuntimeError: CUDA error: out of memory明明显存是够用的 解决办法&#xff1a; 找到自己…

代码随想录——无重叠区间(Leetcode435)

题目链接 贪心 排序 class Solution {public int eraseOverlapIntervals(int[][] intervals) {int res 0;if(intervals.length 1 || intervals.length 0){return res;}// 按左边界排序Arrays.sort(intervals, new Comparator<int[]>() {public int compare(int[] …

Protobuf(三):理论学习,简单总结

1. Protocol Buffers概述 Protocol Buffers&#xff08;简称protobuf&#xff09;&#xff0c;是谷歌用于序列化结构化数据的一种语言独立、平台独立且可扩展的机制&#xff0c;类似XML&#xff0c;但比XML更小、更快、更简单protobuf的工作流程如图所示 1.1 protobuf的优点…

【第17章】MyBatis-Plus自动维护DDL

文章目录 前言一、功能概述二、注意事项三、代码示例四、实战1. 准备2. ddl配置类3. 程序启动4. 效果(数据库) 总结 前言 在MyBatis-Plus的3.5.3版本中&#xff0c;引入了一项强大的功能&#xff1a;数据库DDL&#xff08;数据定义语言&#xff09;表结构的自动维护。这一功能…

告别高查重率,AI降重工具帮你快速过关

高查重率是许多毕业生的困扰。通常&#xff0c;高查重率源于过度引用未经修改的参考资料和格式错误。传统的降重方法&#xff0c;如修改文本和增添原创内容&#xff0c;虽必要但耗时且成效不一。 鉴于此&#xff0c;应用AI工具进行AIGC降重成为了一个高效的解决方案。这些工具…

Cloudflare 推出一款免费对抗 AI 机器人的可防止抓取数据工具

上市云服务提供商Cloudflare推出了一种新的免费工具&#xff0c;可防止机器人抓取其平台上托管的网站以获取数据以训练AI模型。 一些人工智能供应商&#xff0c;包括谷歌、OpenAI 和苹果&#xff0c;允许网站所有者通过修改他们网站的robots.txt来阻止他们用于数据抓取和模型训…

Word “当前页“ 与 “前一页“ (含部分内容)间有大半页空白,删除空白方法

鼠标光标选中需要向上移的句子&#xff0c;右键点击“段落”&#xff0c;然后在跳出的窗口中按照“换行和分页”中的红色方框内取消勾选后&#xff0c;点击确定即可。

@化工人|人员定位系统如何选择?从了解定位技术开始

提及化工安全管理所面临的主要难题&#xff0c;大家往往会想到”人难管、事难办、责难负“这三难&#xff0c;而每一难都离不开”人“这个主体。因此&#xff0c;在企业日渐实现数字化转型的今天&#xff0c;越来越多的化工企业选择建设以人员定位系统为核心的企业安全生产信息…

实时数仓Hologres OLAP场景核心能力介绍

作者&#xff1a;赵红梅 Hologres PD OLAP典型应用场景与痛点 首先介绍典型的OLAP场景以及在这些场景上的核心痛点&#xff0c;OLAP典型应用场景很多&#xff0c;总结有四类&#xff1a;第一类是BI报表分析类&#xff0c;例如BI报表&#xff0c;实时大屏&#xff0c;数据中台等…

Web前端开发——HTML快速入门

HTML&#xff1a;控制网页的结构CSS&#xff1a;控制网页的表现 一、什么是HTML、CSS &#xff08;1&#xff09;HTML &#xff08;HyperText Markup Languaqe&#xff1a;超文本标记语言&#xff09; 超文本&#xff1a;超越了文本的限制&#xff0c;比普通文本更强大。除了…