【深度学习】VGG-16原理及代码实现

1.原理及介绍

2.代码实现

2.1model.py
import torch
from torch import nn
from torchsummary import summary
import torch.nn.functional as Fclass VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()self.block1 = nn.Sequential(  # 用一个序列(块)来表示nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block5 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.block6 = nn.Sequential(nn.Flatten(),  # 平展层nn.Linear(7 * 7 * 512, 128),nn.ReLU(),nn.Dropout(0.5),  # 过拟合时可使用nn.Linear(128, 64),nn.ReLU(),nn.Dropout(0.5),nn.Linear(64, 10),)# 参数初始化for m in self.modules():  # 权重初始化,防止模型的参数随机生成,从而产生不收敛的现象if isinstance(m, nn.Conv2d):  # 卷积层参数初始化nn.init.kaiming_normal_(m.weight, nonlinearity='relu')  # 使用何恺明初始化,优化参数wif m.bias is not None:  # b参数优化,初始置为0,如果有的话nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):  # 全连接层参数初始化nn.init.normal_(m.weight, 0, 0.01)  # 正态分布if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.block6(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = VGG16().to(device)print(summary(model, (1, 224, 224)))  # 打印模型信息
2.2model_train.py
import copy
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as data
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from model import VGG16def train_val_data_process():train_dataset = FashionMNIST(root='./data',  # 指定下载文件夹train=True,  # 只下载训练集# 归一化处理数据集transform=transforms.Compose([transforms.Resize(size=224),transforms.ToTensor()]),  # 数据转换成张量形式,方便模型应用download=True  # 下载数据)# 划分训练集和验证集train_data, val_data = data.random_split(train_dataset,[round(0.8 * len(train_dataset)), round(0.2 * len(train_dataset))])# 训练集加载train_dataloader = data.DataLoader(dataset=train_data,batch_size=28,  # 一个批次数据的数量shuffle=True,  # 数据打乱num_workers=2)  # 分配的进程数目# 验证集加载val_dataloader = data.DataLoader(dataset=val_data,batch_size=28,shuffle=True,num_workers=2)return train_dataloader, val_dataloaderdef train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 定义训练使用的设备,有GPU则用,没有则用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 使用Adam优化器进行模型参数更新,学习率为0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 损失函数为交叉熵损失函数criterion = nn.CrossEntropyLoss()# 将模型放入训练设备内model = model.to(device)# 复制当前模型参数(w,b等),以便将最好的模型参数权重保存下来best_model_wts = copy.deepcopy(model.state_dict())# 初始化参数# 最高准确度best_acc = 0.0# 训练集损失值列表train_loss_all = []# 验证集损失值列表val_loss_all = []# 训练集准确度列表train_acc_all = []# 验证集准确度列表val_acc_all = []# 当前时间since = time.time()for epoch in range(num_epochs):print("Epoch {}/{}".format(epoch, num_epochs - 1))print("-" * 10)# 初始化参数# 训练集损失值train_loss = 0.0# 训练集精确度train_corrects = 0# 验证集损失值val_loss = 0.0# 验证集精确度val_corrects = 0# 训练集样本数量train_num = 0# 验证集样本数量val_num = 0# 对每一个mini-batch训练和计算for step, (b_x, b_y) in enumerate(train_dataloader):# 将特征放入到训练设备中b_x = b_x.to(device)  # batch_size*28*28*1的tensor数据# 将标签放入到训练设备中b_y = b_y.to(device)  # batch_size大小的向量tensor数据# 设置模型为训练模式model.train()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)  # 输出为:batch_size大小的行和10列组成的矩阵# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)  # batch_size大小的向量表示属于物品的标签# 计算每一个batch的损失函数,向量形式的交叉熵损失函数loss = criterion(output, b_y)# 将梯度初始化为0,防止梯度累积optimizer.zero_grad()# 反向传播计算loss.backward()# 根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用optimizer.step()# 对损失函数进行累加,该批次的loss值乘于该批次数量得到批次总体loss值,在将其累加得到轮次总体loss值train_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度train_corrects加1train_corrects += torch.sum(pre_lab == b_y.data)# 当前用于训练的样本数量train_num += b_x.size(0)for step, (b_x, b_y) in enumerate(val_dataloader):# 将特征放入到验证设备中b_x = b_x.to(device)# 将标签放入到验证设备中b_y = b_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为一个batch,输出为一个batch中对应的预测output = model(b_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 计算每一个batch的损失函数loss = criterion(output, b_y)# 对损失函数进行累加val_loss += loss.item() * b_x.size(0)# 如果预测正确,则准确度val_corrects加1val_corrects += torch.sum(pre_lab == b_y.data)# 当前用于验证的样本数量val_num += b_x.size(0)# 计算并保存每一轮次迭代的loss值和准确率# 计算并保存训练集的loss值train_loss_all.append(train_loss / train_num)# 计算并保存训练集的准确率train_acc_all.append(train_corrects.double().item() / train_num)# 计算并保存验证集的loss值val_loss_all.append(val_loss / val_num)# 计算并保存验证集的准确率val_acc_all.append(val_corrects.double().item() / val_num)# 打印每一轮次的loss值和准确度print("{} train loss:{:.4f} train acc: {:.4f}".format(epoch, train_loss_all[-1], train_acc_all[-1]))print("{} val loss:{:.4f} val acc: {:.4f}".format(epoch, val_loss_all[-1], val_acc_all[-1]))if val_acc_all[-1] > best_acc:# 保存当前最高准确度best_acc = val_acc_all[-1]# 保存当前最高准确度的模型参数best_model_wts = copy.deepcopy(model.state_dict())# 计算训练和验证的耗时time_use = time.time() - sinceprint("训练和验证耗费的时间{:.0f}m{:.0f}s".format(time_use // 60, time_use % 60))# 选择最优参数,保存最优参数的模型torch.save(best_model_wts, "./model_save/VGG16_best_model.pth")# 将产生的数据保存成表格,方便查看train_process = pd.DataFrame(data={"epoch": range(num_epochs),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})return train_processdef matplot_acc_loss(train_process):# 显示每一次迭代后的训练集和验证集的损失函数和准确率plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)  # 表示一行两列的第一张图plt.plot(train_process['epoch'], train_process.train_loss_all, "ro-", label="Train loss")plt.plot(train_process['epoch'], train_process.val_loss_all, "bs-", label="Val loss")plt.legend()plt.xlabel("epoch")plt.ylabel("Loss")plt.subplot(1, 2, 2)  # 表示一行两列的第二张图plt.plot(train_process['epoch'], train_process.train_acc_all, "ro-", label="Train acc")plt.plot(train_process['epoch'], train_process.val_acc_all, "bs-", label="Val acc")plt.xlabel("epoch")plt.ylabel("acc")plt.legend()plt.show()if __name__ == '__main__':# 加载需要的模型VGG16 = VGG16()# 加载数据集train_data, val_data = train_val_data_process()# 利用现有的模型进行模型的训练train_process = train_model_process(VGG16, train_data, val_data, num_epochs=3)  # 注:由于硬件限制,只设置训练3轮(可修改)matplot_acc_loss(train_process)
2.3model.test.py
import torch
import torch.utils.data as data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import VGG16def test_data_process():test_data = FashionMNIST(root='./data',train=False,  # 用测试集进行测试transform=transforms.Compose([transforms.Resize(size=224), transforms.ToTensor()]),download=True)test_dataloader = data.DataLoader(dataset=test_data,batch_size=1,  # 该批次设为1shuffle=True,num_workers=0)return test_dataloaderdef test_model_process(model, test_dataloader):# 设定测试所用到的设备,有GPU用GPU没有GPU用CPUdevice = "cuda" if torch.cuda.is_available() else 'cpu'# 讲模型放入到训练设备中model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0# 只进行前向传播计算,不计算梯度,从而节省内存,加快运行速度with torch.no_grad():for test_data_x, test_data_y in test_dataloader:# 将特征放入到测试设备中test_data_x = test_data_x.to(device)# 将标签放入到测试设备中test_data_y = test_data_y.to(device)# 设置模型为评估模式model.eval()# 前向传播过程,输入为测试数据集,输出为对每个样本的预测值output = model(test_data_x)# 查找每一行中最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 如果预测正确,则准确度test_corrects加1test_corrects += torch.sum(pre_lab == test_data_y.data)# 将所有的测试样本进行累加test_num += test_data_x.size(0)# 计算测试准确率test_acc = test_corrects.double().item() / test_numprint("测试的准确率为:", test_acc)if __name__ == "__main__":# 加载模型model = VGG16()model.load_state_dict(torch.load('./model_save/VGG16_best_model.pth'))  # 调用训练好的参数权重# 加载测试数据test_dataloader = test_data_process()# 加载模型测试的函数test_model_process(model, test_dataloader)

可以点个免费的赞吗!!!   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/48285.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

51单片机嵌入式开发:13、STC89C52RC 之 RS232与电脑通讯

STC89C52RC 之 RS232与电脑通讯 第十三节课,RS232与电脑通讯1 概述2 Uart介绍2.1 概述2.2 STC89C52UART介绍2.3 STC89C52 UART寄存器介绍2.4 STC89C52 UART操作 3 C51 UART总结 第十三节课,RS232与电脑通讯 1 概述 RS232(Recommended Stand…

Github报错:Kex_exchange_identification: Connection closed by remote host

文章目录 1. 背景介绍2. 排查和解决方案 1. 背景介绍 Github提交或者拉取代码时,报错如下: Kex_exchange_identification: Connection closed by remote host fatal: Could not read from remote repository.Please make sure you have the correct ac…

HTML5大作业三农有机,农产品,农庄,农旅网站源码

文章目录 1.设计来源1.1 轮播图页面头部效果1.2 栏目列表页面效果1.3 页面底部导航效果 2.效果和源码2.1 源代码 源码下载万套模板,程序开发,在线开发,在线沟通 作者:xcLeigh 文章地址:https://blog.csdn.net/weixin_4…

计算机三级嵌入式笔记(一)—— 嵌入式系统概论

目录 考点1 嵌入式系统 考点2 嵌入式系统的组成与分类 考点3 嵌入式系统的分类与发展 考点4 SOC芯片 考点5 数字(电子)文本 考点6 数字图像 考点7 数字音频与数字视频 考点8 数字通信 考点9 计算机网络 考点10 互联网 考纲(2023&am…

2、如何发行自己的数字代币(truffle智能合约项目实战)

2、如何发行自己的数字代币(truffle智能合约项目实战) 1-Atom IDE插件安装2-truffle tutorialtoken3-tutorialtoken源码框架分析4-安装openzeppelin代币框架(代币发布成功) 1-Atom IDE插件安装 正式介绍基于web的智能合约开发 推…

【Vue3】响应式数据

【Vue3】响应式数据 背景简介开发环境基本数据类型对象数据类型使用 reactive 定义对象类型响应式数据使用 ref 定义对象类型响应式数据 ref 和 reactive 的对比使用原则建议 背景 随着年龄的增长,很多曾经烂熟于心的技术原理已被岁月摩擦得愈发模糊起来&#xff0…

牛客:TOP101链表相加(二)

文章目录 1. 题目描述2. 解题思路3. 代码实现 1. 题目描述 2. 解题思路 按照我们习惯的加法运算,肯定是要从个位开始相加,然后十位……,但是在链表中如果我们先运算后面的,那么接下来我们是无法找到前一位的。想要解决这个问题也很…

数模·插值和拟合算法

插值 将离散的点连成曲线或者线段的一种方法 题目中有"任意时刻任意的量"时使用插值,因为插值一定经过样本点 插值函数的概念 插值函数与样本离散的点一一重合 插值函数往往有多个区间,多个区间插值函数样态不完全一样,简单来说就…

【系统架构设计 每日一问】二 MySql主从复制延迟可能是什么原因,怎么解决

主从复制的架构设计如下图所示: 同步原理 具体到数据库之间是通过binlog和复制线程操作的: Master的更新事件(update、insert、delete)会按照顺序写入bin-log中。当Slave连接到Master的后,Master机器会为Slave开启,binlog dump线程,该线程…

H3CNE(计算机网络的概述)

1. 计算机网络的概述 1.1 计算机网络的三大基本功能 1. 资源共享 2. 分布式处理与负载均衡 3. 综合信息服务 1.2 计算机网络的三大基本类型 1.3 网络拓扑 定义: 网络设备连接排列的方式 网络拓扑的类型: 总线型拓扑: 所有的设备共享一…

Vue3 --- 路由

路由就是一组key-value的对应关系;多个路由,需要经过路由器的管理。 1. 基本切换效果 安装路由器 npm i vue-router /router/index.ts // import { createRouter, createWebHistory } from vue-router import Home from /components/Home.vue import…

萝卜快跑爆火的背后,美格智能如何助力无人车商业化?

近期,“订单量超过600万单”等夺人眼球的信息,让无人驾驶出租车“萝卜快跑”从江城武汉爆火出圈,在2024年的炎炎夏日为这座大火炉再添了一把火。热度背后,不少地方主管部门,近期也纷纷针对无人驾驶出租车、无人驾驶运输…

基于术语词典干预的机器翻译挑战赛笔记Task2 #Datawhale AI 夏令营

上回: 基于术语词典干预的机器翻译挑战赛笔记Task1 跑通baseline Datawhale AI 夏令营-CSDN博客文章浏览阅读718次,点赞11次,收藏8次。基于术语词典干预的机器翻译挑战赛笔记Task1 跑通baselinehttps://blog.csdn.net/qq_23311271/article/d…

C++树形结构(2 树的直径)

目录 1.定义: 2.直径的性质: 3.树的直径求解方法: 4.直径端点求解方法: 朴素方法: 优化方法: 5.例题: 6.直径公共点: 7.例题: 8.去掉再加上: 9.例…

最新版kubeadm搭建k8s(已成功搭建)

kubeadm搭建k8s(已成功搭建) 环境配置 主节点 k8s-master:4核8G、40GB硬盘、CentOS7.9(内网IP:10.16.64.67) 从节点 k8s-node1: 4核8G、40GB硬盘、CentOS7.9(内网IP:10…

框架使用及下载

Bootstrap5 安装使用 | 菜鸟教程 (runoob.com) https://github.com/twbs/bootstrap/releases/download/v5.1.3/bootstrap-5.1.3-dist.zip(下载链接) Staticfile CDN(html的所有框架合集) 直接在w3cschool里面看参考文件进行搜索自…

RHCSA —— 第八节 (编辑器、编辑命令等)

Vi/vim编辑器 vim 编辑器 就是相当于在windows中创建一个记事本,一个word文档里面进行编辑所需要的内容。在linux中编辑文本文件,包括但不限于编辑源代码、配置文件、日志文件等文件内容。 三种模式 这是在编辑器中存在三种模式:命令模式、…

[经验] 驰这个汉字的拼音是什么 #学习方法#其他#媒体

驰这个汉字的拼音是什么 驰,是一个常见的汉字,其拼音为“ch”,音调为第四声。它既可以表示动词,也可以表示形容词或副词,意义广泛,经常出现在生活和工作中。下面就让我们一起来了解一下“驰”的含义和用法。…

Deepfake detection【Datawhale AI夏令营】数据增强方法

deepfake detection比赛链接https://www.kaggle.com/competitions/multi-ffdi 训练分类模型判别图片是否为AI生成图片,探究不同数据增强方法对模型表现的影响。 1、数据增强方法 图像分类任务中常见的数据增强方法: (1) 几何变换…

【BUG】已解决:xlrd.biffh.XLRDError: Excel xlsx file; not supported

已解决:xlrd.biffh.XLRDError: Excel xlsx file; not supported 目录 已解决:xlrd.biffh.XLRDError: Excel xlsx file; not supported 【常见模块错误】 错误原因 解决办法: 欢迎来到英杰社区https://bbs.csdn.net/…