归一化/标准化对神经网络的训练是否有影响?

一、背景

        归一化(Normalization)和标准化(Standardization)是数据预处理中的两种常见技术,旨在调整数据的范围和分布,以提高机器学习模型或者深度学习模型的性能和训练速度。虽然它们的目标相似,但具体方法和应用场景有所不同。

1、归一化

        归一化是将数据缩放到一个特定的范围(通常是0到1之间),以确保不同特征具有相同的尺度。归一化常用于需要保持数据相对比例的场景,如图像处理和神经网络训练。最常见的归一化方法是最小-最大归一化(Min-Max Normalization),其公式为:

{x}' = \frac{x-x_{min}}{x_{max}-x_{min}}

        其中:

  • x是原始数据。
  • x′是归一化后的数据。
  • x_{min}x_{max}分别是数据集中的最小值和最大值。

2、标准化

        标准化是将数据转换为均值为0、标准差为1的标准正态分布。标准化常用于需要假设数据服从正态分布的场景,如线性回归和支持向量机:

{x}' = \frac{x-\mu}{\sigma }

        其中:

  • x是原始数据。
  • x′是标准化后的数据。
  • μ是数据的均值。
  • σ是数据的标准差。

二、python示例

        下面的代码中,我们构建了一个torch板的全连接神经网络,训练数据集采用sklearn自带的经典的乳腺癌二分类数据集。我们将展示对数据进行归一化以及不进行归一化会给模型训练带来什么样的影响(也可以换成标准化的代码试一下)。模型的结构以及构建过程不再赘述,可以参考笔者往期的博文。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import precision_score, recall_score, f1_score# 加载乳腺癌数据集
data = load_breast_cancer()
X = data.data
y = data.target# 数据标准化
scaler = MinMaxScaler()
X = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 将数据转换为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)# 定义自定义数据集类
class BreastCancerDataset(Dataset):def __init__(self, X, y):self.X = Xself.y = ydef __len__(self):return len(self.X)def __getitem__(self, idx):return self.X[idx], self.y[idx]# 创建数据集和数据加载器
train_dataset = BreastCancerDataset(X_train, y_train)
test_dataset = BreastCancerDataset(X_test, y_test)train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)# 定义MLP模型
class MLP(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(MLP, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)self.sigmoid = nn.Sigmoid()def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)out = self.sigmoid(out)return out# 定义模型参数
input_size = X_train.shape[1]
hidden_size = 16
num_classes = 1# 创建模型
model = MLP(input_size, hidden_size, num_classes)# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 5for epoch in range(num_epochs):model.train()running_loss = 0.0for inputs, labels in train_loader:# 前向传播outputs = model(inputs)loss = criterion(outputs.squeeze(), labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")print("Training complete.")# 评估模型
def evaluate(model, dataloader):model.eval()all_labels = []all_predictions = []with torch.no_grad():for inputs, labels in dataloader:outputs = model(inputs)predicted = (outputs.squeeze() > 0.5).float()all_labels.extend(labels.tolist())all_predictions.extend(predicted.tolist())precision = precision_score(all_labels, all_predictions)recall = recall_score(all_labels, all_predictions)f1 = f1_score(all_labels, all_predictions)return precision, recall, f1train_precision, train_recall, train_f1 = evaluate(model, train_loader)
test_precision, test_recall, test_f1 = evaluate(model, test_loader)print(f"Train Precision: {train_precision:.4f}")
print(f"Train Recall: {train_recall:.4f}")
print(f"Train F1 Score: {train_f1:.4f}")print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")

        未进行数据归一化直接训练的模型表现:

        归一化之后的模型表现:

        可以看到,数据归一化之后训练的模型性能相比原始数据训练出来的模型要大大提升。训练和测试是precision都提升了10%左右,F1值也提升了5%左右。

三、总结

        在机器学习建模过程中,许多模型在训练之前需要对输入数据进行归一化或标准化。这些模型通常对输入特征的尺度敏感,归一化或标准化可以提高模型的性能和训练速度。树类模型(如决策树、随机森林、梯度提升树等)不需要对特征进行归一化或标准化处理,因为它们的分裂标准与特征的尺度无关。然而,在神经网络中,输入数据的尺度过大可能导致激活函数的输出值过大或过小,影响梯度的计算,导致数值不稳定性。因此,我们在进行深度学习建模的过程中,如果特征取值之间的量纲差异较大,务必要对数据进行归一化或者标准化的处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61297.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis、TongRDS 可视化工具使用之 Redis Insight

题外话:除了可以连接 redis,也可以用来连接 TongRDS 1)官网下载 Redis Insight 2)安装 3)连接 4)使用 这里只是给一个使用例子

3D Gaussian Splatting在鱼眼相机中的应用与投影变换

paper:Fisheye-GS 1.概述 3D 高斯泼溅 (3DGS) 因其高保真度和实时渲染而备受关注。然而,由于独特的 3D 到 2D 投影计算,将 3DGS 适配到不同的相机型号(尤其是鱼眼镜头)带来了挑战。此外,基于图块的泼溅效率低下,尤其是对于鱼眼镜头的极端曲率和宽视野,这对于其更广泛…

C# 委托与事件

C# 委托 在C#中,委托(Delegate)是一种引用类型,用于封装方法的引用。它允许你将方法作为参数传递,或者将方法赋值给变量,从而实现方法的传递和调用。委托在C#中扮演着非常重要的角色,尤其是在事…

Node.js 安装与环境配置详解:从入门到实战

**标题:Node.js 安装与环境配置详解:从入门到实战** --- ### 一、Node.js 简介 Node.js 是一个基于 Chrome V8 引擎的 JavaScript 运行时环境,允许开发者在服务器端运行 JavaScript 代码。凭借其事件驱动、非阻塞 I/O 模型,Nod…

oracle查看锁阻塞-谁阻塞了谁

一 模拟锁阻塞 #阻塞1 一个会话正在往一个大表写入大量数据的时候,另一个会话加字段: #会话1 #会话2 会话2被阻塞了。 #阻塞2 模拟一个会话update一条记录,没提交。 另一个会话也update这一条记录: 会话2被阻塞了。 二 简单查…

django基于django的民族服饰数据分析系统的设计与实现

摘 要 随着网络科技的发展,利用大数据分析对民族服饰进行管理已势在必行;该平台将帮助企业更好地理解服饰市场的趋势,优化服装款式,提高服装的质量。 本文讲述了基于python语言开发,后台数据库选择MySQL进行数据的存储…

STM32单片机CAN总线汽车线路通断检测-分享

目录 目录 前言 一、本设计主要实现哪些很“开门”功能? 二、电路设计原理图 1.电路图采用Altium Designer进行设计: 2.实物展示图片 三、程序源代码设计 四、获取资料内容 前言 随着汽车电子技术的不断发展,车辆通信接口在汽车电子控…

iw添加wlan0导致crash问题分析

比如通过日下命令&#xff0c;创建一个wlan0接口 iw phy phy0 interface add wlan0 type managed 会产生如下panic内容 <1> [54245.466372] Unable to handle kernel NULL pointer dereference at virtual address 00000010 <1> [54245.474729] pgd c1794000 &…

k8s -20241119

用于管理云平台中多个主机上的容器化的应用&#xff0c;Kubernetes的目标是让部署容器化的应用简单并且高效&#xff08;powerful&#xff09;,Kubernetes提供了应用部署&#xff0c;规划&#xff0c;更新&#xff0c;维护的一种机制通过部署容器方式实现&#xff0c;每个容器之…

Linux 查看磁盘空间使用情况

1. df命令 功能&#xff1a;显示文件系统的整体磁盘空间使用情况。工作原理&#xff1a;读取文件系统的超级块信息&#xff0c;显示文件系统的总容量、已用空间、可用空间以及挂载点。特点&#xff1a; 显示的是整个分区的空间使用情况&#xff0c;而不是单个文件或目录的空间…

详解Rust的数据类型和语法

文章目录 基本数据类型复杂数据类型字符串基本语法 Rust是一种强调安全性和性能的系统编程语言。它的设计目标之一是防止内存安全错误同时提供丰富的功能和灵活的语法。下面介绍一下Rust语言的基本数据类型和语法。 基本数据类型 1.整数类型 有符号整数: i8, i16, i32, i64, i…

golang对日期格式化

1.对日期格式化为 YYYY-mm-dd, 并且没有数据时&#xff0c;返回空 import ("encoding/json""time" )type DateTime time.Timetype SysRole struct {RoleId int64 gorm:"type:bigint(20);primary_key;auto_increment;角色ID;" json:&quo…

MySQL系列之数据授权(privilege)

导览 前言Q&#xff1a;如何对MySQL数据库进行授权管理一、MySQL的“特权”1. 权限级别2. 权限清单 二、授权操作1. 查看权限2. 分配权限3. 回收权限 结语精彩回放 前言 看过博主上一篇的盆友&#xff0c;可以Get到一个知识点&#xff1a;数据授权&#xff08;eg&#xff1a;g…

项目进度计划表:详细的甘特图的制作步骤

甘特图&#xff08;Gantt chart&#xff09;&#xff0c;又称为横道图、条状图&#xff08;Bar chart&#xff09;&#xff0c;是一种用于管理时间和任务活动的工具。 甘特图由亨利劳伦斯甘特&#xff08;Henry Laurence Gantt&#xff09;发明&#xff0c;是一种通过条状图来…

抽象工厂方法模式

工厂方法模式&#xff08;Factory Method Pattern&#xff09; 工厂方法模式是一种 创建型设计模式&#xff0c;它定义了一个创建对象的接口&#xff0c;但让子类决定实例化哪一个具体类。通过这种方式&#xff0c;工厂方法将对象的创建延迟到子类&#xff0c;避免了直接依赖具…

【Redis】Redis实现的消息队列

一、用list实现【这是数据类型所以支持持久化】 消息基于redis存储不会因为受jvm内存上限的限制&#xff0c;支持消息的有序性&#xff0c;基于redis的持久化机制&#xff0c;只支持单一消费者订阅&#xff0c;无法避免消息丢失。 二、用PubSub【这不是数据类型&#xff0c;是…

Linux登录指令last详解

引言 在Linux系统中&#xff0c;了解用户登录记录是系统管理和安全审计的重要任务之一。last指令作为Linux系统中用于检索和展示用户登录信息的工具&#xff0c;扮演着至关重要的角色。本文将详细介绍last指令的定义、架构、原理、企业应用以及常见的命令体系&#xff0c;帮助…

CSP-X2024山东小学组T2:消灭怪兽

题目链接 题目名称 题目描述 怪兽入侵了地球&#xff01; 为了抵抗入侵&#xff0c;人类设计出了按顺序排列好的 n n n 件武器&#xff0c;其中第 i i i 件武器的攻击力为 a i a_i ai​&#xff0c;可以造成 a i a_i ai​ 的伤害。 武器已经排列好了&#xff0c;因此不…

网络安全常见练习靶场

DVWA (Dam Vulnerable Web Application) DVWA是用PHPMysql编写的一套用于常规WEB漏洞教学和检测的WEB脆弱性测试程序。包含了SQL注入、XSS、盲注等常见的一些安全漏洞。 链接地址&#xff1a;http://www.dvwa.co.uk mutillidaemutillidae mutillidaemutillidae是一个免费&am…

【操作系统笔记】目录

【操作系统笔记】操作系统框架https://blog.csdn.net/Resurgence03/article/details/142624262 【操作系统笔记】CPU管理https://blog.csdn.net/Resurgence03/article/details/142621526 【操作系统笔记】内存管理https://blog.csdn.net/Resurgence03/article/details/142669…