【深度学习实验】卷积神经网络(六):卷积神经网络模型(VGG)训练、评价

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 构建数据集(CIFAR10Dataset)

a. read_csv_labels()

b. CIFAR10Dataset

2. 构建模型(FeedForward)

3.整合训练、评估、预测过程(Runner)

4. __main__

代码整合


一、实验介绍

        本实验实现了一个简化版VGG网络,并基于此完成图像分类任务。(包括模型训练、评价)
       

        VGG网络是深度卷积神经网络中的经典模型之一,由牛津大学计算机视觉组(Visual Geometry Group)提出。它在2014年的ImageNet图像分类挑战中取得了优异的成绩(分类任务第二,定位任务第一),被广泛应用于图像分类、目标检测和图像生成等任务。

        VGG网络的主要特点是使用了非常小的卷积核尺寸(通常为3x3)和更深的网络结构。该网络通过多个卷积层和池化层堆叠在一起,逐渐增加网络的深度,从而提取图像的多层次特征表示。VGG网络的基本构建块是由连续的卷积层组成,每个卷积层后面跟着一个ReLU激活函数。在每个卷积块的末尾,都会添加一个最大池化层来减小特征图的尺寸。VGG网络的这种简单而有效的结构使得它易于理解和实现,并且在不同的任务上具有很好的泛化性能。

        VGG网络有几个不同的变体,如VGG11、VGG13、VGG16和VGG19,它们的数字代表网络的层数。这些变体在网络深度和参数数量上有所区别,较深的网络通常具有更强大的表示能力,但也更加复杂。

二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,广泛应用于图像识别、计算机视觉和模式识别等领域。它的设计灵感来自于生物学中视觉皮层的工作原理。

        卷积神经网络通过多个卷积层、池化层全连接层组成。

  • 卷积层主要用于提取图像的局部特征,通过卷积操作和激活函数的处理,可以学习到图像的特征表示。
  • 池化层则用于降低特征图的维度,减少参数数量,同时保留主要的特征信息。
  • 全连接层则用于将提取到的特征映射到不同类别的概率上,进行分类或回归任务。

        卷积神经网络在图像处理方面具有很强的优势,它能够自动学习到具有层次结构的特征表示,并且对平移、缩放和旋转等图像变换具有一定的不变性。这些特点使得卷积神经网络成为图像分类、目标检测、语义分割等任务的首选模型。除了图像处理,卷积神经网络也可以应用于其他领域,如自然语言处理和时间序列分析。通过将文本或时间序列数据转换成二维形式,可以利用卷积神经网络进行相关任务的处理。

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F

1. 构建数据集(CIFAR10Dataset)

a. read_csv_labels()

        从CSV文件中读取标签信息并返回一个标签字典。

def read_csv_labels(fname):"""读取fname来给标签字典返回一个文件名"""with open(fname, 'r') as f:# 跳过文件头行(列名)lines = f.readlines()[1:]tokens = [l.rstrip().split(',') for l in lines]return dict(((name, label) for name, label in tokens))
  •  使用open函数打开指定文件名的CSV文件,并将文件对象赋值给变量f。这里使用'r'参数以只读模式打开文件。

  • 使用文件对象的readlines()方法读取文件的所有行,并将结果存储在名为lines的列表中。通过切片操作[1:],跳过了文件的第一行(列名),将剩余的行存储在lines列表中。

  • 列表推导式(list comprehension):对lines列表中的每一行进行处理。对于每一行,使用rstrip()方法去除行末尾的换行符,并使用split(',')方法将行按逗号分割为多个标记。最终,将所有行的标记组成的子列表存储在tokens列表中。

  • 使用字典推导式(dictionary comprehension)将tokens列表中的子列表转换为字典。对于tokens中的每个子列表,将子列表的第一个元素作为键(name),第二个元素作为值(label),最终返回一个包含这些键值对的字典。

b. CIFAR10Dataset

class CIFAR10Dataset(Dataset):def __init__(self, folder_path, fname):self.labels = read_csv_labels(os.path.join(folder_path, fname))self.folder_path = os.path.join(folder_path, 'train')def __len__(self):return len(self.labels)def __getitem__(self, idx):img = read_image(self.folder_path + '/' + str(idx + 1) + '.png')label = self.labels[str(idx + 1)]return img, torch.tensor(int(label))
  • 构造函数:

    • 接受两个参数

      • folder_path表示数据集所在的文件夹路径

      • fname表示包含标签信息的文件名。

    • 调用read_csv_labels函数,传递folder_pathfname作为参数,以读取CSV文件中的标签信息,并将返回的标签字典存储在self.labels变量中。

    • 通过拼接folder_path和字符串'train'来构建数据集的文件夹路径,将结果存储在self.folder_path变量中。

  • def __len__(self)

    • 这是CIFAR10Dataset类的方法,用于返回数据集的长度,即样本的数量。

  • def __getitem__(self, idx): 这是CIFAR10Dataset类的方法,用于根据给定的索引idx获取数据集中的一个样本。它首先根据索引idx构建图像文件的路径,并调用read_image函数来读取图像数据,将结果存储在img变量中。然后,它通过将索引转换为字符串,并使用该字符串作为键来从self.labels字典中获取相应的标签,将结果存储在label变量中。最后,它返回一个元组,包含图像数据和经过torch.tensor转换的标签。

2. 构建模型(FeedForward)

        参考前文:

【深度学习实验】卷积神经网络(五):深度卷积神经网络经典模型——VGG网络(卷积层、池化层、全连接层)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133350927?spm=1001.2014.3001.5501

3.整合训练、评估、预测过程(Runner)

        参考前文:

【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133219448?spm=1001.2014.3001.5501

        (略有改动:)

class Runner(object):def __init__(self, model, optimizer, loss_fn, metric=None):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fn# 用于计算评价指标self.metric = metric# 记录训练过程中的评价指标变化self.dev_scores = []# 记录训练过程中的损失变化self.train_epoch_losses = []self.dev_losses = []# 记录全局最优评价指标self.best_score = 0# 模型训练阶段def train(self, train_loader, dev_loader=None, **kwargs):# 将模型设置为训练模式,此时模型的参数会被更新self.model.train()num_epochs = kwargs.get('num_epochs', 0)log_steps = kwargs.get('log_steps', 100)save_path = kwargs.get('save_path','best_model.pth')eval_steps = kwargs.get('eval_steps', 0)# 运行的step数,不等于epoch数global_step = 0if eval_steps:if dev_loader is None:raise RuntimeError('Error: dev_loader can not be None!')if self.metric is None:raise RuntimeError('Error: Metric can not be None')# 遍历训练的轮数for epoch in range(num_epochs):total_loss = 0# 遍历数据集for step, data in enumerate(train_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long())total_loss += lossif step%log_steps == 0:print(f'loss:{loss.item():.5f}')loss.backward()self.optimizer.step()self.optimizer.zero_grad()# 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件if eval_steps != 0 :if (epoch+1) % eval_steps ==  0:dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')if dev_score > self.best_score:self.save_model(f'model_{epoch+1}.pth')print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')self.best_score = dev_score# 验证过程结束后,请记住将模型调回训练模式   self.model.train()global_step += 1# 保存当前轮次训练损失的累计值train_loss = (total_loss/len(train_loader)).item()self.train_epoch_losses.append((global_step,train_loss))self.save_model(f'{save_path}.pth')   print('[Train] Train done')# 模型评价阶段def evaluate(self, dev_loader, **kwargs):assert self.metric is not None# 将模型设置为验证模式,此模式下,模型的参数不会更新self.model.eval()global_step = kwargs.get('global_step',-1)total_loss = 0self.metric.reset()for batch_id, data in enumerate(dev_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long()).item()total_loss += loss self.metric.update(logits, y)dev_loss = (total_loss/len(dev_loader))self.dev_losses.append((global_step, dev_loss))dev_score = self.metric.accumulate()self.dev_scores.append(dev_score)return dev_score, dev_loss# 模型预测阶段,def predict(self, x, **kwargs):self.model.eval()logits = self.model(x)return logits# 保存模型的参数def save_model(self, save_path):torch.save(self.model.state_dict(), save_path)# 读取模型的参数def load_model(self, model_path):self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

4. __main__

 batch_size = 20# 构建训练集train_data = CIFAR10Dataset('cifar10_tiny', 'trainLabels.csv')train_iter = DataLoader(train_data, batch_size=batch_size)# 构建测试集num_classes = 10# 定义模型model = VGG_S(num_classes)# 定义损失函数loss_fn = F.cross_entropy# 定义优化器optimizer = torch.optim.SGD(model.parameters(), lr=0.1)runner = Runner(model, optimizer, loss_fn, metric=None)runner.train(train_iter, num_epochs=10, save_path='chapter_5')

本文有待进一步完善……

代码整合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/90879.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【网络协议】TCP

TCP协议全称为传输控制协议(Transmission Control Protocol).要理解TCP就要从他的特性开始说,这些特性各自之间或多或少各有联结,需要以宏观视角来看待。 目录: 1.TCP报文格式 因为报文解释过于繁琐,具体内容请看这篇文章TCP报文…

问题 - 谷歌浏览器 network 看不到接口请求解决方案

谷歌浏览器 -> 设置 -> 重置设置 -> 将设置还原为其默认值 查看接口情况,选择 All 或 Fetch/XHR,勾选 Has blocked cookies 即可 如果万一还不行,卸载浏览器重装。 参考:https://www.cnblogs.com/tully/p/16479528.html

微信小程序开发基础(二)基本组件

本帖开始介绍小程序中的一些基本组件~ 微信小程序是一种轻量、快速、跨平台的应用程序,是微信公众号的重要组成部分。随着微信小程序的普及,越来越多的开发者和企业开始使用微信小程序来搭建自己的应用,但是对于初次接触微信小程序的开发者…

CSS 基础 3

目录 🚀 导读 -- target 盒子模型 看透网页布局的本质 盒子模型组成 边框(border) border-style ​编辑border-color border-width 边框写法 简写 分开写 表格细线边框 边框会影响盒子实际大小 内边距 内容 内边距-padding padding属性简写 pad…

《PPT 自我介绍》:一本让你的职场表现更加出色的秘籍?

这里提供一个2000字左右的PPT自我介绍模板制作指南: 自我介绍是面试或工作中常见的情况,利用PPT可以给人留下更深刻的印象。但如何快速且专业地制作一个自我介绍PPT呢?这里给大家介绍几点技巧: 1. 选择一个简洁大方的PPT模板 首先要选择一…

STM32F4X UCOSIII任务信号量

STM32F4X UCOSIII任务信号量 任务信号量与内核信号量对比内核信号量任务信号量 UCOSIII任务信号量API任务信号量发送函数任务信号量接收函数 UCOSIII任务信号量例程 之前的章节中讲解过信号量这个机制,UCOSIII除了有内核信号量之外,还有任务信号量。在UC…

前端项目练习(练习-002-NodeJS项目初始化)

首先,创建一个web-002项目,内容和web-001一样。 下一步,规范一下项目结构,将html,js,css三个文件放到 src/view目录下面: 由于html引入css和js时,使用的是相对路径,所以…

详解C语言—文件操作

目录 1. 为什么使用文件 2. 什么是文件 3. 文件的使用 文件指针 文件的打开和关闭 三个标准的输入/输出流: 4. 文件的顺序读写 对字符操作: fputc: fgetc: 练习复制整个文件: 对字符串操作:…

WebGL雾化

目录 前言 如何实现雾化 线性雾化公式 雾化因子关系图 根据雾化因子计算片元颜色公式 示例程序(Fog.js) 代码详解​编辑 详解如何计算雾化因子(clamp()) 详解如何计算最终片元颜色(根据雾化因子计算片元颜色…

搭建自己的搜索引擎之四

一、前言 搭建自己的搜索引擎之三 介绍了通过HTTP RESTful 对ES进行增删改查,这一般手工运维ES时使用,程序代码中最好还是使用Java API去操作ES会更容易维护,但ES API竟然贼多,本篇介绍一下 四种 API及其简单使用。 注&#xff…

第1篇 目标检测概述 —(2)目标检测算法介绍

前言:Hello大家好,我是小哥谈。目标检测算法是一种计算机视觉算法,用于在图像或视频中识别和定位特定的目标物体。常见的目标检测算法包括传统的基于特征的方法(如Haar特征和HOG特征)以及基于深度学习的方法&#xff0…

Docker(三)、Dockerfile探究

Dockerfile探究 一、镜像层概念1、通过执行命令显化docker的机制 二、Dockerfile基础命令1、FROM 基于基准镜像【即构建镜像的时候,依托原有镜像做拓展】2、LABEL & MAINTAINER -说明信息3、WORKDIR 设置工作目录4、ADD & COPY 复制文件5、ENV 设置环境常量…

crypto:Alice与Bob

题目 根据题目描述,将98554799767分解成两个素数 得到两个素数101999和966233。根据题目提示把它们拼接起来进行md5的32位小写哈希

提升您的Mac文件拖拽体验——Dropzone 4 for mac

大家都知道,在Mac上进行文件拖拽是一件非常方便的事情。然而,随着我们在工作和生活中越来越多地使用电脑,我们对于这个简单操作的需求也越来越高。为了让您的文件拖拽体验更加高效和便捷,今天我们向大家介绍一款强大的工具——Dro…

[plugin:vite:css] [sass] Undefined mixin.

前言: vite vue3 TypeScript环境 scss报错: [plugin:vite:css] [sass] Undefined mixin. 解决方案: 在vite.config.ts文件添加配置 css: {preprocessorOptions: {// 导入scss预编译程序scss: {additionalData: use "/resources/_ha…

8章:scrapy框架

文章目录 scrapy框架如何学习框架?什么是scarpy?scrapy的使用步骤1.先转到想创建工程的目录下:cd ...2.创建一个工程3.创建之后要转到工程目录下4.在spiders子目录中创建一个爬虫文件5.执行工程setting文件中的参数 scrapy数据解析scrapy持久…

计算机视觉与深度学习-循环神经网络与注意力机制-RNN(Recurrent Neural Network)、LSTM-【北邮鲁鹏】

目录 举例应用槽填充(Slot Filling)解决思路方案使用前馈神经网络输入1-of-N encoding(One-hot)(独热编码) 输出 问题 循环神经网络(Recurrent Neural Network,RNN)定义如何工作学习目标深度Elm…

uniapp ssr发行后一直Hydration completed but contains mismatches Cannot find module

最开始我用前端网页托管的地址访问一直是 Hydration completed but contains mismatches 解决方案 要从云函数的地址访问项目。 先绑定域名,否则用uniapp自带地址访问一直是下载文件 设置路径 最后效果 uniapp ssr 云函数访问 MODULE_NOT_FOUND:Cannot fin…

2023 年 Bitget Wallet 测评

对Bitget Wallet钱包的看法 Bitget Wallet在安全性、产品实力和使用体验方面可与Metamask媲美,甚至有所超越,唯一稍显不足的是知名度稍逊一筹。在众多钱包中,Bitget Wallet是拥有最全面的钱包之一,尤其适合那些希望一步到位&…

最小生成树 | 市政道路拓宽预算的优化 (Minimum Spanning Tree)

任务描述: 市政投资拓宽市区道路,本着执政为民,节省纳税人钱的目的,论证是否有必要对每一条路都施工拓宽? 这是一个连问带答的好问题。项目制学习可以上下半场,上半场头脑风暴节省投资的所有可行的思路&a…