卷积神经网络(含案例代码)

概述

        卷积神经网络(Convolutional Neural Network,CNN)是一类专门用于处理具有网格结构数据的神经网络。它主要被设计用来识别和提取图像中的特征,但在许多其他领域也取得了成功,例如自然语言处理中的文本分类任务。

        CNN 的主要特点是它使用了卷积层(convolutional layer)来处理输入数据。卷积层通过卷积操作在输入数据上滑动一个或多个卷积核(也称为滤波器),从而学习局部特征。这种局部感知能力使得 CNN 能够有效地捕捉输入数据中的空间结构和模式。

基本组成部分

        卷积层(Convolutional Layer)

        由多个卷积核组成,每个卷积核用于检测输入数据中的特定特征。卷积操作通过在输入数据上滑动卷积核并计算局部区域的加权和来提取特征。

        池化层(Pooling Layer)

        用于减小数据的空间维度,降低计算复杂度,并且增强模型对平移变化的鲁棒性。最大池化是常用的池化操作,它选择输入区域中的最大值作为输出。

        激活函数(Activation Function)

通常在卷积层之后应用,引入非线性特性。常用的激活函数包括ReLU(Rectified Linear Unit)。

        全连接层(Fully Connected Layer)

        在卷积层和输出层之间,用于整合卷积层提取的特征并生成最终的输出。全连接层将前一层的所有节点与当前层的每个节点连接。

        CNN 在图像处理任务中表现出色,因为它能够学习到图像的局部和全局特征,具有平移不变性(通过共享权重)、参数共享和稀疏交互等特性。这些特性使得 CNN 在图像分类、目标检测、图像生成等任务中取得了显著的成功。

基本实现原理

        卷积神经网络(CNN)的实现原理涉及卷积层、池化层、激活函数、全连接层等关键组件。

输入层(Input Layer)

        接收原始输入数据,通常是图像或其他具有网格结构的数据。

卷积层(Convolutional Layer)

        使用卷积核(filter)对输入数据进行卷积操作,通过在输入数据上滑动卷积核,提取局部特征。卷积操作通过计算局部区域的加权和来生成输出特征图(feature map)。

激活函数(Activation Function)

        在卷积操作后,应用激活函数引入非线性,增加网络的表示能力。常用的激活函数包括ReLU(Rectified Linear Unit)。

池化层(Pooling Layer)

        对卷积层的输出进行下采样,减小空间维度,提高计算效率,并增强网络对平移变化的鲁棒性。最大池化是常用的池化操作,选择局部区域中的最大值作为输出。

全连接层(Fully Connected Layer)

        将池化层的输出扁平化,并通过全连接层连接到输出层。全连接层负责整合卷积层和池化层提取的特征,并生成最终的输出。

输出层(Output Layer)

        输出层根据任务的性质确定,可以是分类问题的softmax层,回归问题的线性层,或者其他适当的输出层结构。

        在训练过程中,通过反向传播算法更新网络参数,以最小化损失函数。这个过程包括前向传播(计算预测输出)、计算损失、反向传播(计算梯度),以及使用优化算法(如梯度下降)来更新权重。

        CNN 的关键之一是参数共享,即卷积核在整个输入上共享权重,这减少了参数数量,提高了模型的效率和泛化能力。此外,卷积操作和池化操作的重复使用使得网络能够逐渐构建出对输入数据的抽象表示。卷积神经网络通过多层次的特征提取和抽象,能够学习到输入数据的有用表示,从而在图像分类、目标检测等任务中表现出色。

案例代码

        下面是一个使用PyTorch的简单卷积神经网络(CNN)的代码案例。

        在这个例子中,使用PyTorch来构建一个简单的CNN模型,以进行图像分类。确保你已经安装了PyTorch:

pip install torch torchvision

        代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 设置随机种子,以保证实验的可重复性
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False# 定义CNN模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(64 * 7 * 7, 128)self.relu3 = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = self.flatten(x)x = self.relu3(self.fc1(x))x = self.fc2(x)return x# 数据预处理和加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 5
for epoch in range(num_epochs):for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 在测试集上评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / total
print('Test Accuracy: {:.2%}'.format(accuracy))

        这个示例中,使用PyTorch构建了一个包含两个卷积层和两个全连接层的简单CNN模型,并在MNIST手写数字数据集上进行训练和测试。可以根据自己的需求修改模型架构、训练参数等。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 定义CNN模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(64 * 7 * 7, 128)self.relu3 = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = self.flatten(x)x = self.relu3(self.fc1(x))x = self.fc2(x)return x# 数据预处理和加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 5
for epoch in range(num_epochs):for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 保存模型
torch.save(model.state_dict(), 'simple_cnn_model.pth')
print("Model has been saved.")# 加载模型
new_model = SimpleCNN()
new_model.load_state_dict(torch.load('simple_cnn_model.pth'))
new_model.eval()# 在测试集上评估加载的模型
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:outputs = new_model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / total
print('Test Accuracy of Loaded Model: {:.2%}'.format(accuracy))

        这个示例中,添加了模型的保存和加载过程。模型在训练完成后被保存到simple_cnn_model.pth文件,然后通过加载这个文件,可以重新创建模型并在测试集上进行评估。这对于在训练后的应用部署和模型共享上都是非常有用的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/221869.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nginx快速入门

nginx准备 文本概述参考笔记 狂神:https://www.kuangstudy.com/bbs/1353634800149213186 前端vue打包 参考:https://blog.csdn.net/weixin_44813417/article/details/121329335 打包命令: npm run build:prod nginx 下载 网址&#x…

Java集合--Map

1、Map集合概述 在Java的集合框架中&#xff0c;Map为双列集合&#xff0c;在Map中的元素是成对以<K,V>键值对的形式存在的&#xff0c;通过键可以找对所对应的值。Map接口有许多的实现类&#xff0c;各自都具有不同的性能和用途。常用的Map接口实现类有HashMap、Hashtab…

mysql迁移步骤

MySQL迁移是指将MySQL数据库从一台服务器迁移到另一台服务器。这可能是因为您需要升级服务器、增加存储空间、提高性能或改变数据库架构。 以下是MySQL迁移的一般步骤&#xff1a; 以上是MySQL迁移的一般步骤&#xff0c;具体步骤可能因您的环境和需求而有所不同。在进行迁移之…

uniapp+vue3使用canvas保存海报的使用示例,各种奇奇怪怪的问题解决办法

我们这里这里有一个需求&#xff0c;是将当前页面保存为海报分享给朋友或者保存到本地相册&#xff0c;因为是在小程序端开发的&#xff0c;所以不能使用html2canvas这个库&#xff0c;而且微信官方新推出Snapshot.takeSnapshot这个api还不是很完善&#xff0c;如果你是纯小程序…

如果一个嵌套类需要在单个方法之外仍然是可见,或者它太长,不适合放在方法内部,就应该使用成员类。

当一个嵌套类需要在单个方法之外仍然是可见&#xff0c;或者它太长不适合放在方法内部时&#xff0c;可以考虑使用成员类&#xff08;成员内部类&#xff09;。成员类是声明在类的内部但不是在方法内部的类&#xff0c;可以访问外部类的实例成员。 以下是一个示例&#xff0c;…

【问题处理】—— lombok 的 @Data 大小写区分不敏感

问题描述 今天在项目本地编译的时候&#xff0c;发现有个很奇怪的问题&#xff0c;一直提示某位置找不到符号&#xff0c; 但是实际在Idea中显示确实正常的&#xff0c;一开始以为又是IDEA的故障&#xff0c;所以重启了IDEA&#xff0c;并执行了mvn clean然后重新编译。但是问…

ASF-YOLO开源 | SSFF融合+TPE编码+CPAM注意力,精度提升!

目录 摘要 1 Introduction 2 Related work 2.1 Cell instance segmentation 2.2 Improved YOLO for instance segmentation 3 The proposed ASF-YOLO model 3.1 Overall architecture 3.2 Scale sequence feature fusion module 3.3 Triple feature encoding module …

【Python网络爬虫入门教程3】成为“Spider Man”的第三课:从requests到scrapy、爬取目标网站

Python 网络爬虫入门&#xff1a;Spider man的第三课 写在最前面从requests到scrapy利用scrapy爬取目标网站更多内容 结语 写在最前面 有位粉丝希望学习网络爬虫的实战技巧&#xff0c;想尝试搭建自己的爬虫环境&#xff0c;从网上抓取数据。 前面有写一篇博客分享&#xff0…

【实用技巧】从文件夹内批量筛选指定文件并将其复制到目标文件夹

原创文章&#xff0c;转载请注明出处&#xff01; 从文件夹中批量提取指定文件。 使用DOS命令&#xff0c;根据TXT文件中列出指定文件名&#xff0c;批量实现查找指定文件夹里的文件并复制到新的文件夹。 文中给出使用DOS命令和建立批处理文件两种方法。 文件准备 工作文件…

vite(一)——基本了解和依赖预构建

文章目录 一、什么是构建工具&#xff1f;1.为什么使用构建工具&#xff1f;2.构建工具的作用&#xff1f;3.构建工具怎么用&#xff1f; 二、经典面试题&#xff1a;webpack和vite的区别1.编译方式不同2.基础概念不同3.开发效率不同4.扩展性不同5.应用场景不同6.总结&#xff…

vue 组件实现v-model

ChatgGPT4.0国内站点: 海鲸AI 在Vue中&#xff0c;可以通过使用v-model指令来实现双向数据绑定。如果你想在自定义组件中使用v-model&#xff0c;需要做一些额外的工作。 首先&#xff0c;在组件的props中定义一个名为value的属性&#xff0c;用于接收父组件传递的值。然后&a…

QT- QT-lximagerEidtor图片编辑器

QT- QT-lximagerEidtor图片编辑器 一、演示效果二、关键程序三、下载链接 功能如下&#xff1a; 1、缩放、旋转、翻转和调整图像大小 2、幻灯片 3、缩略图栏&#xff08;左、上或下&#xff09;&#xff1b;不同的缩略图大小 4、Exif数据栏 5、内联图像重命名 6、自定义快捷方式…

MybatisPlus的分页插件

PaginationInnerInterceptor 此插件是核心插件,目前代理了 Executor#query 和 Executor#update 和 StatementHandler#prepare 方法。 在SpringBoot环境中配置方式如下&#xff1a; /*** author giserDev* description 配置分页插件、方言、mapper包扫描等* date 2023-12-13 …

删除一个字符串中的指定字母,如:字符串 “aca“,删除其中的 a 字母。

#include<stdio.h> #include<stdlib.h> #include<string.h> // 删除字符串中指定字母函数 char* deleteCharacters(char * str, char * charSet) { int hash [256]; if(NULL charSet) return str; for(int i 0; i < 256; i) …

B - Team Gym - 102801B ( 网络流问题)

题目链接 先占个坑&#xff0c;有空写一下思路 #include <bits/stdc.h> using namespace std; #define pi acos(-1) #define xx first #define yy second #define endl "\n" #define lowbit(x) x & (-x) #define int long long #define ull unsigned lo…

Vue3安装使用Mock.js--解决跨域

首先使用axios发送请求到模拟服务器上&#xff0c;再将mock.js模拟服务器数据返回给客户端。打包工具使用的是vite。 1.安装 npm i axios -S npm i mockjs --save-dev npm i vite-plugin-mock --save-dev 2.在vite.config.js文件中配置vite-plugin-mock等消息 import { viteMo…

RedisHelper

Redis面试题&#xff1a; 1、什么是事务&#xff1f;2、Redis中有事务吗&#xff1f;3、Redis中的事务可以回滚吗&#xff1f; 答&#xff1a; 1、事务是指一个完整的动作&#xff0c;要么全部执行&#xff0c;要么什么也没有做 2、Redis中有事务&#xff0c;Redis 事务不是严…

分页操作中使用LIMIT和OFFSET后出现慢查询的原因分析

事情经过 最近在做批量数据处理的相关业务&#xff0c;在和下游对接时&#xff0c;发现拉取他们的业务数据刚开始很快&#xff0c;后面会越来越慢&#xff0c;40万数据一个小时都拉不完。经过排查后&#xff0c;发现对方用了很坑的分页查询方式 —— LIMIT OFFSET&#xff0c;…

【前端学习记录】Vue前端规范整理

文章目录 前言一、文件及文件夹命名二、钩子顺序三、注释规范四、组件封装五、CSS编码规范六、JS编码规范 前言 优秀的项目源码&#xff0c;即使是多人开发&#xff0c;看代码也如一人之手。统一的编码规范&#xff0c;可使代码更易于阅读&#xff0c;易于理解&#xff0c;易于…

mysql中NULL值

mysql中NULL值表示“没有值”&#xff0c;它跟空字符串""是不同的 例如&#xff0c;执行下面两个插入记录的语句&#xff1a; insert into test_table (description) values (null); insert into test_table (description) values ();执行以后&#xff0c;查看表的…