深入解析 Loss 减少方式:mean和sum的区别及其在大语言模型中的应用 (中英双语)

深入解析 Loss 减少方式:meansum 的区别及其在大语言模型中的应用

在训练大语言模型(Large Language Models, LLM)时,损失函数(Loss Function)的处理方式对模型的性能和优化过程有显著影响。本文以 reduce_loss 参数为例,详细探讨 meansum 两种方式的定义、适用场景及其对对话模型性能的潜在提升原因,并通过代码实例加深理解。


1. 什么是 reduce_loss

reduce_loss 决定了在每个 batch 中,如何对 token-level 的损失进行归一化或累加处理。常见的选项是:

  • mean: 取每个 token 损失的平均值。
  • sum: 将每个 token 损失直接累加。

参数定义示例(在代码中通过 dataclass 定义):参考来源:https://github.com/allenai/open-instruct

from dataclasses import dataclass, field@dataclass
class TrainingArguments:reduce_loss: str = field(default="mean",metadata={"help": ("How to reduce loss over tokens. Options are 'mean' or 'sum'.""Using 'sum' can improve chat model performance.")},)

2. meansum 的定义

2.1 mean 模式
  • 定义:将 batch 中所有 token 的损失值取平均。
  • 公式
    Loss mean = ∑ i = 1 N Loss i N \text{Loss}_{\text{mean}} = \frac{\sum_{i=1}^{N} \text{Loss}_i}{N} Lossmean=Ni=1NLossi
    其中 ( N N N) 是当前 batch 中的 token 总数。
  • 特性:每个 token 的损失对最终的 loss 贡献相等,损失值与 batch 中的 token 数无关。
2.2 sum 模式
  • 定义:将 batch 中所有 token 的损失值直接累加。
  • 公式
    Loss sum = ∑ i = 1 N Loss i \text{Loss}_{\text{sum}} = \sum_{i=1}^{N} \text{Loss}_i Losssum=i=1NLossi
  • 特性:长序列(更多 token)的损失对总 loss 的贡献更大,损失值直接与 token 数成正比。

3. meansum 的区别

模式特点优点缺点
mean损失对 token 数归一化,独立于 batch size。稳定性强,适用于 token 数差异大的批次。长序列与短序列对损失的贡献相同,可能弱化长序列的重要性。
sum损失值与 token 总数成正比,长序列贡献更大。在注重长序列表现的任务中效果更好(如对话生成)。损失值随 batch size 变化波动,需要动态调整学习率。

4. 适用场景分析

4.1 mean
  • 适用任务:大多数语言建模任务,如 GPT 或 BERT 的预训练。
  • 适用场景:当训练数据中序列长度差异较大时,mean 可以避免因长序列的损失值过大而导致梯度更新不均衡。
4.2 sum
  • 适用任务:对长序列表现要求较高的任务,如对话生成(Chat Models)和长文本生成。
  • 适用场景:长序列的损失占比更高,从而使优化过程更加关注全局上下文的建模。

5. 为什么 sum 能提升对话模型性能?

对话模型(Chat Models)的训练中,长序列往往包含丰富的上下文信息,而短序列则可能无法体现模型的上下文理解能力。在 sum 模式下:

  1. 长序列的重要性增加:长序列的损失对总损失的贡献更大,这促使模型更关注上下文的建模。
  2. 对全局一致性更敏感sum 模式下,模型的优化方向更倾向于全序列的一致性,特别适合需要长距离依赖的任务。

示例
假设一个 batch 包含以下两个样本:

  • 样本 A: 长度为 10,损失总和为 5。
  • 样本 B: 长度为 50,损失总和为 25。

计算损失贡献:

  • mean 模式
    Loss mean = 5 + 25 10 + 50 = 0.5 \text{Loss}_{\text{mean}} = \frac{5 + 25}{10 + 50} = 0.5 Lossmean=10+505+25=0.5
    样本 A 和 B 的贡献权重相同。
  • sum 模式
    Loss sum = 5 + 25 = 30 \text{Loss}_{\text{sum}} = 5 + 25 = 30 Losssum=5+25=30
    样本 B 的贡献权重显著增加,优化更关注长序列。

6. 实战代码

以下是一个完整的训练脚本,展示如何在 Hugging Face 的 transformers 框架中使用 reduce_loss 参数。

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch# 模型和数据集
model_name = "meta-llama/Llama-3.1-8B"
dataset_name = "allenai/tulu-3-sft-mixture"model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)dataset = load_dataset(dataset_name)
tokenized_dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding="max_length"), batched=True)
train_loader = DataLoader(tokenized_dataset["train"], batch_size=2, shuffle=True)# 训练设置
reduce_loss = "sum"  # 改为 "mean" 可对比效果
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 训练循环
for epoch in range(2):for batch in train_loader:inputs = torch.tensor(batch["input_ids"]).to(device)labels = inputs.clone()outputs = model(inputs, labels=labels)if reduce_loss == "sum":loss = outputs.loss.sum()else:  # 默认 "mean"loss = outputs.loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()print(f"Epoch: {epoch}, Loss: {loss.item()}")

7. 注意事项与优化建议

  1. 动态调整学习率

    • 使用 sum 时,由于损失值放大,建议适配学习率,如降低到 mean 模式的 ( 1 / N 1/N 1/N )。
    • 配合学习率调度器(如 linear)优化训练。
  2. 对长短序列的平衡

    • 若长序列权重过大导致模型性能退化,可结合 curriculum learning 或混合训练策略(如对长短序列按比例采样)。
  3. 性能评估

    • 在验证集上,关注长序列和短序列的生成性能对比。

8. 总结

reduce_loss 的选择对模型性能有直接影响:

  • mean 更通用,适合大多数语言建模任务。
  • sum 在对话生成等长序列敏感任务中表现更优。

希望本文能为 LLM 研究人员提供思路和参考,在具体任务中灵活选择合适的损失归一化方式,从而提升模型性能。

Understanding the Difference Between mean and sum Loss Reduction in LLM Training

When training large language models (LLMs), the way token-level loss is reduced across a batch can significantly impact optimization and model performance. This article delves into the reduce_loss parameter, exploring the differences between mean and sum reduction modes, their definitions, use cases, and why sum might improve the performance of chat-oriented models. Practical code examples are also provided for clarity.


1. What is reduce_loss?

The reduce_loss parameter determines how the token-level loss values in a batch are aggregated. The two most common options are:

  • mean: Averages the loss over all tokens in a batch.
  • sum: Sums the loss of all tokens in a batch.

Example definition (from the codebase using Python dataclass):

from dataclasses import dataclass, field@dataclass
class TrainingArguments:reduce_loss: str = field(default="mean",metadata={"help": ("How to reduce loss over tokens. Options are 'mean' or 'sum'.""Using 'sum' can improve chat model performance.")},)

2. Definitions of mean and sum

2.1 mean
  • Definition: Averages the loss across all tokens in a batch.
  • Formula:
    Loss mean = ∑ i = 1 N Loss i N \text{Loss}_{\text{mean}} = \frac{\sum_{i=1}^{N} \text{Loss}_i}{N} Lossmean=Ni=1NLossi
    where ( N N N ) is the total number of tokens in the batch.
  • Characteristics: The contribution of each token to the final loss is normalized, making the loss independent of the batch’s token count.
2.2 sum
  • Definition: Sums up the loss across all tokens in a batch.
  • Formula:
    Loss sum = ∑ i = 1 N Loss i \text{Loss}_{\text{sum}} = \sum_{i=1}^{N} \text{Loss}_i Losssum=i=1NLossi
  • Characteristics: The total loss is proportional to the number of tokens, giving longer sequences more weight in the optimization process.

3. Key Differences Between mean and sum

Reduction ModeCharacteristicsAdvantagesDisadvantages
meanNormalizes the loss by token count.Stable and robust for datasets with variable-length sequences.Long sequences are underweighted relative to short ones.
sumLoss scales with the number of tokens.Places greater emphasis on longer sequences, improving performance in tasks requiring context modeling.Loss values vary with batch size, necessitating dynamic learning rate adjustment.

4. Use Cases for mean and sum

4.1 mean
  • Best Suited For: Pretraining or general language modeling tasks like GPT or BERT.
  • Scenario: When the dataset contains sequences of widely varying lengths, mean ensures that longer sequences do not disproportionately influence gradient updates.
4.2 sum
  • Best Suited For: Tasks requiring high performance on long sequences, such as dialogue generation or document-level text generation.
  • Scenario: Encourages the model to prioritize sequences with richer contexts, as their loss contributes more to the overall optimization.

5. Why Does sum Improve Chat Model Performance?

In chat-oriented models, sequences are typically longer and require the model to understand and generate coherent responses over extended contexts. Using sum mode:

  1. Enhances Long Sequence Weighting: Longer sequences contribute more to the total loss, emphasizing the importance of context modeling.
  2. Encourages Global Consistency: By assigning more weight to longer contexts, the model better captures dependencies across the entire sequence.
  3. Balances Token Importance: Since chat models are often evaluated on dialogue-level coherence, sum ensures that tokens from the context and the response are proportionally weighted.

Example:
Consider a batch with two samples:

  • Sample A: Sequence length = 10, loss = 5.
  • Sample B: Sequence length = 50, loss = 25.

Loss calculations:

  • mean mode:
    Loss mean = 5 + 25 10 + 50 = 0.5 \text{Loss}_{\text{mean}} = \frac{5 + 25}{10 + 50} = 0.5 Lossmean=10+505+25=0.5
    Both samples contribute equally to the loss.
  • sum mode:
    Loss sum = 5 + 25 = 30 \text{Loss}_{\text{sum}} = 5 + 25 = 30 Losssum=5+25=30
    Sample B contributes much more to the total loss, focusing the optimization on longer contexts.

6. Practical Implementation

Here’s a practical training script that demonstrates the use of reduce_loss in both modes.

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch# Model and dataset
model_name = "meta-llama/Llama-3.1-8B"
dataset_name = "allenai/tulu-3-sft-mixture"model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)dataset = load_dataset(dataset_name)
tokenized_dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding="max_length"), batched=True)
train_loader = DataLoader(tokenized_dataset["train"], batch_size=2, shuffle=True)# Training setup
reduce_loss = "sum"  # Change to "mean" to compare effects
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# Training loop
for epoch in range(2):for batch in train_loader:inputs = torch.tensor(batch["input_ids"]).to(device)labels = inputs.clone()outputs = model(inputs, labels=labels)if reduce_loss == "sum":loss = outputs.loss.sum()else:  # Default: "mean"loss = outputs.loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()print(f"Epoch: {epoch}, Loss: {loss.item()}")

7. Practical Considerations

  1. Learning Rate Adjustment:

    • When using sum, the loss magnitude increases with batch size, so you may need to adjust the learning rate (e.g., scale it down by ( 1 / N 1/N 1/N )).
  2. Balancing Long and Short Sequences:

    • Overweighting long sequences can sometimes harm generalization. Using curriculum learning or sampling strategies (e.g., proportional sampling) can help mitigate this.
  3. Validation:

    • Evaluate model performance on both short and long sequences to confirm improvements in the intended metrics.

8. Conclusion

The choice between mean and sum loss reduction modes depends on the specific task and dataset:

  • Use mean for general-purpose language modeling tasks where sequence lengths vary significantly.
  • Use sum for tasks that prioritize long-sequence performance, such as chat models or long-text generation.

Understanding and experimenting with these settings can lead to better-optimized models, particularly in the nuanced field of LLM fine-tuning.

后记

2024年12月3日16点04分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/62183.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 AutoFlow 快速搭建基于 TiDB 向量搜索的本地知识库问答机器人

导读 本文将详细介绍如何通过 PingCAP 开源项目 AutoFlow 实现快速搭建基于 TiDB 的本地知识库问答机器人。如果提前准备好 Docker、TiDB 环境,整个搭建过程估计在 10 分钟左右即可完成,无须开发任何代码。 文中使用一篇 TiDB 文档作为本地数据源作为示…

生信技能63 - 构建gnomAD变异位点的SQLite查询数据库

将数据量巨大的gnomAD数据库,通过SQLite数据库寻找gnomAD中存在的各种变异注释信息(如等位基因计数,深度,次要等位基因频率等),查询300.000个变量的查询需要大约40秒,通过染色体编号+位置+REF+ALT即可进行快速查询。 1. gnomAD变异注释VCF文件字段 gnomAD VCF各版本包…

【前端】将vue的方法挂载到window上供全局使用,也方便跟原生js做交互

【前端】将vue的方法挂载到window上供全局使用&#xff0c;也方便跟原生js做交互 <template><div><el-button click"start">调用方法</el-button></div> </template> <script> // import { JScallbackProc } from ./JScal…

基于XML的AOP开发

AOP 为 Aspect Oriented Programming 的缩写&#xff0c;意思为面向切面编程。 AOP相关术语&#xff1a; 目标对象(Target)&#xff1a; 你要去代理的对象&#xff0c;可以理解为之前很单纯的那个对象。 代理对象(Proxy)&#xff1a; 你把你那个单纯的对象给我&#xff0c…

记录blender学习过程中遇到的问题

物体发射的方向不对 被发射物体&#xff08;例如一棵树&#xff09;n键看旋转归0 切换正视图 将被发射物体的局部坐标的Z轴 指向 全局方向的X轴时 并且把粒子系统设置的物体旋转勾选上 方向就对了 做倒角发现有问题 检查缩放应用、面朝向、有没有重合点&#xff08;融合点&am…

Ubuntu系统中Redis的安装步骤及服务配置

目录 内容概括 系统环境 安装方式 1、apt包管理器安装 &#xff08;1&#xff09;安装redis服务 &#xff08;2&#xff09;安装客户端&#xff08;进入命令行操作使用&#xff0c;包含redis-cli&#xff09; &#xff08;3&#xff09;安装检验 &#xff08;4&#xf…

半导体设备中的微型导轨应如何选择合适的润滑油?

微型导轨的润滑对于保证其高精度和高稳定性至关重要&#xff0c;尤其是在半导体设备中&#xff0c;微型导轨的润滑油选择需要考虑多个因素&#xff0c;以确保设备的最佳性能和寿命。以下是一些关键点&#xff1a; 1、黏度&#xff1a;润滑油的黏度是影响其流动性和润滑效果的重…

RocketMq详解:六、RocketMq的负载均衡机制

上一章&#xff1a;《SpringBootAop实现RocketMq的幂等》 文章目录 1.背景1.1 什么是负载均衡1.2 负载均衡的意义 2.RocketMQ消息消费2.1 消息的流转过程2.2 Consumer消费消息的流程 3.RocketMq的负载均衡策略3.1 Broker负载均衡3.2 Producer发送消息负载均衡3.3 消费端的负载均…

yocto的xxx.bb文件在什么时候会拷贝文件到build目录

在 Yocto 中&#xff0c;.bb 文件用于描述如何构建和安装一个软件包&#xff0c;而文件在构建过程中的拷贝操作通常会在某些特定的步骤中进行。具体来说&#xff0c;文件会在以下几个阶段被拷贝到 build 目录&#xff08;或者更准确地说&#xff0c;拷贝到目标目录 ${D}&#x…

主打极致性价比,AMD RX 8600/8800显卡定了

*以下内容仅为网络爆料及传闻&#xff0c;一切以官方消息为准。 这谁能想到&#xff0c;率先掏出下一代桌面独立显卡的不是老大哥 NVIDIA&#xff0c;也不是 AMD&#xff0c;反而是三家中存在感最弱的 Intel&#xff01; 就在 12 月 3 日&#xff0c;Intel 正式发布了自家第二…

数组哪些方法会触发Vue监听,哪些不会触发监听

发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【宝藏入口】。 在 Vue 中&#xff0c;数组的变化是通过 响应式 系统来监听的。Vue 使用 getter 和 setter 来追踪数组的变化&#xff0c;并在数…

npm, yarn, pnpm之间的区别

前言 在现代化的开发中&#xff0c;一个人可能同时开发多个项目&#xff0c;安装的项目越来越多&#xff0c;所随之安装的依赖包也越来越臃肿&#xff0c;而且有时候所安装的速度也很慢&#xff0c;甚至会安装失败。 因此我们就需要去了解一下&#xff0c;我们的包管理器&#…

工业检测基础-工业相机选型及应用场景

以下是一些常见的工业检测相机种类、检测原理、应用场景及选型依据&#xff1a; 2D相机 检测原理&#xff1a;基于二维图像捕获&#xff0c;通过分析图像的明暗、纹理、颜色等信息来检测物体的特征和缺陷.应用场景&#xff1a;广泛应用于平面工件的外观检测&#xff0c;如检测…

C语言连接数据库

文章目录 一、初始化数据库二、创建数据库连接三、执行增删改查语句1、增删改2、查 四、执行增删改查语句 接下来我简单的介绍一下怎么用C语言连接数据库。 初始化数据库创建数据库连接执行增删改查语句关闭数据库连接 一、初始化数据库 // 数据库初始化 MYSQL mysql; MYSQL* r…

优化LabVIEW数据运算效率的方法

在LabVIEW中进行大量数据运算时&#xff0c;提升计算效率并减少时间占用是开发过程中常遇到的挑战。为此&#xff0c;可以从多个角度着手优化&#xff0c;包括合理选择数据结构与算法、并行处理、多线程技术、硬件加速、内存管理和界面优化等。通过采用这些策略&#xff0c;可以…

开源模型应用落地-安全合规篇-用户输入价值观判断(四)

一、前言 在深度合规功能中,对用户输入内容的价值观判断具有重要意义。这一功能不仅仅是对信息合法性和合规性的简单审核,更是对信息背后隐含的伦理道德和社会责任的深刻洞察。通过对价值观的判断,系统能够识别可能引发不当影响或冲突的内容,从而为用户提供更安全、更和谐的…

计算机的错误计算(一百七十六)

摘要 利用某一大语言模型计算 的值&#xff0c;输出为 0 . 例1. 在某一大语言模型下&#xff0c;计算 的值。其中sin中值取弧度。结果保留16位有效数字。 直接贴图吧&#xff1a; 点评&#xff1a; &#xff08;1&#xff09;以上为一个大模型给的答案。从其回答可知&…

数据结构与算法——1204—递归分治法

1、斐波那契数列优化 使用滚动变量&#xff0c;保存当前计算结果和前两项值 (1)RAB (2)更新计算对象&#xff0c;AB&#xff0c;BR #include<iostream> using namespace std;int fun(int n) {if (n 0)return 0;if (n 1 || n 2)return 1;int num11;int num21;int su…

openstack内部rpc消息通信源码分析

我们知道openstack内部消息队列基于AMQP协议&#xff0c;默认使用的rabbitmq 消息队列。谈到rabbitmq&#xff0c;大家或许并不陌生&#xff0c;但或许会对oslo message有些陌生。openstack内部并不是直接使用rabbitmq&#xff0c;而是使用了oslo.message 。oslo.message 后端的…

Python 3 和 MongoDB 的集成使用

Python 3 和 MongoDB 的集成使用 MongoDB 是一个流行的 NoSQL 数据库&#xff0c;以其灵活的数据模型和强大的查询功能而闻名。Python 3 作为一种广泛使用的编程语言&#xff0c;与 MongoDB 的集成变得日益重要。本文将介绍如何在 Python 3 环境中集成和使用 MongoDB&#xff…