【深度学习入门篇 ⑧】关于卷积神经网络

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


关于卷积神经网络,你还有哪些不知道的知识点呢,之前我们介绍了大部分,今天再来补充一下~

卷积神经网络基础

什么是卷积

Convolution,输入信息与核函数(滤波器)的乘积

  • 一维信号的时间卷积:输入x,核函数w,输出是一个连续时间段t的加权平均结果。
  • 二维图像的空间卷积:输入图像I,卷积核K,输出图像O。

单个二维图片卷积 :输入为单通道图像,输出为单通道图像。

图像的数据存储是多通道的二维矩阵:

灰度图(Gray)只有一个通道(一层),RGB彩色图就是三个通道(Red,Green,Blue),而RGBA彩色图就是四个通道(Red,Green,Blue,Alpha)。

如何表达每一个网络层中高维的图像数据?

特征图包含:通道,宽度,高度,其中输入特征图Ci,输出特征图C0,输出特征图的每一个通道,由输入图的所有通道和相同数量的卷积核先一一对应各自进行卷积计算,然后求和

 

卷积相关操作与参数

填充

padding :给卷积前的输入图像边界添加额外的行,列

  • 控 制 卷 积 后 图 像分 辨 率 , 方便计算特征图尺寸的变化
  • 弥 补 边 界 信 息 “丢 失 ” 

步长

步长(stride):卷积核在图像上移动的步子

卷积的核心思想 

为什么要进行局部连接?

  • 局部连接可以更好地利用图像中的结构信息,空间距离越相近的像素其相互影响越大

权重共享:保证不变性,图像从一个局部区域学习到的信息应用到其他区域 ,减少参数,降低学习难度。

ANN与CNN比较

传统神经网络为有监督的机器学习,输入为特征;卷积神经网络为无监督特征学习,输入为最原始的图像。


案例-图像分类

CIFAR10 数据集

CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32×32×3。

 PyTorch 中的 torchvision.datasets 计算机视觉模块封装了 CIFAR10 数据集:

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoaderdef func1():# 加载数据集train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))print('训练集数量:', len(train.targets))print('测试集数量:', len(valid.targets))print("数据集形状:", train[0][0].shape)print("数据集类别:", train.class_to_idx)# 数据加载器
def func2():train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))dataloader = DataLoader(train, batch_size=8, shuffle=True)for x, y in dataloader:print(x.shape)print(y)breakif __name__ == '__main__':func1()func2()

我们要搭建的网络结构:

我们在每个卷积计算之后应用 relu 激活函数来给网络增加非线性因素。

网络代码实现:

class ImageClassification(nn.Module):def __init__(self):super(ImageClassification, self).__init__()self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.linear1 = nn.Linear(576, 120)self.linear2 = nn.Linear(120, 84)self.out = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)x = x.reshape(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))return self.out(x)

 编写训练函数

训练时,使用多分类交叉熵损失函数,Adam 优化器:

def train():transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transgform)model = ImageClassification()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练轮数epoch = 100for epoch_idx in range(epoch):# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 样本数量sam_num = 0# 损失总和total_loss = 0.0# 开始时间start = time.time()correct = 0for x, y in dataloader:# 送入模型output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()correct += (torch.argmax(output, dim=-1) == y).sum()total_loss += (loss.item() * len(y))sam_num += len(y)print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %(epoch_idx + 1,total_loss / sam_num,correct / sam_num,time.time() - start))torch.save(model.state_dict(), 'model/image_classification.bin')

编写预测函数

我们加载训练好的模型,对测试集中的 1 万条样本进行预测,查看模型在测试集上的准确率

def test():transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transgform)# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 加载模型model = ImageClassification()model.load_state_dict(torch.load('model/image_classification.bin'))model.eval()total_correct = 0total_samples = 0for x, y in  dataloader:output = model(x)total_correct += (torch.argmax(output, dim=-1) == y).sum()total_samples += len(y)print('Acc: %.2f' % (total_correct / total_samples))

输出:

'Acc: 0.61

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/872734.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python和C++骨髓细胞进化解析数学模型

🎯要点 🎯 数学模型邻接矩阵及其相关的转移概率 | 🎯蒙特卡罗模拟进化动力学 | 🎯细胞进化交叉图族概率 | 🎯进化图模型及其数学因子 | 🎯混合图模式对进化概率的影响 | 🎯造血干细胞群体的空间…

汇总国内镜像提供了Redis的下载地址

文章目录 1. 清华大学开源软件镜像站:2. 中国科技大学开源软件镜像:3. 阿里云镜像:4. 华为云镜像:5. 腾讯云镜像:6. 网易开源镜像站7. 官方GitHub仓库(虽然不是镜像,但也是一个可靠的下载源&…

AI算法19-偏最小二乘法回归算法Partial Least Squares Regression | PLS

偏最小二乘法回归算法简介 算法概述 偏最小二乘法模型可分为偏最小二乘回归模型和偏最小二乘路径模型。其中偏最小二乘回归模型是一种新型的多元统计方法,它集中了主成分分析、典型相关分析和线性回归的特点,特别在解决回归中的共线性问题具有无可比拟…

# Redis 入门到精通(五)-- redis 持久化(2)

Redis 入门到精通(五)-- redis 持久化(2) 一、redis 持久化–save 配置与工作原理 1、RDB 启动方式:反复执行保存指令,忘记了怎么办?不知道数据产生了多少变化,何时保存&#xff1…

CNN之图像识别

Inception Inception网络是CNN发展史上一个重要的里程碑。在Inception出现之前,大部分流行CNN仅仅是把卷积层堆叠得越来越多,使网络越来越深,以此希望能够得到更好的性能。但是存在以下问题: 图像中突出部分的大小差别很大。由于信息位置的…

【typedb】例子:药物发现: studio运行

测试8:solution结果 测试1:获取名字为Q9NPB9的protein Let’s start by getting the names of the protein Q9NPB9:测试2:哪个基因编码了Q9NPB9 Now let’s see which gene encodes for protein Q9NPB9: 推理过程:

【Linux】基础I/O——FILE,用户缓冲区

1.FILE里的fd FILE是C语言定义的文件结构体,里面包含了各种文件信息。可以肯定的一点是,FILE结构体内一定封装了 fd 。为什么?来看接下来的思路分析: 1.使用系统接口的必然性   文件存储在磁盘上,属于外设。谁有权限访问…

RabbitMQ:基础篇

1.RabbitMQ是高性能的异步通讯组件 何为异步通讯 打电话就是同步通讯,微信聊天可以理解为异步通讯,不是实时的进行通讯:时效性差。 同步调用的缺点: 拓展性差(需求不尽提) 性能下降 级联失败 …

带你轻松玩转DevOps

一、DevOps详细介绍 软件开发最开始是由两个团队组成: 开发计划由**开发团队**从头开始设计和整体系统的构建。需要系统不停的迭代更新。**运维团队**将开发团队的Code进行测试后部署上线。希望系统稳定安全运行。 这两个看似目标不同的团队,需要协同完…

HarmonyOS 开发者联盟高级认证最新题库

本篇文章包含 Next 版本更新后高级认证题库中95%的题目。 答案正确率 50-60%,答案仅做参考。 请在考试前重点看一遍题目,勿要盲目抄答案。 欢迎在评论留言正确答案和未整理的题目。 1、下面关于方舟字节码格式PREF_IMM16_v8_v8描述正确的是 16位前缀操作…

dp or 数学问题

看一下数据量&#xff0c;只有一千&#xff0c;说明这个不是数学问题 #include<bits/stdc.h> using namespace std;#define int long long const int mo 100000007; int n, s, a, b; const int N 1005;// 2 -3 // 1 3 5 2 -1 // 1 -2 -5 -3 -1 int dp[N][N]; int fun…

算法力扣刷题记录 四十九【112. 路径总和】和【113. 路径总和ii】

前言 二叉树篇继续。 记录 四十九【112. 路径总和】和【113. 路径总和ii】 一、【112. 路径总和】题目阅读 给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径&#xff0c;这条路径上所有节点值相加等于目标和 target…

django-ckeditor富文本编辑器

一.安装django-ckeditor 1.安装 pip install django-ckeditor2.注册应用 INSTALLED_APPS [...ckeditor&#xff0c; ]3.配置model from ckeditor.fields import RichTextFieldcontent RichTextField()4.在项目中manage.py文件下重新执行迁移&#xff0c;生成迁移文件 py…

R语言模型评估网格搜索

### 网格搜索 ### install.packages("gbm") set.seed(1234) library(caret) library(gbm) fitControl <- trainControl(method repeatedcv,number 10,repeats 5) # 设置网格搜索的参数池 gbmGrid <- expand.grid(interaction.depth c(3,5,9),n.trees (1:2…

轨道交通AR交互教学定制公司优选深圳华锐视点

在寻找上海AR开发制作公司作为合作伙伴的过程中&#xff0c;选择一家既技术深厚又具备丰富经验的AR开发企业&#xff0c;成为了众多客户与合作伙伴的共同追求。华锐视点上海AR开发制作公司作为业界的佼佼者&#xff0c;凭借其卓越的公司规模、丰富的行业案例以及顶尖的ar增强现…

Unity基础调色

叭叭叭 最近&#xff08;*这两天&#xff09;因为想做一些Unity的调色问题&#xff0c;尝试原文翻译一下&#xff0c;其实直接原文更好&#xff01;&#xff01; Color Grading 参考了&#xff0c;某大牛的翻译&#xff0c;实在忍不住了&#xff0c;我是不知道为什么能翻译成…

OpenSceneGraph学习笔记

目录 引言第一章&#xff1a;OSG概述一、前言&#xff08;1&#xff09;为什么要学习OSG?&#xff08;2&#xff09;OSG的组成&#xff08;3&#xff09;OSG的智能指针&#xff08;4&#xff09;OSG的安装编译 二、第一个OSG程序&#xff08;1&#xff09;Hello OSG程序&#…

美式键盘 QWERTY 布局的来历

注&#xff1a;机翻&#xff0c;未校对。 The QWERTY Keyboard Is Tech’s Biggest Unsolved Mystery QWERTY 键盘是科技界最大的未解之谜 It’s on your computer keyboard and your smartphone screen: QWERTY, the first six letters of the top row of the standard keybo…

Linux热键,shell含义及权限介绍

君子忧道不忧贫。 —— 孔丘 Linux操作系统的权限 1、几个常用的热键介绍1、1、[Tab]键1、2、[ctrl]-c1、3、[ctrl]-d1、4、[ctrl]-r 2、shell命令以及运行原理3、权限3、1、什么是权限3、2、权限的本质3、3、Linux中的用户3、4、Linux中文件的权限3、4、1、快速掌握修改权限的…

vue引用js html页面 vue引用js动态效果

要引用的index.html页面&#xff1a;&#xff08;资源来自网络&#xff09;在pubilc下建一个static文件放入js文件 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>数字翻转</title><meta con…