Pytorch 复习总结 3

Pytorch 复习总结,仅供笔者使用,参考教材:

  • 《动手学深度学习》
  • Stanford University: Practical Machine Learning

本文主要内容为:Pytorch 多层感知机。

本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟合现象提出解决办法。


Pytorch 语法汇总:

  • Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
  • Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
  • Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
  • Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
  • Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
  • Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;

目录

  • 一. 多层感知机
    • 1. 读取数据集
    • 2. 神经网络模型
    • 3. 激活函数
    • 4. 损失函数
    • 5. 优化器
    • 6. 训练
  • 二. 过拟合的缓解
    • 1. 权重衰减
    • 2. Dropout

一. 多层感知机

虽然线性模型易于实现和理解、计算成本低、泛化能力强,但是对于一些非线性问题,可能会违反线性模型的单调性。为此,多层感知器引入了隐藏层来克服线性模型的限制,并且加入激活函数以增强网络非线性建模能力。

1. 读取数据集

同 Pytorch 复习总结 2 中 Softmax 回归的数据读取,继续使用 Fashion-MNIST 图像分类数据集:

import torch
import torchvision
from torch.utils import data
from torchvision import transformsdef load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集并将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True),data.DataLoader(mnist_test, batch_size, shuffle=False))batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

2. 神经网络模型

先将输入的图像展平,然后使用 2 个全连接层进行处理,中间的全连接层需要使用激活函数激活,最后一层全连接层作为输出:

from torch import nn
net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Linear(256, 10)
)

仍然使用 init_weights() 函数按正态分布初始化所有全连接层的权重:

def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

3. 激活函数

上一节使用了 ReLU 函数进行激活,在实际应用中,还可以使用 sigmoid、tanh 等函数激活。ReLU、sigmoid、tanh 函数的梯度可视化如下:

import torch
from matplotlib import pyplot as pltx = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
# y = torch.relu(x)
# y = torch.sigmoid(x)
y = torch.tanh(x)
y.backward(torch.ones_like(x), retain_graph=True)
plt.figure(figsize=(5, 2.5))
plt.plot(x.detach(), x.grad)
plt.show()

4. 损失函数

同 Softmax 回归:

loss = nn.CrossEntropyLoss(reduction='none')

5. 优化器

同 Softmax 回归:

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

6. 训练

同 Softmax 回归,可以将训练过程封装成函数:

def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def train_net(net, train_iter, test_iter, loss, num_epochs, trainer):for epoch in range(num_epochs):     # 迭代训练轮次net.train()                     # 将模型设置为训练模式train_loss_sum = 0.0            # 训练损失总和train_acc_sum = 0.0             # 训练准确度总和sample_num = 0                  # 样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y)trainer.zero_grad()l.mean().backward()trainer.step()train_loss_sum += l.sum()train_acc_sum += accuracy(y_hat, y)sample_num += y.numel()train_loss = train_loss_sum / sample_numtrain_acc = train_acc_sum / sample_numnet.eval()                      # 将模型设置为评估模式test_acc_sum = 0.0test_sample_num = 0for X, y in test_iter:test_acc_sum += accuracy(net(X), y)test_sample_num += y.numel()test_acc = test_acc_sum / test_sample_numprint(f'epoch {epoch + 1}, 'f'train loss {train_loss:.4f}, train acc {train_acc:.4f}, 'f'test acc {test_acc:.4f}')num_epochs = 10
train_net(net, train_iter, test_iter, loss, num_epochs, trainer)

二. 过拟合的缓解

当模型过于复杂、训练数据太少、迭代轮数太多时,就会出现过拟合现象。解决过拟合的方法有很多:

  • 增加数据量:增加训练数据可以帮助模型更好地学习数据的真实规律,减少过拟合的发生;
  • 简化模型:降低模型的复杂度,可以通过减少模型的参数数量、使用正则化等方法来实现;
  • 交叉验证:使用交叉验证来评估模型的泛化能力,选择最优的模型;
  • 提前停止:即 Dropout,在训练过程中监控模型在验证集上的表现,当验证集误差不再下降甚至开始上升时,及时停止训练,防止模型过拟合;
  • 集成学习:使用集成学习方法(如随机森林、梯度提升树等)降低模型的方差,提高泛化能力。

下面介绍几种常用的正则化方法。

1. 权重衰减

权重衰减 (Weight Decay) 通过向损失函数中添加一个惩罚项来减小模型复杂度,以防止过拟合。惩罚项也叫 正则项,通常是权重的平方和(即 L2 范数)或权重的绝对值和(即 L1 范数)乘以一个正则化系数。

以线性回归的损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 为例,使用优化器训练时,在损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 上添加 L2 范数如下:
L ( w , b ) + λ 2 ∥ w ∥ 2 = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b)+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ =\frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right)^2+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ L(w,b)+2λw2=n1i=1n21(wx(i)+by(i))2+2λw2

损失函数中没有添加偏置 b b b 的惩罚项,因为一般情况下,网络输出层的偏置项不需要正则化。代入 w \mathbf{w} w 的参数更新表达式为:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow(1-\eta \lambda) \mathbf{w}-\frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right) w(1ηλ)wBηiBx(i)(wx(i)+by(i))

要想对模型进行权重衰减,只需要在实例化优化器时通过 weight_decay 指定权重衰减参数。默认情况下,PyTorch 同时衰减权重和偏移:

trainer = torch.optim.SGD(net.parameters(), lr=lr)

如果想要只衰减权重,需要指定参数:

params_to_optimize = [{"params": net[0].weight, 'weight_decay': wd},{"params":net[0].bias}
]
trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)

2. Dropout

Dropout 通过在训练过程中随机地将网络 内部 的一部分神经元的输出设置为零,即以一定的概率 “丢弃” 这些神经元。这样可以防止神经元在训练过程中过于依赖其他神经元,从而降低了网络对特定神经元的依赖性,使得网络更具鲁棒性:
在这里插入图片描述

通常情况下,Dropout 只在训练过程中使用,不在推理阶段使用,因为推理时模型需要产生确定性的输出。

Dropout 需要在网络中添加 Dropout 层,一般位于激活函数后,并且给定 dropout 概率:

dropout1, dropout2 = 0.2, 0.5net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),nn.Dropout(dropout2),nn.Linear(256, 10)
)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)

Dropout 概率的设置技巧是靠近输入层的地方设置较低的概率,远离输入层的地方设置较高的概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/702083.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring 容器、核心容器总结

目录 创建容器获取 bean容器类层次结构图核心容器总结容器相关bean 相关依赖注入相关 创建容器 方式一: 类路径加载配置文件 ApplicationContext ctx new ClassPathXmlApplicationContext("applicationContext.xml");方式二: 文件路径加载配…

消息队列-RabbitMQ:延迟队列、rabbitmq 插件方式实现延迟队列、整合SpringBoot

十六、延迟队列 1、延迟队列概念 延时队列内部是有序的,最重要的特性就体现在它的延时属性上,延时队列中的元素是希望在指定时间到了以后或之前取出和处理,简单来说,延时队列就是用来存放需要在指定时间被处理的元素的队列。 延…

移动端自动化常用的元素定位工具 介绍

在移动端自动化测试和开发中,元素定位是非常关键的一步。以下是一些常用的工具和技术来帮助开发者或测试工程师在移动设备上定位元素: 1. **UiAutomator**: - **UiAutomator** 是 Android 官方提供的自动化测试框架。它可以用来编写测试脚本&…

Linux之项目部署与发布

目录 一、Nginx配置安装(自启动) 1.一键安装4个依赖 2. 下载并解压安装包 3. 安装Nginx 4. 启动 nginx 服务 5. 对外开放端口 6. 配置开机自启动 7.修改/etc/rc.d/rc.local的权限 二、后端部署tomcat负载均衡 1. 准备2个tomcat 2. 修改端口 3…

4.寻找两个正序数组的中位数

题目:给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。 解题思路:用二分法查找。使用归并的方式,合并两个有序数组,得到一个大的有序数组。大的…

树状数组与线段树<2>——线段树初步

这个系列终于更新了(主要因为树状数组初步比较成功) 话不多说,切入正题。 什么是线段树? 线段树是一种支持单点修改区间查询(树状数组也行) and 区间修改单点查询(树状数组不行) and 区间修改区间查询(树状数组更不行)的高级数据结构,相当…

Spring Boot利用Kaptcha生成验证码

生成验证码 我们在登录或注册某个网站的时候,会需要我们输入验证码,才能登录注册,那么如何生成验证码呢?其实,生成验证码我们可以用Java Swing在后台内存里的区域画一个出来,但是非常麻烦,所以…

在IDEA中创建vue hello-world项目

工作中最近在接触vue前端项目,记录一下从0搭建一个vue hello world项目的步骤 1、本地电脑安装配置node、npm D:\Project\vue\hello-world>node -v v14.21.3 D:\Project\vue\hello-world>npm -v 6.14.18 D:\Project\vue\hello-world> 2、设置npm国内淘宝的景象 …

unity学习(41)——创建(create)角色脚本(panel)——UserHandler(收)+CreateClick(发)——发包!

1.客户端的程序结构被我精简过,现在去MessageManager.cs中增加一个UserHandler函数,根据收到的包做对应的GameInfo赋值。 2.在Model文件夹下新增一个协议文件UserProtocol,内容很简单。 using System;public class UserProtocol {public co…

涵盖5大领域的机器学习工具介绍

随着数据的产生及其使用量的不断增加,对机器学习模型的需求也在成倍增加。由于ML系统包含了算法和丰富的ML库,它有助于分析数据和做出决策。难怪机器学习的知名度越来越高,因为ML应用几乎主导了现代世界的每一个方面。随着企业对这项技术的探…

Java中PDF文件传输有哪些方法?

专栏集锦,大佬们可以收藏以备不时之需: Spring Cloud 专栏:http://t.csdnimg.cn/WDmJ9 Python 专栏:http://t.csdnimg.cn/hMwPR Redis 专栏:http://t.csdnimg.cn/Qq0Xc TensorFlow 专栏:http://t.csdni…

记录一些mac电脑重装mysql和pgsql的坑

为什么要重装,是想在mac电脑 创建data目录…同事误操作,导致电脑重启不了.然后重装系统后,.就连不上数据库了.mysql和pgsql两个都连不上.网上也查了很多资料.实在不行,.就重装了… 重装mysql. 1.官网下载 https://www.mysql.com/downloads/ 滑到最下面 选择 选择对应的芯片版本…

设计推特(Leetcode355)

例题: https://leetcode.cn/problems/design-twitter/ 分析: 推特其实类似于微博,在微博中可以发送文章。 求解这类题目,我们需要根据题目需求,利用面向对象的思想,先对需求做一个抽象,看看能…

字符串(算法竞赛)--Manacher(马拉车)算法

1、B站视频链接&#xff1a;F05 Manacher(马拉车)_哔哩哔哩_bilibili 题目链接&#xff1a;【模板】manacher - 洛谷 ​ #include <bits/stdc.h> using namespace std; const int N3e7; char a[N],s[N]; int d[N];//回文半径函数void get_d(char*s,int n){d[1]1;for(int…

领域驱动设计(Domain-Driven Design DDD)——通过重构找到深层次模型2

五、应用分析模式 深层模型和柔性设计并非唾手可得。想要取得进展&#xff0c;必须学习大量领域知识并进行充分的讨论&#xff0c;还需要经历大量的尝试和失败。在实际的研究领域问题实践时&#xff0c;有一些成熟的模式可以供我们借鉴和套用。这样我们可以从这个起点来重构和试…

vim恢复.swp [BJDCTF2020]Cookie is so stable1

打开题目 扫描目录得到 关于 .swp 文件 .swp 文件一般是 vim 编辑器在编辑文件时产生的&#xff0c;当用 vim 编辑器编辑文件时就会产生&#xff0c;正常退出时 .swp 文件被删除&#xff0c;但是如果直接叉掉&#xff08;非正常退出&#xff09;&#xff0c;那么 .swp 文件就会…

spring-security 过滤器 (三)

spring-security过滤器 版本信息过滤器配置过滤器配置相关类图过滤器加载过程创建 HttpSecurity Bean 对象创建过滤器 过滤器作用ExceptionTranslationFilter 自定义过滤器 本章介绍 spring-security 过滤器配置类 HttpSecurity&#xff0c;过滤器加载过程&#xff0c;自定义过…

信息抽取(UIE):使用自然语言处理技术提升证券投资决策效率

一、引言 在当今快速变化的证券市场中&#xff0c;信息的价值不言而喻。作为一名资深项目经理&#xff0c;我曾领导一个关键项目&#xff0c;旨在通过先进的信息抽取技术&#xff0c;从海量的文本数据中提取关键事件&#xff0c;如企业并购、新产品发布以及政策环境的变动。这些…

Open CASCADE学习|几何数据结构

在几何引擎内一般把数据分成两类&#xff1a;几何信息与拓扑信息。二者可以完整地表达出实体模型&#xff0c;彼此相互独立、又互相关联。几何信息是指构成几何实体的各几何元素在欧式空间中的位置、大小、尺寸和形状信息。例如一条空间的直线&#xff0c;可以用两端点的位置矢…

五种多目标优化算法(MOCS、MOFA、NSWOA、MOAHA、MOPSO)性能对比(提供MATLAB代码)

一、5种多目标优化算法简介 多目标优化算法是用于解决具有多个目标函数的优化问题的一类算法。其求解流程通常包括以下几个步骤&#xff1a; 1. 定义问题&#xff1a;首先需要明确问题的目标函数和约束条件。多目标优化问题通常涉及多个目标函数&#xff0c;这些目标函数可能…