AI学习指南深度学习篇-丢弃法Python实践

AI学习指南深度学习篇-丢弃法Python实践

引言

在深度学习的领域中,丢弃法(Dropout)是一种有效的防止过拟合的随机正则化技术。过拟合是指模型在训练集上表现良好,但在测试集或未见过的数据上表现较差的现象。丢弃法通过随机地“丢弃”一部分神经元节点,使模型在训练过程中不能过于依赖某个特定的特征。本文将详细介绍丢弃法的基本概念、原理,以及在TensorFlow和PyTorch等深度学习库中的具体实现,并通过实际代码示例展示如何调参以优化模型性能。

1. 丢弃法的基本概念

丢弃法首次提出于2014年的论文《Dropout: A Simple Way to Prevent Neural Networks from Overfitting》中。其主要思想是在每次训练迭代中随机选取一定比例的神经元,并将其输出置为零。这样可以迫使网络以不同方式学习,使得模型具有更好的泛化能力。

1.1 丢弃率(Dropout Rate)

丢弃率是指在每次训练中被丢弃的神经元的比例。常用的丢弃率有0.2、0.5等,具体值需要根据实验结果进行调整。较高的丢弃率可能会导致欠拟合,而较低的丢弃率则可能不足以起到正则化的效果。

2. 使用TensorFlow实现丢弃法的示例

2.1 环境配置

在开始之前,请确保你已经安装了TensorFlow。如果你没有安装,可以使用以下命令进行安装:

pip install tensorflow

2.2 加载数据集

我们将使用MNIST手写数字数据集作为示例。TensorFlow提供了方便的接口来加载MNIST数据集。以下是加载数据集的代码:

import tensorflow as tf
from tensorflow.keras import layers, models# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 预处理数据
x_train = x_train.reshape((60000, 28, 28, 1)).astype("float32") / 255
x_test = x_test.reshape((10000, 28, 28, 1)).astype("float32") / 255

2.3 构建模型

下面的代码示例展示了如何构建包含丢弃法的卷积神经网络(CNN)模型。

def create_model(dropout_rate=0.5):model = models.Sequential()model.add(layers.Conv2D(32, (3, 3), activation="relu", input_shape=(28, 28, 1)))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(64, (3, 3), activation="relu"))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(64, (3, 3), activation="relu"))# 添加丢弃层model.add(layers.Dropout(dropout_rate))model.add(layers.Flatten())model.add(layers.Dense(64, activation="relu"))model.add(layers.Dropout(dropout_rate))model.add(layers.Dense(10, activation="softmax"))return model

2.4 编译和训练模型

接下来,我们需要编译模型,并使用训练数据进行训练。我们将创建一个包含不同丢弃率的模型,对比它们的表现。

# 编译和训练模型
def train_model(dropout_rate):model = create_model(dropout_rate)model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["accuracy"])history = model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))return history# 测试不同的丢弃率
dropout_rates = [0.2, 0.5, 0.7]
histories = {}for rate in dropout_rates:print(f"Training with dropout rate: {rate}")histories[rate] = train_model(rate)

2.5 可视化训练过程

我们可以通过绘制训练和验证损失及准确率图来比较不同丢弃率模型的表现。

import matplotlib.pyplot as pltdef plot_history(histories):plt.figure(figsize=(12, 5))for rate, history in histories.items():plt.subplot(1, 2, 1)plt.plot(history.history["accuracy"], label=f"train acc (dropout {rate})")plt.plot(history.history["val_accuracy"], label=f"val acc (dropout {rate})")plt.subplot(1, 2, 2)plt.plot(history.history["loss"], label=f"train loss (dropout {rate})")plt.plot(history.history["val_loss"], label=f"val loss (dropout {rate})")plt.subplot(1, 2, 1)plt.title("Training and Validation Accuracy")plt.xlabel("Epochs")plt.ylabel("Accuracy")plt.legend()plt.subplot(1, 2, 2)plt.title("Training and Validation Loss")plt.xlabel("Epochs")plt.ylabel("Loss")plt.legend()plt.show()plot_history(histories)

3. 使用PyTorch实现丢弃法的示例

接下来,我们将使用PyTorch实现相同功能。首先,请确保你安装了PyTorch。如果尚未安装,可以使用以下命令:

pip install torch torchvision

3.1 加载数据集

同样,我们将使用MNIST数据集。以下是加载数据集的代码:

import torch
import torchvision
import torchvision.transforms as transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 加载MNIST数据集
trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

3.2 构建模型

下面是使用丢弃法的CNN模型代码:

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self, dropout_rate=0.5):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3)self.conv2 = nn.Conv2d(32, 64, 3)self.fc1 = nn.Linear(64 * 24 * 24, 128)self.fc2 = nn.Linear(128, 10)self.dropout = nn.Dropout(dropout_rate)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = x.view(-1, 64 * 24 * 24)x = F.relu(self.fc1(x))x = self.dropout(x)  # Apply Dropoutx = self.fc2(x)return x

3.3 训练和测试模型

接下来,我们需要定义训练和测试过程:

def train_model(dropout_rate, num_epochs=10):model = Net(dropout_rate)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters())for epoch in range(num_epochs):model.train()running_loss = 0.0for i, data in enumerate(trainloader):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch + 1}, Loss: {running_loss / (i + 1)}")return model# 测试不同丢弃率模型
dropout_rates = [0.2, 0.5, 0.7]
models = {}for rate in dropout_rates:print(f"Training model with dropout rate: {rate}")models[rate] = train_model(rate)

3.4 测试模型

接下来,我们来评估模型的准确率:

def test_model(model):correct = 0total = 0model.eval()with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy: {100 * correct / total}%")# 测试不同丢弃率模型
for rate, model in models.items():print(f"Testing model with dropout rate: {rate}")test_model(model)

结论

在本篇文章中,我们详细介绍了丢弃法的基本概念及其在深度学习中的应用。通过使用TensorFlow和PyTorch,我们实现了包含丢弃法的卷积神经网络,并对不同丢弃率模型的性能进行了比较和分析。

丢弃法作为一种简单有效的正则化技术,可以帮助我们减少模型的复杂度,提高模型在未见数据上的泛化能力。在实际应用中,还需根据具体的任务场景调整丢弃率以及其他超参数,以寻求最佳的模型性能。希望本文能够为深度学习爱好者提供一个清晰的丢弃法实践指南!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54909.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【电商搜索】现代工业级电商搜索技术-Facebook语义搜索技术QueSearch

【电商搜索】现代工业级电商搜索技术-Facebook语义搜索技术Que2Search 目录 文章目录 【电商搜索】现代工业级电商搜索技术-Facebook语义搜索技术Que2Search目录0. 论文信息1. 研究背景:2. 技术背景和发展历史:3. 算法建模3.1 模型架构3.1.1 双塔与分类 …

Rust调用tree-sitter支持自定义语言解析

要使用 Rust 调用 tree-sitter 解析自定义语言,你需要遵循一系列步骤来定义语言的语法,生成解析器,并在 Rust 中使用这个解析器。下面是详细步骤: 1. 定义自定义语言的语法 首先,你需要创建一个 tree-sitter 语言定义…

NLP:BERT的介绍

1. BERT 1.1 Transformer Transformer架构是一种基于自注意力机制(self-attention)的神经网络架构,它代替了以前流行的循环神经网络和长短期记忆网络,已经应用到多个自然语言处理方向。   Transformer架构由两个主要部分组成:编码器(Encod…

【HarmonyOS】应用引用media中的字符串资源如何拼接字符串

【HarmonyOS】应用引用media中的字符串资源如何拼接字符串 一、问题背景: 鸿蒙应用中使用字符串资源加载,一般文本放置在resoutces-base-element-string.json字符串配置文件中。便于国际化的处理。当然小项目一般直接引用字符串,不需要加载s…

[dp+dfs]砝码称重

题目描述 现有 n n n 个砝码,重量分别为 a 1 , a 2 , … , a n a_1, a_2, \ldots,a_n a1​,a2​,…,an​ ,在去掉 m m m 个砝码后,问最多能称量出多少不同的重量(不包括 0 0 0 )。 输入格式 第一行为有两个整数…

python爬虫:从12306网站获取火车站信息

代码逻辑 初始化 (init 方法): 设置请求头信息。设置车站版本号。 同步车站信息 (synchronization 方法): 发送GET请求获取车站信息。返回服务器响应的文本。 提取信息 (extract 方法): 从服务器响应中提取车站信息字符串。去掉字符串末尾的…

如何通过Dockfile更改docker中ubuntu的apt源

首先明确我们有一个宿主机和一个docker环境,接下来的步骤是基于他们两个完成的 1.在宿主机上创建Dockerfile 随便将后面创建的Dockerfile放在一个位置,我这里选择的是 /Desktop 使用vim前默认你已经安装好了vim 2.在输入命令“vim Dockerfile”之后,…

知识付费APP开发指南:基于在线教育系统源码的技术详解

本篇文章,我们将探讨基于在线教育系统源码的知识付费APP开发的技术细节,帮助开发者和企业快速入门。 一、选择合适的在线教育系统源码 选择合适的在线教育系统源码是开发的关键一步。市场上有许多开源和商业化的在线教育系统源码,开发者需要…

花都狮岭寄宿自闭症学校:开启孩子的生命之门

在花都狮岭这片充满温情的土地上,有一所特别的学校,它像一把钥匙,轻轻旋转,为自闭症儿童们开启了一扇通往无限可能的生命之门——这就是广州星贝育园自闭症儿童寄宿制学校。这所学校不仅是知识的摇篮,更是孩子们心灵成…

React 启动时webpack版本冲突报错

报错信息: 解决办法: 找到全局webpack的安装路径并cmd 删除全局webpack 安装所需要的版本

Python(六)-拆包,交换变量名,lambda

目录 拆包 交换变量值 引用 lambda函数 lambda实例 字典的lambda 推导式 列表推导式 列表推导式if条件判断 for循环嵌套列表推导式 字典推导式 集合推导式 拆包 看一下在Python程序中的拆包:把组合形成的元组形式的数据,拆分出单个元素内容…

影响上证50股指期货价格的因素有哪些?

上证50股指期货,作为反映上海证券交易所最具代表性50只股票整体表现的期货合约,其价格同样受到一系列复杂因素的驱动。以下是对影响上证50股指期货价格的主要因素进行的详细分析。 因素一、期货合约的供求关系 股指期货市场是一个由多头和空头双方共同…

具身智能综述:鹏城实验室中大调研近400篇文献,深度解析具身智能

具身智能是实现通用人工智能的必经之路,其核心是通过智能体与数字空间和物理世界的交互来完成复杂任务。近年来,多模态大模型和机器人技术得到了长足发展,具身智能成为全球科技和产业竞争的新焦点。然而,目前缺少一篇能够全面解析…

面试遇到的质量体系10个问题(深度思考)

在某大型公司的招聘面试中关于质量体系本身及建设实践方面的10个问题,这些问题都是偏理论性强一些,但是可以通过这些问题来了解大型公司对质量体系的一些想法和预期的内容,本期先抛出来这10个问题,不附答案,目的就是让…

AI绘画:Stable Diffusion 终极炼丹宝典:从入门到精通

前言 我是Lison,以浅显易懂的方式,与大家分享那些实实在在可行之宝藏。 历经耗时数十个小时,总算将这份Stable Diffusion的使用教程整理妥当。 从最初的安装与配置,细至界面功能的详解,再至实战案例的制作&#xff…

数组基础(c++)

第1题 精挑细选 时限:1s 空间:256m 小王是公司的仓库管理员,一天,他接到了这样一个任务:从仓库中找出一根钢管。这听起来不算什么,但是这根钢管的要求可真是让他犯难了,要求如下&#x…

从细胞到临床:表观组学分析技术在精准医疗中的角色

中国科学院等科研院所的顶尖人才发起,专注于多组学、互作组、生物医学等领域的研究与服务。在Nature等国际知名期刊发表多篇论文,提供实验整体打包、免费SCI论文润色等四大优势服务。在表观组学分析技术方面,提供DAP-seq、ATAC-seq、H3K4me3 …

使用mendeley生成APA格式参考文献

mendeley 是一款文献管理工具,可以在word中方便的插入引用文献。 效果对比: 注:小绿鲸有三种导出格式,分别为复制、导出为Bibtex和导出为Endnote三种。 mendeley 下载与安装 Download Mendeley Reference Manager For Desktop mac…

98问答网是一个怎样的平台?它主要提供哪些服务?

98问答网是一个集知识分享、问题解答与社区交流为一体的综合性在线问答平台。该平台旨在通过汇聚来自各行各业的专家、学者以及广大网友的智慧,为用户提供一个快速获取准确信息、解决生活工作中遇到的各种问题的渠道。 主要服务包括: 问题提问与解答&am…

10.C++程序中的循环语句

C中提供了三种循环语句(for循环,while循环以及do-while循环)来使程序员可以更方便地对数据进行迭代操作。 if语句 for语句的格式为: for(初始化语句;循环条件;迭代语句) { 代码块 &#x…