什么是正则化?Regularization: The Stabilizer of Machine Learning Models(中英双语)

正则化:机器学习模型的稳定器


1. 什么是正则化?

正则化(Regularization)是一种在机器学习模型训练中,通过约束模型复杂性以防止过拟合的技术
它的核心目标是让模型不仅在训练集上表现良好,还能在测试集上具有良好的泛化能力。


2. 为什么正则化起作用?

2.1 过拟合的本质

过拟合通常发生在模型参数过多、数据量不足或数据噪声较大时,模型学到了数据中的噪声和不相关的模式,从而导致泛化能力下降。

2.2 正则化的作用原理

正则化通过引入额外的约束条件来抑制模型的复杂性,限制其自由度,使得模型更倾向于学习数据的总体模式而非局部噪声。

数学原理
正则化通过在损失函数中添加正则项,改变了优化目标,从而约束模型的参数空间。以常见的线性回归为例:

  • 原始损失函数(最小化误差):
    L = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \mathcal{L} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 L=n1i=1n(yiy^i)2
  • 加入正则化后的损失函数:
    L reg = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 + λ R ( θ ) \mathcal{L}_{\text{reg}} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 + \lambda R(\theta) Lreg=n1i=1n(yiy^i)2+λR(θ)

其中:

  • ( R ( θ ) R(\theta) R(θ) ) 是正则项,用于约束模型参数 ( θ \theta θ )。
  • ( λ \lambda λ ) 是正则化强度的超参数,用于权衡数据拟合与正则化之间的关系。

3. 常见的正则化方法

3.1 参数正则化:L1 和 L2 正则化
  • L1 正则化(Lasso Regression)
    在损失函数中加入 ( L 1 L1 L1 ) 范数的约束:
    R ( θ ) = ∥ θ ∥ 1 = ∑ j = 1 p ∣ θ j ∣ R(\theta) = \|\theta\|_1 = \sum_{j=1}^p |\theta_j| R(θ)=θ1=j=1pθj

    • 优点:促使部分参数变为零,从而实现特征选择。
    • 缺点:在高维数据中可能会丢失部分信息。
  • L2 正则化(Ridge Regression)
    在损失函数中加入 ( L2 ) 范数的约束:
    R ( θ ) = ∥ θ ∥ 2 2 = ∑ j = 1 p θ j 2 R(\theta) = \|\theta\|_2^2 = \sum_{j=1}^p \theta_j^2 R(θ)=θ22=j=1pθj2

    • 优点:通过惩罚较大的参数值,抑制模型复杂性。
    • 缺点:不会稀疏参数,所有特征都会保留。

代码示例(以线性回归为例):

import numpy as np
from sklearn.linear_model import Ridge, Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 模拟数据
np.random.seed(42)
X = np.random.rand(100, 5)
y = 3 * X[:, 0] + 2 * X[:, 1] + np.random.randn(100)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# L2 正则化(Ridge)
ridge = Ridge(alpha=1.0)  # alpha 控制正则化强度
ridge.fit(X_train, y_train)
y_pred_ridge = ridge.predict(X_test)# L1 正则化(Lasso)
lasso = Lasso(alpha=0.1)
lasso.fit(X_train, y_train)
y_pred_lasso = lasso.predict(X_test)print("Ridge MSE:", mean_squared_error(y_test, y_pred_ridge))
print("Lasso MSE:", mean_squared_error(y_test, y_pred_lasso))

3.2 数据增强(Data Augmentation)
  • 数据增强是通过对训练数据进行扩充(如图像翻转、裁剪、旋转等),使模型看到更多变种,从而提升泛化能力。
  • 常用于计算机视觉和自然语言处理领域。

代码示例(以 PyTorch 图像增强为例):

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 数据增强
transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),
])# 加载数据集
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 打印增强后的图像形状
for images, labels in train_loader:print(images.shape)  # (64, 3, 32, 32)break

3.3 Dropout
  • Dropout 是一种在训练过程中随机“丢弃”一部分神经元的正则化技术,用于防止神经网络过拟合。
  • 训练时,随机将一部分神经元的输出置为零;推理时,使用所有神经元,但缩放其输出。

数学原理
假设 Dropout 比例为 ( p p p ),每个神经元有 ( 1 − p 1-p 1p ) 的概率被激活:
输出 = 激活值 ⋅ 掩码 / ( 1 − p ) \text{输出} = \text{激活值} \cdot \text{掩码} / (1-p) 输出=激活值掩码/(1p)

代码示例

import torch
import torch.nn as nn# 定义一个简单的网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 256)self.dropout = nn.Dropout(p=0.5)  # Dropout 概率为 0.5self.fc2 = nn.Linear(256, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x# 使用 Dropout 的网络
model = SimpleNN()
print(model)

3.4 大模型中的正则化方法

在深度学习领域(尤其是 2022-2023 年的大模型训练),一些新的正则化方法逐渐被广泛应用:

  1. LayerNorm 和 WeightNorm

    • LayerNorm 对每一层进行归一化,减少梯度消失或爆炸问题。
    • WeightNorm 通过分离权重的幅度和方向,提升模型收敛速度。
  2. Label Smoothing

    • 通过在训练目标上引入少量噪声,避免模型过度自信。
      y ~ = ( 1 − ϵ ) ⋅ y + ϵ / K \tilde{y} = (1 - \epsilon) \cdot y + \epsilon / K y~=(1ϵ)y+ϵ/K
  3. 梯度裁剪(Gradient Clipping)

    • 限制梯度更新的幅度,避免梯度爆炸。
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  4. 正则化优化器

    • AdamW 是一种带权重衰减的优化器,直接在更新权重时加入 L2 正则化效果。

4. 正则化在大模型中的实际应用

以 GPT-3 或 BERT 等大语言模型的训练为例,正则化方法的组合应用非常重要:

  • 使用 LayerNormDropout 作为网络层内的正则化手段。
  • 在优化器中应用 AdamW,并设置适当的权重衰减参数。
  • 在大数据集上进行分布式训练,同时引入数据增强策略。

5. 总结

正则化技术是机器学习和深度学习中不可或缺的一部分,帮助模型在复杂场景下提升泛化能力并防止过拟合。
不同场景适合的正则化方法如下:

场景常用正则化方法
传统机器学习(线性模型)L1 正则化、L2 正则化
神经网络训练Dropout、数据增强
大模型训练(2022-2023)LayerNorm、AdamW、梯度裁剪、Label Smoothing

正则化方法的选择依赖于具体任务和模型的需求,但其核心思想始终是限制模型的复杂性,提升模型的稳定性和泛化能力。

Regularization: The Stabilizer of Machine Learning Models


1. What is Regularization?

Regularization is a set of techniques used in machine learning to constrain model complexity and prevent overfitting.
The primary goal of regularization is to ensure that the model performs well not only on the training data but also generalizes effectively to unseen test data.


2. Why Does Regularization Work?

2.1 The Nature of Overfitting

Overfitting happens when a model learns noise and irrelevant patterns in the training data, leading to poor generalization on new data. This is more common in cases with:

  • Insufficient training data
  • High model complexity
  • Noisy datasets
2.2 How Regularization Works

Regularization works by imposing constraints on the model’s complexity. This discourages it from fitting noise and forces it to focus on learning the underlying patterns in the data.

Mathematical Insight:
By adding a regularization term to the loss function, we effectively change the optimization objective, which restricts the parameter space.

For example, in linear regression:

  • Original loss function:
    L = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \mathcal{L} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 L=n1i=1n(yiy^i)2
  • Regularized loss function:
    L reg = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 + λ R ( θ ) \mathcal{L}_{\text{reg}} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 + \lambda R(\theta) Lreg=n1i=1n(yiy^i)2+λR(θ)

Where:

  • ( R ( θ ) R(\theta) R(θ) ) is the regularization term that penalizes complex models.
  • ( λ \lambda λ ) controls the trade-off between fitting the data and regularization strength.

3. Common Regularization Techniques

3.1 Parameter Regularization: L1 and L2 Regularization
  • L1 Regularization (Lasso)
    Adds the ( L 1 L1 L1 )-norm of the parameters to the loss function:
    R ( θ ) = ∥ θ ∥ 1 = ∑ j = 1 p ∣ θ j ∣ R(\theta) = \|\theta\|_1 = \sum_{j=1}^p |\theta_j| R(θ)=θ1=j=1pθj

    • Advantages: Encourages sparsity, making some parameters zero. Useful for feature selection.
    • Disadvantages: May lose some information in high-dimensional data.
  • L2 Regularization (Ridge)
    Adds the ( L 2 L2 L2 )-norm of the parameters to the loss function:
    R ( θ ) = ∥ θ ∥ 2 2 = ∑ j = 1 p θ j 2 R(\theta) = \|\theta\|_2^2 = \sum_{j=1}^p \theta_j^2 R(θ)=θ22=j=1pθj2

    • Advantages: Shrinks large parameter values, reducing model complexity.
    • Disadvantages: Does not produce sparse parameters; retains all features.

Code Example (Linear Regression with L1 and L2):

import numpy as np
from sklearn.linear_model import Ridge, Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# Generate synthetic data
np.random.seed(42)
X = np.random.rand(100, 5)
y = 3 * X[:, 0] + 2 * X[:, 1] + np.random.randn(100)# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Ridge (L2) Regularization
ridge = Ridge(alpha=1.0)
ridge.fit(X_train, y_train)
y_pred_ridge = ridge.predict(X_test)# Lasso (L1) Regularization
lasso = Lasso(alpha=0.1)
lasso.fit(X_train, y_train)
y_pred_lasso = lasso.predict(X_test)print("Ridge MSE:", mean_squared_error(y_test, y_pred_ridge))
print("Lasso MSE:", mean_squared_error(y_test, y_pred_lasso))

3.2 Data Augmentation

Data augmentation expands the training dataset by applying transformations (e.g., flips, rotations, cropping) to existing data, increasing model robustness and improving generalization.

Example (Image Augmentation in PyTorch):

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# Define data augmentation
transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),
])# Load dataset with augmentation
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# Print augmented image shape
for images, labels in train_loader:print(images.shape)  # Example: (64, 3, 32, 32)break

3.3 Dropout

Dropout randomly deactivates a subset of neurons during training, reducing reliance on specific neurons and preventing co-adaptation.

Mathematical Insight:
For a dropout rate ( p p p ), each neuron’s output is retained with probability ( 1 − p 1-p 1p ). During inference, the full network is used but scaled by ( 1 − p 1-p 1p ).

Code Example:

import torch
import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 256)self.dropout = nn.Dropout(p=0.5)  # 50% dropoutself.fc2 = nn.Linear(256, 10)def forward(self, x):x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xmodel = SimpleNN()
print(model)

3.4 Advanced Regularization Techniques for Large Models

With the advent of large-scale models (2022-2023), new regularization techniques have been widely adopted:

  1. LayerNorm and WeightNorm

    • LayerNorm normalizes activations across features within a layer.
    • WeightNorm separates weight vectors into magnitude and direction, improving optimization stability.
  2. Label Smoothing
    Prevents overconfidence in predictions by softening the target distribution:
    y ~ = ( 1 − ϵ ) ⋅ y + ϵ / K \tilde{y} = (1 - \epsilon) \cdot y + \epsilon / K y~=(1ϵ)y+ϵ/K

  3. Gradient Clipping
    Limits the magnitude of gradients to prevent exploding gradients:

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  4. AdamW Optimizer
    Combines the Adam optimizer with weight decay for improved regularization.


4. Regularization in Large Model Training

For models like GPT-3 and BERT, regularization involves combining multiple techniques:

  • LayerNorm and Dropout to stabilize training and reduce overfitting.
  • AdamW with appropriate weight decay settings.
  • Label Smoothing for classification tasks to prevent overconfidence.
  • Gradient Clipping to handle gradient explosion in deep networks.

5. Conclusion

Regularization is crucial for building robust machine learning models. The right choice of technique depends on the specific task and model requirements. Below is a summary of common regularization techniques:

ScenarioRegularization Methods
Traditional ML (linear models)L1, L2 regularization
Neural Network TrainingDropout, Data Augmentation
Large Model TrainingLayerNorm, AdamW, Label Smoothing

By constraining model complexity, regularization ensures models are stable, generalizable, and less prone to overfitting.

后记

2024年12月14日15点55分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/889902.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Day9 神经网络的偏导数基础

多变量函数与神经网络 在神经网络中,我们经常遇到多变量函数。这些函数通常描述了网络的输入、权重、偏置与输出之间的关系。例如,一个简单的神经元输出可以表示为: z f ( w 1 x 1 w 2 x 2 … w n x n b ) z f(w_1x_1 w_2x_2 \ldots…

map和set题目练习

一、习题一:随机链表的复制 1.1题目详情 1.2思路 在没有学习map和set之前,解决这道题最大的问题就在于无法建立原链表与拷贝链表的映射关系,只能通过在原链表每个节点后面新建一个新的链表来进行节点间的对应,而学习了map之后&a…

Hw亮度省电

1. 亮度控制策略 /decompile-hw/decompile/app/HwPowerGenieEngine3/src/main/res/xml/backlight_policy.xml <?xml version"1.0" encoding"utf-8"?> 2 <backlight_policy xmlns:android"http://schemas.android.com/apk/res/android&qu…

C语言入门(一):A + B _ 基础输入输出

前言 本专栏记录C语言入门100例&#xff0c;这是第&#xff08;一&#xff09;例。 目录 一、【例题1】 1、题目描述 2、代码详解 二、【例题2】 1、题目描述 2、代码详解 三、【例题3】 1、题目描述 2、代码详解 四、【例题4】 1、题目描述 2、代码详解 一、【例…

【21天学习AI底层概念】day8 什么是类意识?

类意识&#xff08;Quasi-Consciousness&#xff09; 是一个用来描述人工智能或复杂系统表现出的类似意识的行为或特性的概念。虽然这种系统不具备真正的意识&#xff08;即主观体验、情感和自我觉知&#xff09;&#xff0c;但在外部表现上&#xff0c;它们可能表现出与有意识…

Docker 镜像源 阿里镜像源限制后其他镜像源

要在Docker中修改镜像源&#xff0c;你需要编辑或创建Docker的配置文件来指定新的镜像源地址。以下是如何为Docker配置中国镜像源的步骤&#xff1a; 找到或创建Docker的配置文件daemon.json。 在Linux系统中&#xff0c;该文件通常位于/etc/docker/目录下。 编辑daemon.jso…

渗透测试学习笔记(五)网络

一.IP地址 1. IP地址详解 ip地址是唯一标识&#xff0c;一段网络编码局域网&#xff08;内网&#xff09;&#xff1a;交换机-网线-pcx.x.x.x 32位置2进制&#xff08;0-255&#xff09; IP地址五大类 IP类型IP范围A类0.0.0.0 到 127.255.255.255B类128.0.0.0 到191.255.25…

《自制编译器》--青木峰郎 -读书笔记 编译hello

在该书刚开始编译hello.cb时就遇到了问题。 本人用的是wsl&#xff0c;环境如下&#xff0c; 由于是64位&#xff0c;因此根据书中的提示&#xff0c;从git上下载了64位的cb编译器 cbc-64bit 问题一: 通过如下命令编译时,总是报错。 cbc -Wa,"--32" -Wl,"-…

LruCache(本地cache)生产环境中遇到的问题及改进

问题&#xff1a;单机qps增加时请求摘要后端&#xff0c;耗时也会增加&#xff0c;因为超过了后端处理能力&#xff08;最大qps&#xff0c;存在任务堆积&#xff09;。 版本一 引入LruCache。为了避免数据失效&#xff0c;cache数据的时效性要小于摘要后端物料的更新时间&…

jedis使用及注意事项

Jedis Jedis 是一个 Java 客户端&#xff0c;用于与 Redis 数据库进行交互。它提供了一系列简单易用的 API&#xff0c;使得在 Java 应用程序中使用 Redis 变得非常方便。以下是 Jedis 的使用方法及一些注意事项。 Jedis的优势 Lettuce客户端及Jedis客户端比较如下&#xff1a;…

CSDN博客:如何使用Python的`datasets`库转换音频采样率

CSDN博客&#xff1a;如何使用Python的datasets库转换音频采样率 什么是采样率&#xff1f;代码用途&#xff1a;调整音频数据的采样率完整代码示例代码详解运行结果&#xff08;示例&#xff09;总结 在这篇文章中&#xff0c;我们将学习如何使用Python的datasets库对音频数据…

浏览器执行机制

主线程 任务1&#xff0c;任务2 微队列微队列任务1&#xff0c; 微队列任务2延时队列延时队列任务1&#xff0c; 延时队列任务2交互队列.... 事件循环的工作原理 主线程执行同步任务&#xff1a; 主线程首先执行所有同步任务&#xff08;即栈中的任务&#xff09;。这些任务会…

Java 基础知识——part 4

8.成员方法&#xff1a;Java中必须通过方法才能对类和对象的属性操作&#xff1b;成员方法只在类的内部声明并加以实现。一般声明成员变量后再声明方法。 9.方法定义 方法的返回值是向外界输出的信息&#xff0c;方法类型和返回值类型同&#xff1b;返回值通过return返回&…

设计模式12:抽象工厂模式

系列总链接&#xff1a;《大话设计模式》学习记录_net 大话设计-CSDN博客 参考&#xff1a; C设计模式&#xff1a;抽象工厂模式&#xff08;风格切换案例&#xff09;_c 抽象工厂-CSDN博客 1.概念 抽象工厂模式&#xff08;Abstract Factory Pattern&#xff09;是软件设计…

【YashanDB知识库】kettle同步大表提示java内存溢出

【问题分类】数据导入导出 【关键字】数据同步&#xff0c;kettle&#xff0c;数据迁移&#xff0c;java内存溢出 【问题描述】kettle同步大表提示ERROR&#xff1a;could not create the java virtual machine! 【问题原因分析】java内存溢出 【解决/规避方法】 ①增加JV…

适配体技术在新药发现中的应用

适配体筛选技术在新药发现中的具体应用 适配体筛选技术&#xff0c;特别是SELEX&#xff08;Systematic Evolution of Ligands by Exponential Enrichment&#xff0c;指数富集的配体系统进化技术&#xff09;&#xff0c;在新药发现中扮演着至关重要的角色。这种技术能够从庞…

C/S软件授权注册系统(Winform+WebApi+.NET8+EFCore版)

适用软件&#xff1a;C/S系统、Winform桌面应用软件。 运行平台&#xff1a;Windows .NETCore&#xff0c;.NET8 开发工具&#xff1a;Visual Studio 2022&#xff0c;C#语言 数据库&#xff1a;Microsoft SQLServer 2012&#xff0c;Oracle 21c&#xff0c;MySQL8&#xf…

go语言使用websocket发送一条消息A,持续接收返回的消息

在Go语言中实现一个WebSocket客户端&#xff0c;可以使用gorilla/websocket这个非常流行的库来处理WebSocket连接。下面是一个简单的示例&#xff0c;展示了如何创建一个WebSocket客户端&#xff0c;向服务器发送消息"A"&#xff0c;并持续接收来自服务器的响应。 首…

监控易 IDC 数据中心一体化智能运维平台:新质生产力的典范

一、引言 在当今数字化飞速发展的时代&#xff0c;IDC 数据中心作为信息产业的核心基础设施&#xff0c;其稳定、高效运行对于企业和社会的重要性不言而喻。随着数据量的爆炸式增长和业务复杂度的提升&#xff0c;传统的运维模式已难以满足需求&#xff0c;数据中心面临着诸多挑…

活着就好20241218

亲爱的朋友们&#xff0c;大家早上好&#xff01;&#x1f31e; 今天是18号&#xff0c;星期三&#xff0c;2024年12月的第十八天&#xff0c;同时也是第50周的第九天&#xff0c;农历甲辰[龙]年十一月初十四日。在这晨光初照的美丽时刻&#xff0c;愿那温柔而灿烂的阳光轻轻洒…