《动手学深度学习(PyTorch版)》笔记6.3

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter6 Convolutional Neural Network(CNN)

6.3 LeNet

在这里插入图片描述

LeNet模型中每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用 5 × 5 5\times 5 5×5卷积核和一个sigmoid激活函数,这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道,每个 2 × 2 2\times2 2×2池操作(步幅2)通过空间下采样将维数减少4倍。为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入(第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示)。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as pltnet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).devicemetric = d2l.Accumulator(2)#metric[0]:正确预测的数量,metric[1]:总预测的数量。with torch.no_grad():for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]# BERT微调所需,将输入转移到GPU上(之后将介绍)else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):#@save"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):metric = d2l.Accumulator(3)#metric[i]分别是训练损失之和,训练准确率之和,样本数net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/663971.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python实现PDF到HTML的转换

PDF文件是共享和分发文档的常用选择,但提取和再利用PDF文件中的内容可能会非常麻烦。而利用Python将PDF文件转换为HTML是解决此问题的理想方案之一,这样做可以增强文档可访问性,使文档可搜索,同时增强文档在不同场景中的实用性。此…

【FPGA原型验证】附录基础知识:FPGA/CPLD基本结构与实现原理

聚焦Xilinx ISE 介绍Xilinx公司及其产品的基本情况,并在此基础上描述了CPLD和FPGA的内部结构及基本原理。 1.1 Xilinx公司及其产品介绍 总部设在加利福尼亚圣何塞市(San Jose)的Xilinx是全球领先的可编程逻辑解决方案的供应商,图1-1为公司标志。 Xilinx公司的业务是研发…

后端——go系统学习笔记(不断更新中......)

数组 固定大小 初始化 arr1 : [3]int{1, 2, 3} arr2 : [...]int{1, 2, 3} var arr3 []int var arr4 [4]int切片 长度是动态的 初始化 arr[0:3] slice : []int{1,2,3} slice : make([]int, 10)len和cap len是获取切片、数组、字符串的长度——元素的个数cap是获取切片的容量—…

Android PMS——ADB命令安装流程(七)

前面的文章我们介绍了系统应用解析流程和通过 PackageInstaller.apk安装应用程序的相关流程,这一篇我们来分析使用 ADB 命令来实现 APK 安装流程。 一、ADB安装命令 ADB命令使用 adb install [选项] [APK绝对路径] 常见选项如下: -r:覆盖安装,保存原有数据; -t:…

深度学习入门笔记(七)卷积神经网络CNN

我们先来总结一下人类识别物体的方法: 定位。这一步对于人眼来说是一个很自然的过程,因为当你去识别图标的时候,你就已经把你的目光放在了图标上。虽然这个行为不是很难,但是很重要。看线条。有没有文字,形状是方的圆的,还是长的短的等等。看细节。纹理、颜色、方向等。卷…

Java正则表达式之Pattern和Matcher

目录 前言一、Pattern和Matcher的简单使用二、Pattern详解2.1 Pattern 常用方法2.1.1 compile(String regex)2.1.2 matches(String regex, CharSequence input)2.1.3 split(CharSequence input)2.1.4 pattern()2.1.5 matcher(CharSequence input) 三、Matcher详解3.1 Matcher 常…

JSP和JSTL板块:第三节 JSP四大域对象 来自【汤米尼克的JAVAEE全套教程专栏】

JSP和JSTL板块:第三节 JSP四大域对象 一、page范围二、request范围三、session范围四、application范围 在服务器和客户端之间、各个网页之间、哪怕同一个网页之内,总是需要传递各种参数值,这时JSP的内置对象就是传递这些参数的载具。内置对象…

Redis面试题38

人工智能在医疗领域有哪些应用? 答:人工智能在医疗领域有许多应用,下面是一些常见的例子: 图像识别和辅助诊断:人工智能可以用于图像识别和辅助诊断,例如在医学影像领域,人工智能可以辅助医生分…

​(四)hive的搭建2

在&#xff08;三&#xff09;hive的搭建1中我们搭建好了hive环境&#xff0c;但是只能本地访问&#xff0c;在本节中配置Hive的访问方式。 1.元数据服务的方式 1.1 编辑hive-site.xml sudo vi hive-site.xml 在文件最后增加以下内容 <!– 指定存储元数据要连接的地址 –…

WebGL 1.0 内置函数

前言 本篇文章介绍WebGL 1.0 shader中支持的内置函数 角度弧度转化 角度转弧度radians 计算公式&#xff1a; R π d e g r e e 180 R \pi \times degree \div 180 Rπdegree180 float radians(float degree) vec2 radians(vec2 degree) vec3 radians(vec3 degree)…

无里程计下的纯跟踪算法实现

文章目录 概要生成相机坐标系下的三维坐标无里程计下的纯跟踪算法实现 概要 当你只有一个相机的时候&#xff0c;想要快速实现机器人跟随功能&#xff0c;没有里程计的情况下&#xff0c;就可以看这里了。这篇博文实现了一个无里程计下的纯跟踪算法。 生成相机坐标系下的三维…

1、安全开发-Python爬虫EDUSRC目标FOFA资产Web爬虫解析库

用途&#xff1a;个人学习笔记&#xff0c;有所借鉴&#xff0c;欢迎指正 前言&#xff1a; 主要包含对requests库和Web爬虫解析库的使用&#xff0c;python爬虫自动化&#xff0c;批量信息收集 Python开发工具&#xff1a;PyCharm 2022.1 激活破解码_安装教程 (2022年8月25日…

For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

环境&#xff1a; wsl ubuntu22.04 vits2 问题描述&#xff1a; RuntimeError: CUDA error: unknown error [rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. [rank0]: For debugging …

【Android新版本兼容】startActivityForResult()方法被弃用的解决方案

提示&#xff1a;此文章仅作为本人记录日常学习使用&#xff0c;若有存在错误或者不严谨得地方欢迎指正。 文章目录 一、使用registerForActivityResult()方法 一、使用registerForActivityResult()方法 startActivityForResult()方法在appcompat库1.3.0或更高版本中被废弃了&…

Linux下find命令详解

find #查找文件 #按照文件名、大小、时间、权限、类型、所属者、所属组来搜索文件 格式&#xff1a; find 查找路径 查找条件 具体条件&#xff08;按文件名或时间大小等&#xff09; 操作 注意&#xff1a; find命令默认的操作是print输出 find是检索…

MATLAB绘制电磁场

MATLAB绘制电磁场举例: clc;close all;clear all;warning off;%清除变量 rand(seed, 100); randn(seed, 100); format long g; m12 for k1:m for j1:m if k1 V(j,k)1; elseif((j1)|(jm)|(km)) V(j,k)0; else …

PKG系统安装包及IPSW固件:MacOS 11-14 Sonoma 正式版

MacOS 14 Sonoma&#xff0c;为提高生产力和创造力带来了全新的功能&#xff0c;有了更多使用小部件和令人惊叹的新屏幕保护程序进行个性化设置的方法&#xff0c;对Safari浏览器和视频会议进行了重大更新&#xff0c;以及优化的游戏体验——Mac体验比以往任何时候都更好。 mac…

贝叶斯的缺点

贝叶斯方法是一种统计学习方法&#xff0c;通过利用贝叶斯定理来计算给定先验概率的情况下&#xff0c;后验概率的条件概率。虽然贝叶斯方法在许多领域中应用广泛且有效&#xff0c;但也存在一些缺点。以下是一些贝叶斯方法的缺点的例子&#xff1a; 1、先验概率的选择 贝叶斯方…

体验数学之美:绘制曲线

文章目录 一、实战概述二、实战步骤(一)圆锥曲线1、绘制圆2、绘制椭圆3、绘制双曲线4、绘制抛物线(二)心形线(三)雅各布线一、实战概述 通过Python编程,我们可以借助matplotlib与numpy库绘制一系列迷人的数学曲线,展现数学之美。例如,利用极坐标绘制椭圆(圆锥曲线的一…

第二十四天| 77. 组合

Leetcode 77. 组合 题目链接&#xff1a;77 组合 题干&#xff1a;给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。你可以按 任何顺序 返回答案。 思考&#xff1a;回溯法。把回溯法的搜索过程抽象为树形结构。 每次从集合中选取元素&#xff0…