深度学习——CNN卷积神经网络

基本概念

概述

卷积神经网络(Convolutional Neural Network,CNN)是一种深度学习中常用于处理具有网格结构数据的神经网络模型。它在计算机视觉领域广泛应用于图像分类、目标检测、图像生成等任务。

核心思想

CNN 的核心思想是通过利用局部感知和参数共享来捕捉输入数据的空间结构信息。相比于传统的全连接神经网络,CNN 在网络结构中引入了卷积层和池化层,从而减少了参数量,并且能够更好地处理高维输入数据。

其他概念

输入层:接收原始图像或其他形式的输入数据。
卷积层(Convolutional Layer):使用卷积操作提取输入特征,通过设置滤波器(卷积核)在输入数据上滑动并执行卷积运算。这样可以学习到局部的特征,如边缘、纹理等。
激活函数(Activation Function):在每个卷积层后面通常紧跟一个非线性的激活函数,如ReLU(Rectified Linear Unit),以增加网络的非线性表达能力。
池化层(Pooling Layer):通过减少特征图的尺寸来降低模型复杂性。常用的池化操作是最大池化(Max Pooling),它选取每个池化窗口内的最大特征值作为输出。
全连接层(Fully Connected Layer):将卷积层和池化层的输出连接到全连接层,使用传统的神经网络模式进行分类、回归等任务。
Dropout 层:在训练过程中以一定概率随机将部分神经元的输出置为0,以减少模型的过拟合。
Softmax 层:多分类问题中常用的输出层,在最后一层进行 softmax 操作将输出转化为类别上的概率分布。

代码与详细注释

import os# third-party library
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt# torch.manual_seed(1)    # reproducible# Hyper Parameters
#  轮次
EPOCH = 1               # train the training data n times, to save time, we just train 1 epoch
# 批大小为50
BATCH_SIZE = 50
# 学习率
LR = 0.001
# 是否下载mnist数据集
DOWNLOAD_MNIST = False# 下载minist数据集
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):# not mnist dir or mnist is empyt dirDOWNLOAD_MNIST = True# torchvision本身就是一个数据库
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,                                     # this is training datatransform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]download=DOWNLOAD_MNIST,
)# 输出训练数据尺寸
print(train_data.train_data.size())                 # (60000, 28, 28)
# 输出标签数据尺寸
print(train_data.train_labels.size())               # (60000)
# 展示训练数据集中的第0个图片
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
# 图片的标题是标签
plt.title('%i' % train_data.train_labels[0])
plt.show()# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
# 批大小为50,shuffle为True意思是设置为随机
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# pick 2000 samples to speed up testing
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# 使用unsqueeze增加一个维度
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.   # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 快速搭建神经网络self.conv1 = nn.Sequential(         # input shape (1, 28, 28)nn.Conv2d(in_channels=1,              # input heightout_channels=16,            # n_filterskernel_size=5,              # filter sizestride=1,                   # filter movement/steppadding=2,                  # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1),                              # output shape (16, 28, 28)nn.ReLU(),                      # activationnn.MaxPool2d(kernel_size=2),    # choose max value in 2x2 area, output shape (16, 14, 14))self.conv2 = nn.Sequential(         # input shape (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 14, 14)nn.ReLU(),                      # activationnn.MaxPool2d(2),                # output shape (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes# 前向传播def forward(self, x):# 第一层卷积x = self.conv1(x)# 第二层卷积x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten the output of conv2 to (batch_size, 32 * 7 * 7)output = self.out(x)return output, x    # return x for visualizationcnn = CNN()
print(cnn)  # net architecture# 选择优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)   # optimize all cnn parameters
# 选择损失函数
loss_func = nn.CrossEntropyLoss()                       # the target label is not one-hotted# following function (plot_with_labels) is for visualization, can be ignored if not interested
from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):plt.cla()X, Y = lowDWeights[:, 0], lowDWeights[:, 1]for x, y, s in zip(X, Y, labels):c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)plt.ion()# training and testing
for epoch in range(EPOCH):for step, (b_x, b_y) in enumerate(train_loader):   # gives batch data, normalize x when iterate train_loaderoutput = cnn(b_x)[0]            # cnn outputloss = loss_func(output, b_y)   # cross entropy lossoptimizer.zero_grad()           # clear gradients for this training steploss.backward()                 # backpropagation, compute gradientsoptimizer.step()                # apply gradientsif step % 50 == 0:test_output, last_layer = cnn(test_x)pred_y = torch.max(test_output, 1)[1].data.numpy()accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)if HAS_SK:# Visualization of trained flatten layer (T-SNE)tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)plot_only = 500low_dim_embs = tsne.fit_transform(last_layer.data.numpy()[:plot_only, :])labels = test_y.numpy()[:plot_only]plot_with_labels(low_dim_embs, labels)
plt.ioff()# print 10 predictions from test data
test_output, _ = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')

运行结果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/4942.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux内核源代码的目录结构包括部分:

内核核心代码:这部分代码包括内核的各个子系统和模块,如进程管理、内存管理、文件系统、网络协议栈等。这些代码构成了Linux内核的核心功能。 非核心代码:除了核心代码之外,还包括一些非核心的代码和文件,如库文件、固…

和chatgpt学架构03-引入UI框架(elment-plus)

目录 1 项目目录及文件的具体作用1.1 App.vue1.2 main.js的作用1.3 main.js什么时候被调用1.4 npm run serve干了什么事情1.5 package.json的作用 2 安装UI框架2.1 安装命令2.2 全局引入 3 启动工程总结 我们已经安装好了我们的vue脚手架,用vscode打开工程目录 要自…

【FPGA】基于C5的第一个SoC工程

文章目录 前言SoC的Linux系统搭建 前言 本文是在毕业实习期间学习FPGA的SoC开发板运行全连接神经网络实例手写体的总结。 声明:本文仅作记录和操作指南,涉及到的操作会尽量细致,但是由于文件过大不会分享文件,具体软件可以自行搜…

Leetcode-每日一题【109.有序链表转换二叉搜索树】

题目 给定一个单链表的头节点 head ,其中的元素 按升序排序 ,将其转换为高度平衡的二叉搜索树。 本题中,一个高度平衡二叉树是指一个二叉树每个节点 的左右两个子树的高度差不超过 1。 示例 1: 输入: head [-10,-3,0,5,9]输出: [0,-3,9,-…

VS报错E1696 无法打开类似于stdio.h等头文件的解决办法

VS报错E1696 无法打开类似于stdio.h等头文件的解决办法 我的VS版本是2022的,然后我今天把同事在VS2017上的code(一个完整的解决方案)从svn上拿过来。结果发现,一大堆E1696的错误。主要表现就是项目中include的一些常用的c语言基础…

算法竞赛字符串常用操作大全

算法竞赛字符串常用操作总结来啦~ 👊 大家好 我是寸铁💪 考前需要刷大量真题,大家一起相互监督,每日做N题,一起上岸吧✌️ ~ 冲刺蓝桥杯省一模板大全来啦 💥 ~ 蓝桥杯4月8号就要开始了 🙏 ~ 还没背熟模…

字幕切分视频

Whisper 仓库地址: https://github.com/openai/whisper 可用模型信息: 测试视频:18段,总共447S视频(11段前:有11段开头有停顿的视频) Tiny: 跑完:142S ,11段前&#xf…

(栈队列堆) 剑指 Offer 09. 用两个栈实现队列 ——【Leetcode每日一题】

❓ 剑指 Offer 09. 用两个栈实现队列 难度:简单 用两个栈实现一个队列。队列的声明如下,请实现它的两个函数 appendTail 和 deleteHead ,分别完成在队列尾部插入整数和在队列头部删除整数的功能。(若队列中没有元素,deleteHead …

vscode远程连接提示:过程试图写入的管道不存在(删除C:\Users\<用户名>\.ssh\known_hosts然后重新连接)

文章目录 复现过程原因解决方法总结 复现过程 我是在windows上用vscode远程连接到我的ubuntu虚拟机上,后来我的虚拟机出了点问题,我把它回退了,然后再连接就出现了这个问题 原因 本地的known_hosts文件记录服务器信息与现服务器的信息冲突了…

虚拟机挂载USB设备/USB serial 连接开发板

虚拟机挂载USB设备 1、添加USB设备 2、终端输入:sudo fdisk -l 查看Device设备: 3、创建挂载目录:mkdir /mnt/usb 4、执行挂载命令:sudo mount /dev/sdb1 /mnt/usb ,查看/mnt/usb目录下是否存在U盘中的数据: 5、用…

设计模式——桥梁模式

桥梁模式 定义 桥梁模式(Bridge Pattern)也叫做桥接模式。 将抽象和显示解耦,使得两者可以独立地变化。 优缺点、应用场景 优点 抽象和实现的解耦。 这是桥梁模式的主要特点,它完全是为了解决继承的缺点而提出的设计模式。优…

成为一个年薪30W+的DFT工程师是一种什么体验?

一直以来,DFT都是数字IC设计行业中相对神秘的一个岗位。 你说他重要吧,并不是所有芯片设计公司都有这个岗位,你说他不重要吧,但凡芯片产品达到一定规模后,就必须设置DFT部门。 一、什么是DFT? DFT&#x…

原生信息流广告APP应用内增收及计费模式

比起传统的广告宣传,信息流最大的优势就在于流量的庞大。与此同时,多样化的信息流广告形式和精准的定向,还可以帮助广告主准确获取意向流量。此外,它的广告形式不强迫推送,因此也受到了广泛用户的支持和青睐。 原生信…

音视频开发实战03-FFmpeg命令行工具移植

一,背景 作为一个音视频开发者,在日常工作中经常会使用ffmpeg 命令来做很多事比如转码ffmpeg -y -i test.mov -g 150 -s 1280x720 -codec libx265 -r 25 test_h265.mp4 ,水平翻转视频:ffmpeg -i src.mp4 -vf hflip -acodec copy …

26.JavaWeb-SpringSecurity安全框架

1.SpringSecurity安全框架 Spring Security是一个功能强大且灵活的安全框架,它专注于为Java应用程序提供身份验证(Authentication)、授权(Authorization)和其他安全功能。Spring Security可以轻松地集成到Spring框架中…

MySQL数据库(五)

目录 一、数据库的约束 1.1 约束类型 1.1.1 null约束 1.1.2unique约束 1.1.3default默认值约束 1.1.4primary key主键约束 1.1.5foreign key外键约束 二、内容重点总结 一、数据库的约束 1.1 约束类型 not null - 指示某列不能存储 null值。unique - 保证某列的每行必须有唯一…

上市公司Git分支管理规范

Git分支管理策略 主分支Master 首先,代码库应该有一个、且仅有一个主分支。所有提供给用户使用的正式版本,都在这个主分支上发布。 Git主分支的名字,默认叫做Master。它是自动建立的,版本库初始化以后,默认就是在主…

采集传感器的物联网网关怎么采集数据?

随着工业4.0和智能制造的快速发展,物联网(IoT)技术的应用越来越广泛,传感器在整个物联网系统中使用非常普遍,如温度传感器、湿度传感器、光照传感器等,对于大部分物联网应用来说,采集传感器都非…

Ubuntu学习笔记(二)——文件属性与权限

文章目录 前言一、用户与用户组1.用户(文件拥有者)2.用户组3.其他人 二、Linux用户身份与用户组记录文件1. /etc/passwd2. /etc/shadow3. /etc/group 三、文件属性与权限1. 查看文件属性的方法(ls)2.文件属性详细介绍2.1 权限2.2 …

MacOS触控板缩放暂时失灵问题解决

我的系统版本为Monterey 12.5.1,亲测有效 直接创建脚本xxx.sh,并在终端执行脚本bash xxx.sh即可解决此问题,脚本内容如下: #!/bin/bashkillall Finder #kill Finder如不需要可以删除 killall Dock #kill Dock 如不需要可以删…