深度学习中的正则化技术

在深度学习中,正则化是一种防止模型过拟合的重要手段。过拟合是指模型在训练数据上表现良好,但在未见数据上表现不佳的现象。正则化通过引入额外的约束或信息来限制模型的复杂性,从而提高模型的泛化能力。本文将介绍几种常见的正则化技术,包括 L1 正则化、L2 正则化、Dropout 和 Batch Normalization,并提供在 PyTorch 中的实现方法。

1. 正则化技术

1.1 L1 正则化

原理

L1 正则化(Lasso 正则化)通过在损失函数中添加权重绝对值的和来实现。其主要特点是能够导致一些权重变为零,从而实现特征选择。

公式

L1 正则化的损失函数公式如下:

L = Loss + λ ∑ j = 1 n ∣ w j ∣ L = \text{Loss} + \lambda \sum_{j=1}^{n} |w_j| L=Loss+λj=1nwj

其中:

  • L L L 是总损失。
  • Loss \text{Loss} Loss 是原始损失(如均方误差)。
  • λ \lambda λ 是正则化强度(超参数)。
  • w j w_j wj 是模型的权重。
在 PyTorch 中的实现

在 PyTorch 中,L1 正则化通常需要在损失函数中手动添加 L1 范数。可以通过遍历模型的参数,计算绝对值和并将其添加到损失中。


1.2 L2 正则化

原理

L2 正则化(Ridge 正则化)通过在损失函数中添加权重平方和来实现。它的主要作用是将权重收缩到更小的值,从而减小模型的复杂性。

公式

L2 正则化的损失函数公式如下:

L = Loss + λ ∑ j = 1 n w j 2 L = \text{Loss} + \lambda \sum_{j=1}^{n} w_j^2 L=Loss+λj=1nwj2

其中:

  • L L L 是总损失。
  • Loss \text{Loss} Loss 是原始损失。
  • λ \lambda λ 是正则化强度(超参数)。
  • w j w_j wj 是模型的权重。
在 PyTorch 中的实现

在 PyTorch 中,L2 正则化可以通过在优化器中设置 w e i g h t _ d e c a y weight\_decay weight_decay 参数来实现。例如,使用 SGD 优化器时,可以直接在优化器的初始化中添加 w e i g h t _ d e c a y weight\_decay weight_decay


1.3 Dropout

原理

Dropout 是一种随机丢弃神经元的技术。在训练过程中,随机选择一部分神经元(及其连接)不参与前向传播和反向传播。其主要目的是防止神经网络对特定神经元的过度依赖,从而提高模型的泛化能力。

在 PyTorch 中的实现

在 PyTorch 中,可以通过在模型中添加 n n . D r o p o u t nn.Dropout nn.Dropout 层来实现 Dropout。您可以指定丢弃的概率(例如,0.5 表示有 50% 的概率丢弃神经元)。


1.4 Batch Normalization

原理

Batch Normalization 是一种加速训练并提高模型稳定性的技术。它通过对每一层的输入进行标准化,使其具有零均值和单位方差,从而减少内部协变量偏移。Batch Normalization 通常可以提高模型的收敛速度并增强泛化能力。

在 PyTorch 中的实现

在 PyTorch 中,可以通过在模型中添加 n n . B a t c h N o r m 1 d nn.BatchNorm1d nn.BatchNorm1d n n . B a t c h N o r m 2 d nn.BatchNorm2d nn.BatchNorm2d 层来实现 Batch Normalization,具体取决于输入的维度(1D 或 2D)。


2. 实验代码与分析

下面是使用 MNIST 手写数字数据集的代码示例,比较了不带正则化的 CNN 模型与带有 L2 正则化、Dropout 和 Batch Normalization 的 CNN 模型。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 检查设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 数据准备
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 归一化
])# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 定义不带正则化的 CNN 模型
class SimpleCNNNoReg(nn.Module):def __init__(self):super(SimpleCNNNoReg, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 输入通道数为1,输出通道数为32self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)  # 7x7 是经过卷积和池化后的特征图大小self.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = nn.MaxPool2d(kernel_size=2)(x)  # 最大池化x = torch.relu(self.conv2(x))x = nn.MaxPool2d(kernel_size=2)(x)  # 最大池化x = x.view(-1, 64 * 7 * 7)  # 展平x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 定义带 L2 正则化、Dropout 和 Batch Normalization 的 CNN 模型
class SimpleCNNWithReg(nn.Module):def __init__(self):super(SimpleCNNWithReg, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(32)  # Batch Normalizationself.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(64)  # Batch Normalizationself.fc1 = nn.Linear(64 * 7 * 7, 128)self.dropout = nn.Dropout(0.5)  # Dropoutself.fc2 = nn.Linear(128, 10)def forward(self, x):x = torch.relu(self.bn1(self.conv1(x)))  # Batch Normalizationx = nn.MaxPool2d(kernel_size=2)(x)  # 最大池化x = torch.relu(self.bn2(self.conv2(x)))  # Batch Normalizationx = nn.MaxPool2d(kernel_size=2)(x)  # 最大池化x = x.view(-1, 64 * 7 * 7)  # 展平x = torch.relu(self.fc1(x))x = self.dropout(x)  # Dropoutx = self.fc2(x)return x# 训练和评估函数
def train(model, train_loader, criterion, optimizer, epochs=10):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 计算平均损失avg_loss = running_loss / len(train_loader)print(f'Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}')def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return correct / total# 不带正则化的模型
model_no_reg = SimpleCNNNoReg().to(device)  # 移动模型到设备
criterion = nn.CrossEntropyLoss()
optimizer_no_reg = optim.SGD(model_no_reg.parameters(), lr=0.01)# 训练不带正则化的模型
print("Training model without regularization...")
train(model_no_reg, train_loader, criterion, optimizer_no_reg, epochs=10)
accuracy_no_reg = evaluate(model_no_reg, test_loader)# 带 L2 正则化、Dropout 和 Batch Normalization 的模型
model_with_reg = SimpleCNNWithReg().to(device)  # 移动模型到设备
optimizer_with_reg = optim.SGD(model_with_reg.parameters(), lr=0.01, weight_decay=0.01)  # L2 正则化# 训练带正则化的模型
print("\nTraining model with L2 regularization, Dropout, and Batch Normalization...")
train(model_with_reg, train_loader, criterion, optimizer_with_reg, epochs=10)
accuracy_with_reg = evaluate(model_with_reg, test_loader)# 输出结果
print(f"\nAccuracy without regularization: {accuracy_no_reg:.4f}")
print(f"Accuracy with L2 regularization, Dropout, and Batch Normalization: {accuracy_with_reg:.4f}")
Accuracy without regularization: 0.9850
Accuracy with L2 regularization, Dropout, and Batch Normalization: 0.9857
  • 不带正则化的 CNN:包含两个卷积层和两个全连接层,适合于处理 MNIST 这样的图像数据。
  • 带正则化的 CNN:在卷积层后添加了 Batch Normalization 和 Dropout,帮助模型提高泛化能力和稳定性。

在这个实验中,通过比较不带正则化的模型与带有 L2 正则化、Dropout 和 Batch Normalization 的模型,我们能够观察到正则化技术对模型泛化能力的影响。正则化可以有效地减少过拟合,提高模型在未见数据上的准确性。

  • L2 正则化:通过惩罚较大的权重,帮助模型保持简单,从而提高泛化能力。
  • Dropout:通过随机丢弃神经元,减少模型对特定特征的依赖,增强模型的鲁棒性。
  • Batch Normalization:通过标准化层的输入,加速训练并提高模型稳定性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/61568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

业务架构、数据架构、应用架构和技术架构

TOGAF(The Open Group Architecture Framework)是一个广泛应用的企业架构框架,旨在帮助组织高效地进行架构设计和管理。 TOGAF 的核心就是由我们熟知的四大架构领域组成:业务架构、数据架构、应用架构和技术架构。 企业数字化架构设计中的最常见要素是4A 架构。 4…

苹果Siri将搭载大型语言模型,近屿智能抢占AIGC大模型人才培养高地

据媒体报道,苹果公司正在研发一款全新升级、更加智能且对话能力显著提升的Siri,意在超越OpenAI的ChatGPT及其他语音服务。 报道指出,新一代Siri将搭载更为先进的大型语言模型(LLM),苹果期望其能够进行连续…

【1.4 Getting Started--->Support Matrix】

主页:支持矩阵 这些支持矩阵概述了 TensorRT API、解析器和层支持的平台、特性和硬件功能。 Support Matrix Abstract 这些支持矩阵概述了 TensorRT API、解析器和层所支持的平台、功能和硬件功能。 有关之前发布的 TensorRT 文档,请参阅 TensorRT 档…

WPF中如何让Textbox显示为一条直线

由于Textbox直接使用是一条直线 设置如下代码 可以让Textbox变为直线输入 <Style TargetType"TextBox"x:Key"UsernameTextBoxStyle"><Setter Property"Template"><Setter.Value><ControlTemplate TargetType"{x:Typ…

Mac 修改默认jdk版本

当前会话生效 这里演示将 Java 17 版本降低到 Java 8 查看已安装的 Java 版本&#xff1a; 在终端&#xff08;Terminal&#xff09;中运行以下命令&#xff0c;查看已安装的 Java 版本列表 /usr/libexec/java_home -V设置默认 Java 版本&#xff1a; 找到 Java 8 的安装路…

K8S + Jenkins 做CICD

前言 这里会做整体CICD的思路和流程的介绍&#xff0c;会给出核心的Jenkins pipeline脚本&#xff0c;最后会演示一下 实验/实操 结果 由于整体内容较多&#xff0c;所以不打算在这里做每一步的详细演示 - 本文仅作自己的实操记录和日后回顾用 要看保姆式教学的可以划走了&…

使用 前端技术 创建 QR 码生成器 API1

前言 QR码&#xff08;Quick Response Code&#xff09;是一种二维码&#xff0c;于1994年开发。它能快速存储和识别数据&#xff0c;包含黑白方块图案&#xff0c;常用于扫描获取信息。QR码具有高容错性和快速读取的优点&#xff0c;广泛应用于广告、支付、物流等领域。通过扫…

基于Java Springboot高校工作室管理系统

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术&#xff1a;Html、Css、Js、Vue、Element-ui 数据库&#xff1a;MySQL 后端技术&#xff1a;Java、Spring Boot、MyBatis 三、运行环境 开发工具&#xff1a;IDEA/eclipse 数据…

【读书】复杂性意义结构框架——Cynefin框架

Cynefin框架 《代码大全》的作者史蒂夫麦克康奈尔&#xff08;Steve McConnell&#xff09;在《卓有成效的敏捷》这本书里&#xff0c;探讨了用于理解不确定性和复杂性的Cynefin框架。 Cynefin框架是戴维斯诺登&#xff08;David Snowden&#xff09;20世纪90年代的在IBM时创…

ZYNQ-7020嵌入式系统学习笔记(1)——使用ARM核配置UART发送Helloworld

本工程实现调用ZYNQ-7000的内部ARM处理器&#xff0c;通过UART给电脑发送字符串。 硬件&#xff1a;正点原子领航者-7020 开发平台&#xff1a;Vivado 2018、 SDK 1 Vivado部分操作 1.1 新建工程 设置工程名&#xff0c;选择芯片型号。 1.2 添加和配置PS IP 点击IP INTEGR…

全面击破工程级复杂缓存难题

目录 一、走进业务中的缓存 &#xff08;一&#xff09;本地缓存 &#xff08;二&#xff09;分布式缓存 二、缓存更新模式分析 &#xff08;一&#xff09;Cache Aside Pattern&#xff08;旁路缓存模式&#xff09; 读操作流程 写操作流程 流程问题思考 问题1&#…

SpringSecurity创建一个简单的自定义表单的认证应用

1、SpringSecurity 自定义表单 在 Spring Security 中创建自定义表单认证应用是一个常见的需求&#xff0c;特别是在需要自定义登录页面、认证逻辑或添加额外的表单字段时。以下是一个详细的步骤指南&#xff0c;帮助你创建一个自定义表单认证应用。 2、基于 SpringSecurity 的…

用python简单集成一个分词工具

本部分记录如何利用Python进行分词工具集成&#xff0c;集成工具可以实现运行无环境要求&#xff0c;同时也更方便。 该文章主要是记录&#xff0c;知识点不是特别多&#xff0c;欢迎访问个人博客&#xff1a;https://blog.jiumoz.top/archives/fen-ci-gong-ju-ji-cheng 成品展…

Fakelocation Server服务器/专业版 Windows11

前言:需要Windows11系统 Fakelocation开源文件系统需求 Windows11 | Fakelocation | 任务一 打开 PowerShell&#xff08;以管理员身份&#xff09;命令安装 Chocolatey Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProto…

【云计算】腾讯云架构高级工程师认证TCP--考纲例题,知识点总结

【云计算】腾讯云架构高级工程师认证TCCP–知识点总结&#xff0c;排版整理 文章目录 1、云计算架构概论1.1 五大版块知识点&#xff08;架构设计&#xff0c;基础服务&#xff0c;高阶技术&#xff0c;安全&#xff0c;上云&#xff09;1.2 课程详细目录1.3 云基础架构设计1.4…

HarmonyOs鸿蒙开发实战(22)=>开源插件集成-城市选择案例(带字母索引可修改源码)

1.第一步>DevEco Studio集成开源插件 1.1.下载资源插件 1.2.打开Perferences页面&#xff0c;从本地导入安装插件 2.第二步>导入HarmonyOs Next组件市场的城市选择案例&#xff0c;安装 2.1. 代码空白处右键&#xff0c;打开开源组件弹窗 2.2. 安装城市选择案例 3.第三步…

ROS之什么是Node节点和Package包?

1.什么是ROS&#xff1f; 官方术语&#xff1a;ROS&#xff08;Robot Operating System&#xff0c;机器人操作系统&#xff09;是一个开源的、模块化的机器人软件框架。它为机器人开发提供了一套工具和库&#xff0c;用于实现硬件抽象、设备驱动、消息传递、多线程管理等功能…

Windows环境安装MongoDB

文章目录 1. 下载MongoDB2. 安装MongoDB3. Compass-图形化界面客户端4. 更换Compass的主题 阅读本文前可以先阅读以下文章&#xff1a; MongoDB快速入门&#xff08;MongoDB简介、MongoDB的应用场景、MongoDB中的基本概念、MongoDB的数据类型、MongoDB的安装与部署、MongoDB的常…

在线解析工具链接

在线字数统计工具-统计字符字节汉字数字标点符号-计算word文章字数字数统计,字符统计,字节统计,字数计算,统计字数,统计字节数,统计字符数,统计word字数,在线字数统计,在线查字数,计算字数,字数统计工具,支持手机移动端查询多少字数,英文:Calculate the number of words,Count …

RTL8211F 1000M以太网PHY指示灯

在RK3562 Linux5.10 SDK里面已支持该芯片kernel-5.10/drivers/net/phy/realtek.c&#xff0c;而默认是没有去修改到LED配置的&#xff0c;我们根据硬件设计修改相应的寄存器配置&#xff0c;该PHY有3个LED引脚&#xff0c;我们LED0不使用&#xff0c;LED1接绿灯&#xff08;数据…