深度学习之PyTorch实现卷积神经网络(CNN)

在深度学习领域,卷积神经网络(Convolutional Neural Networks,CNN)是一种非常强大的模型,专门用于处理图像数据。CNN通过卷积操作和池化操作来提取图像中的特征,具有较好的特征学习能力,特别适用于图像识别和计算机视觉任务。PyTorch作为一种流行的深度学习框架,提供了方便易用的工具来构建和训练CNN模型。本文将介绍如何使用PyTorch构建一个简单的CNN,并通过一个图像分类任务来演示其效果。

1. CNN的结构

典型的CNN结构包括卷积层、池化层和全连接层。卷积层通过卷积操作提取图像特征,池化层通过降采样操作减小特征图的尺寸,全连接层用于最终的分类。

在这里插入图片描述

2. 环境配置

在开始之前,确保已经安装了PyTorch和相关的Python库。可以通过以下命令安装:

pip install torch torchvision matplotlib

3. 数据集准备

在这个示例中,我们将使用PyTorch提供的CIFAR-10数据集,它包含了10个类别的60000张32x32彩色图像。我们将图像分为训练集和测试集,并加载到PyTorch的数据加载器中。

import torch
import torchvision
import torchvision.transforms as transforms# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)# 测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

4. 构建CNN模型

我们将构建一个简单的CNN模型,包括卷积层、池化层、全连接层和激活函数。这个模型将接受3通道的32x32图像作为输入,并输出10个类别的概率分布。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()# 定义网络层:卷积层+池化层self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层self.linear1 = nn.Linear(576, 120)self.linear2 = nn.Linear(120, 84)self.out = nn.Linear(84, 10)def forward(self, x):# 卷积+relu+池化x = F.relu(self.conv1(x))x = self.pool1(x)# 卷积+relu+池化x = F.relu(self.conv2(x))x = self.pool2(x)# 将特征图做成以为向量的形式:相当于特征向量x = x.reshape(x.size(0), -1)# 全连接层x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))# 返回输出结果return self.out(x)net = Net()

5. 模型训练

定义损失函数和优化器,并在训练集上训练CNN模型。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)# 训练网络
for epoch in range(2):  # 多次遍历数据集running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:  # 每2000个小批量数据打印一次损失值print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('训练结束!!!')

6. 模型测试

在测试集上评估训练好的模型的性能。

correct = 0
total = 0
# 禁用梯度追踪
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('10000 测试图片的准确率: %d %%' % (100 * correct / total))

7. 结果分析

通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。

%%’ % (100 * correct / total))


## 7. 结果分析通过训练和测试过程,我们可以得到CNN模型在CIFAR-10数据集上的准确率。进一步分析模型在每个类别上的表现,以及可视化模型的特征图等,可以帮助我们更好地理解模型的行为。这篇博客介绍了如何使用PyTorch构建和训练一个简单的卷积神经网络。通过实际的代码示例,读者可以了解CNN模型的基本原理,并掌握如何在PyTorch中实现和训练这样的模型。希望本文能够对你理解和应用深度学习模型有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/819191.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

优先队列处理文件的锁定和自动解锁

【背景】最近要做一个防篡改的功能,一开始是采用事件型的方式实现的,结果发现会有一种情况"如果某个文件短时间一直被外部进行多次恶意操作"时,一直防也不是个事,应该在加一层防护—文件锁定,这样就舒服多了…

云原生:企业数字化转型的引擎与未来

一,引言 随着信息技术的飞速发展,企业数字化转型已成为时代的必然趋势。在这场深刻的变革中,云原生技术以其独特的优势,逐渐成为推动企业数字化转型的核心动力。本文将详细探讨云原生技术的内涵、发展历程,以及在企业数…

【Java开发指南 | 第八篇】Java变量、构造方法、创建对象

专栏:Java开发指南 CSDN秋说 文章目录 Java变量构造方法创建对象 Java变量 局部变量:在方法、构造方法或者语句块中定义的变量被称为局部变量。变量声明和初始化都是在方法中,方法结束后,变量就会自动销毁。成员变量(…

研究生,该学单片机还是plc。?

PLC门槛相对较低,但是在深入学习和应用时,仍然有很高的技术要求。我这里有一套单片机入门教程,不仅包含了详细的视频 讲解,项目实战。如果你渴望学习单片机,不妨点个关注,给个评论222,私信22&am…

OpenHarmony实战开发-图片选择和下载保存案例。

介绍 本示例介绍图片相关场景的使用:包含访问手机相册图片、选择预览图片并显示选择的图片到当前页面,下载并保存网络图片到手机相册或到指定用户目录两个场景。 效果图预览 使用说明 从主页通用场景集里选择图片选择和下载保存进入首页。分两个场景点…

Linux UDP通信系统

目录 一、socket编程接口 1、socket 常见API socket():创建套接字 bind():将用户设置的ip和port在内核中和我们的当前进程关联 listen() accept() 2、sockaddr结构 3、inet系列函数 二、UDP网络程序—发送消息 1、服务器udp_server.hpp initS…

Hadoop HDFS:海量数据的存储解决方案

引言 在大数据时代,数据的存储与处理成为了业界面临的一大挑战。Hadoop的分布式文件系统(Hadoop Distributed File System,简称HDFS)作为一个高可靠性、高扩展性的文件系统,提供了处理海量数据的有效解决方案。本文将…

stm32开发之threadx整合letter-shell 组件记录

前言 使用过rt-thread的shell 命令交互的方式,觉得比较方便,所以在threadx中也移植个shell的组件。这里使用的是letter-shellletter-shell 核心的逻辑在于组件通过链接文件自动初始化或自动添加的两种方式,方便开发源码仓库 实验(核心代码) shell 线程…

rhce day1

一 . 在系统中设定延迟任务要求如下 在系统中建立 easylee 用户,设定其密码为 easylee 延迟任务由 root 用户建立 要求在 5 小时后备份系统中的用户信息文件到 /backup 中 确保延迟任务是使用非交互模式建立 确保系统中只有 root 用户和 easylee 用户可以执行延…

✌粤嵌—2024/3/11—跳跃游戏

代码实现&#xff1a; 方法一&#xff1a;递归记忆化 int path; int used[10000];bool dfs(int *nums, int numsSize) {if (path numsSize - 1) {return true;}for (int i 1; i < nums[path]; i) {if (used[path i]) {continue;}path i;used[path] 1;if (dfs(nums, num…

“华为杯“华南理工大学程序设计竞赛 L-再一道好题

题目 #include<bits/stdc.h> using namespace std; #define int long long #define pb push_back #define fi first #define se second const int maxn 1e6 5; const int inf 1e9 5;using namespace std;int n, m;void solve(){int res 0;int q;string s;int k;cin …

北京市为例的空气质量分析报告分析【免费送】

原始数据&#xff1a; 日期名称类型所属区拥挤指数速度客流指数20240405世界之花假日广场购物;购物中心大兴区2.46621.369.4920240405华润五彩城购物;购物中心海淀区2.01329.7111.1720240405北京市百货大楼购物;购物中心东城区1.85615.938.2320240405apm购物;购物中心东城区1.…

C#开源项目推荐

winform界面开发 SunnyUI SharpSCADALite工业控制数据采集 https://github.com/qwe7922142/SharpSCADALite net中国集合优秀Net项目 https://gitee.com/dotnetchina 数据库管理系统 https://gitee.com/dotnetchina/SmartSQL 工作流项目 RoadFlow-UnMean 网口通讯 weaving-…

Grok-1.5 Vision:X AI发布突破性的多模态AI模型,超越GPT 4V

在人工智能领域&#xff0c;多模态模型的发展一直是科技巨头们竞争的焦点。 近日&#xff0c;马斯克旗下的X AI公司发布了其最新的多模态模型——Grok-1.5 Vision&#xff08;简称Grok-1.5V&#xff09;&#xff0c;这一模型在处理文本和视觉信息方面展现出了卓越的能力&#x…

Python统计模型线性推理事件前因后果

&#x1f3af;要点 经典统计方法&#xff1a;&#x1f58a; A/B测试&#xff0c;计算两个均值样本的置信区间&#xff0c;&#x1f58a;最小二乘法计算变量估值&#xff0c;&#x1f58a;使用非线性关系式表示线性回归。&#x1f58a;实例&#xff1a;高等教育和数学高分的事件…

批量导入照片

clear clc close all % 创建或获取演示文稿对象 ppt Presentation(new_presentation.pptx, 演示文稿1.pptx); open(ppt); % 添加新的幻灯片 slide1 add(ppt, Title Slide); % 指定第一张图片路径 imagePath1 C:\Users\Administrator\Desktop\001\阵风因子-横风向图20…

Linux usermod命令教程:如何修改用户属性(附案例详解和注意事项)

Linux usermod命令介绍 usermod命令是Linux系统中用来修改用户属性的命令。它可以修改用户的登录名、家目录、登录shell、用户组等信息。 Linux usermod命令适用的Linux版本 usermod命令在大多数Linux发行版中都是可用的&#xff0c;包括Debian、Ubuntu、Alpine、Arch Linux…

即席查询笔记

文章目录 一、Kylin4.x1、Kylin概述1.1 定义1.2 Kylin 架构1.3 Kylin 特点1.4 Kylin4.0 升级 2、Kylin 环境搭建2.1 简介2.2 Spark 安装和部署2.3 Kylin 安装和部署2.4 Kylin 启动环境准备2.5 Kylin 启动和关闭 3、快速入门3.1 数据准备3.2 Kylin项目创建入门3.3 Hive 和 Kylin…

Canvas图形编辑器-数据结构与History(undo/redo)

Canvas图形编辑器-数据结构与History(undo/redo) 这是作为 社区老给我推Canvas&#xff0c;于是我也学习Canvas做了个简历编辑器 的后续内容&#xff0c;主要是介绍了对数据结构的设计以及History能力的实现。 在线编辑: https://windrunnermax.github.io/CanvasEditor开源地…