深度学习之PyTorch实现卷积神经网络(CNN)

在深度学习领域,卷积神经网络(Convolutional Neural Networks,CNN)是一种非常强大的模型,专门用于处理图像数据。CNN通过卷积操作和池化操作来提取图像中的特征,具有较好的特征学习能力,特别适用于图像识别和计算机视觉任务。PyTorch作为一种流行的深度学习框架,提供了方便易用的工具来构建和训练CNN模型。本文将介绍如何使用PyTorch构建一个简单的CNN,并通过一个图像分类任务来演示其效果。

1. CNN的结构

典型的CNN结构包括卷积层、池化层和全连接层。卷积层通过卷积操作提取图像特征,池化层通过降采样操作减小特征图的尺寸,全连接层用于最终的分类。

在这里插入图片描述

2. 环境配置

在开始之前,确保已经安装了PyTorch和相关的Python库。可以通过以下命令安装:

pip install torch torchvision matplotlib

3. 数据集准备

在这个示例中,我们将使用PyTorch提供的CIFAR-10数据集,它包含了10个类别的60000张32x32彩色图像。我们将图像分为训练集和测试集,并加载到PyTorch的数据加载器中。

import torch
import torchvision
import torchvision.transforms as transforms# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)# 测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

4. 构建CNN模型

我们将构建一个简单的CNN模型,包括卷积层、池化层、全连接层和激活函数。这个模型将接受3通道的32x32图像作为输入,并输出10个类别的概率分布。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()# 定义网络层:卷积层+池化层self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层self.linear1 = nn.Linear(576, 120)self.linear2 = nn.Linear(120, 84)self.out = nn.Linear(84, 10)def forward(self, x):# 卷积+relu+池化x = F.relu(self.conv1(x))x = self.pool1(x)# 卷积+relu+池化x = F.relu(self.conv2(x))x = self.pool2(x)# 将特征图做成以为向量的形式:相当于特征向量x = x.reshape(x.size(0), -1)# 全连接层x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))# 返回输出结果return self.out(x)net = Net()

5. 模型训练

定义损失函数和优化器,并在训练集上训练CNN模型。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)# 训练网络
for epoch in range(2):  # 多次遍历数据集running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:  # 每2000个小批量数据打印一次损失值print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('训练结束!!!')

6. 模型测试

在测试集上评估训练好的模型的性能。

correct = 0
total = 0
# 禁用梯度追踪
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('10000 测试图片的准确率: %d %%' % (100 * correct / total))

7. 结果分析

通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。

%%’ % (100 * correct / total))


## 7. 结果分析通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。这篇博客介绍了如何使用PyTorch构建和训练一个简单的卷积神经网络。通过实际的代码示例,读者可以了解CNN模型的基本原理,并掌握如何在PyTorch中实现和训练这样的模型。希望本文能够对你理解和应用深度学习模型有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/819191.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云原生:企业数字化转型的引擎与未来

一,引言 随着信息技术的飞速发展,企业数字化转型已成为时代的必然趋势。在这场深刻的变革中,云原生技术以其独特的优势,逐渐成为推动企业数字化转型的核心动力。本文将详细探讨云原生技术的内涵、发展历程,以及在企业数…

【Java开发指南 | 第八篇】Java变量、构造方法、创建对象

专栏:Java开发指南 CSDN秋说 文章目录 Java变量构造方法创建对象 Java变量 局部变量:在方法、构造方法或者语句块中定义的变量被称为局部变量。变量声明和初始化都是在方法中,方法结束后,变量就会自动销毁。成员变量(…

研究生,该学单片机还是plc。?

PLC门槛相对较低,但是在深入学习和应用时,仍然有很高的技术要求。我这里有一套单片机入门教程,不仅包含了详细的视频 讲解,项目实战。如果你渴望学习单片机,不妨点个关注,给个评论222,私信22&am…

OpenHarmony实战开发-图片选择和下载保存案例。

介绍 本示例介绍图片相关场景的使用:包含访问手机相册图片、选择预览图片并显示选择的图片到当前页面,下载并保存网络图片到手机相册或到指定用户目录两个场景。 效果图预览 使用说明 从主页通用场景集里选择图片选择和下载保存进入首页。分两个场景点…

Linux UDP通信系统

目录 一、socket编程接口 1、socket 常见API socket():创建套接字 bind():将用户设置的ip和port在内核中和我们的当前进程关联 listen() accept() 2、sockaddr结构 3、inet系列函数 二、UDP网络程序—发送消息 1、服务器udp_server.hpp initS…

stm32开发之threadx整合letter-shell 组件记录

前言 使用过rt-thread的shell 命令交互的方式,觉得比较方便,所以在threadx中也移植个shell的组件。这里使用的是letter-shellletter-shell 核心的逻辑在于组件通过链接文件自动初始化或自动添加的两种方式,方便开发源码仓库 实验(核心代码) shell 线程…

rhce day1

一 . 在系统中设定延迟任务要求如下 在系统中建立 easylee 用户,设定其密码为 easylee 延迟任务由 root 用户建立 要求在 5 小时后备份系统中的用户信息文件到 /backup 中 确保延迟任务是使用非交互模式建立 确保系统中只有 root 用户和 easylee 用户可以执行延…

✌粤嵌—2024/3/11—跳跃游戏

代码实现&#xff1a; 方法一&#xff1a;递归记忆化 int path; int used[10000];bool dfs(int *nums, int numsSize) {if (path numsSize - 1) {return true;}for (int i 1; i < nums[path]; i) {if (used[path i]) {continue;}path i;used[path] 1;if (dfs(nums, num…

“华为杯“华南理工大学程序设计竞赛 L-再一道好题

题目 #include<bits/stdc.h> using namespace std; #define int long long #define pb push_back #define fi first #define se second const int maxn 1e6 5; const int inf 1e9 5;using namespace std;int n, m;void solve(){int res 0;int q;string s;int k;cin …

北京市为例的空气质量分析报告分析【免费送】

原始数据&#xff1a; 日期名称类型所属区拥挤指数速度客流指数20240405世界之花假日广场购物;购物中心大兴区2.46621.369.4920240405华润五彩城购物;购物中心海淀区2.01329.7111.1720240405北京市百货大楼购物;购物中心东城区1.85615.938.2320240405apm购物;购物中心东城区1.…

Grok-1.5 Vision:X AI发布突破性的多模态AI模型,超越GPT 4V

在人工智能领域&#xff0c;多模态模型的发展一直是科技巨头们竞争的焦点。 近日&#xff0c;马斯克旗下的X AI公司发布了其最新的多模态模型——Grok-1.5 Vision&#xff08;简称Grok-1.5V&#xff09;&#xff0c;这一模型在处理文本和视觉信息方面展现出了卓越的能力&#x…

即席查询笔记

文章目录 一、Kylin4.x1、Kylin概述1.1 定义1.2 Kylin 架构1.3 Kylin 特点1.4 Kylin4.0 升级 2、Kylin 环境搭建2.1 简介2.2 Spark 安装和部署2.3 Kylin 安装和部署2.4 Kylin 启动环境准备2.5 Kylin 启动和关闭 3、快速入门3.1 数据准备3.2 Kylin项目创建入门3.3 Hive 和 Kylin…

【个人博客搭建】(3)添加SqlSugar ORM

1、安装sqlsugar。在models下的依赖项那右击选择管理Nuget程序包&#xff0c;输入sqlsugarcore&#xff08;因为我们用的是netcore&#xff0c;而不是net famework所以也对应sqlsugarcore&#xff09;&#xff0c;出来的第一个就是了&#xff0c;然后点击选择版本&#xff0c;一…

密码学 | 椭圆曲线 ECC 密码学入门(四)

目录 正文 1 曲线方程 2 点的运算 3 求解过程 4 补充&#xff1a;有限域 ⚠️ 知乎&#xff1a;【密码专栏】动手计算双线性对&#xff08;中&#xff09; - 知乎 ⚠️ 写在前面&#xff1a;本文属搬运博客&#xff0c;自己留着学习。注意&#xff0c;这篇博客与前三…

代码随想录算法训练营Day56|LC583 两个字符串的删除操作LC72 编辑距离

一句话总结&#xff1a;看起来复杂&#xff0c;动规分析以后就比较简单。 原题链接&#xff1a;583 两个字符串的删除操作 本质就是求两个字符串的最短子序列的长度。已经做过&#xff0c;不再详解。 class Solution {public int minDistance(String word1, String word2) {/…

Python(11):网络编程

文章目录 一、一些基本概念二、软件的开发架构&#xff08;c/s架构和b/s架构&#xff09;三、OSI模型四、socket套接字编程1.socket编程过程2.python中的socket编程 一、一些基本概念 来了解一些网络的基本概念 名词解释IP&#xff08;互联网协议地址&#xff09;IP用来标识网…

PCB基础介绍

一&#xff0c;单层板&#xff1a; 1&#xff0c;铜皮 和导线类似&#xff0c;提供电路板上的电信号传导路径。 因为铜具有良好的导热性能&#xff0c;因此铜皮还可以用于散热。在高功率电子设备中&#xff0c;通过在PCB上增加铜皮面积和散热片&#xff0c;可以提高散热效果…

数字晶体管数字三极管

数字晶体管 指内部集成了电阻的三极管&#xff0c;有PNP和NPN型&#xff0c;也有双管&#xff0c;双管有3种形式&#xff0c;其中一种是PNPNPN。下面以双NPN示例&#xff0c;好处是外面没有电阻&#xff0c;批量应用时&#xff0c;焊点费用就可省下不少。双NPN的用在串口自动下…

开源相机管理库Aravis例程学习(二)——连续采集multiple-acquisition-main-thread

开源相机管理库Aravis例程学习&#xff08;二&#xff09;——连续采集multiple-acquisition-main-thread 简介例程代码函数说明arv_camera_set_acquisition_modearv_camera_create_streamarv_camera_get_payloadarv_buffer_newarv_stream_push_bufferarv_camera_start_acquisi…