大模型 | NEFTune之引入随机噪声对大模型训练的收益

大模型 | NEFTune之引入随机噪声对大模型训练的收益

paper中提到,在模型foward过程中,对inputs_embedding增加适度的随机噪声,会带来显著的收益。

Paper: https://arxiv.org/pdf/2310.05914.pdf
Github: https://github.com/neelsjain/NEFTune

文章目录

  • 大模型 | NEFTune之引入随机噪声对大模型训练的收益
  • 理论
  • 一. 实践方法
    • 1.1 等待Hugging发布该功能
    • 1.2 直接封装model
    • 1.3 改写compute_loss


理论

核心是输入经过Embedding层后,再加入一个均匀分布的噪声,噪声的采样范围为 [ − α L d , α L d ] [-\frac{\alpha}{\sqrt{Ld}},\frac{\alpha}{\sqrt{Ld}}] [Ld α,Ld α]之间,其中 α \alpha α为噪声超参,L为输入长度,d为Embedding层维度(即hidden维度)
在这里插入图片描述
在AlpacaEval榜单上,利用GPT4作为评分器,在多个数据上微调Llama2-7B模型,NEFTune方法相较于直接微调方法,均有显著提高。
在这里插入图片描述
可以缓解模型在指令微调阶段的过拟合现象,可以更好的利用预训练阶段的知识内容。

一. 实践方法

1.1 等待Hugging发布该功能

进度:等待hugging face正式发布此功能,2023-10-26

[10/17/2023] NEFTune has been intregrated into the Huggingface’s TRL (Transformer Reinforcement Learning) library. See Annoucement.

1.2 直接封装model

进度:直接对模型进行如下封装,原理是对model.embed_tokens.forward()进行改写,经实践,这种方法不管用,会报堆栈溢出的error。

from torch.nn import functional as Fdef NEFTune(model, noise_alpha=5)def noised_embed(orig_embed, noise_alpha):def new_func(x):# during training, we add noise to the embedding# during generation, we don't add noise to the embeddingif model.training:embed_init = orig_embed(x)dims = torch.tensor(embed_init.size(1) * embed_init.size(2))mag_norm = noise_alpha/torch.sqrt(dims)return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)else:return orig_embed(x)return new_func##### NOTE: this is for a LLaMA model ##### ##### For a different model, you need to change the attribute path to the embedding #####model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha)return model

1.3 改写compute_loss

进度:loss能够正常计算,但optimzer会报错,可能与精度有关,尚未解决

由于损失函数是自己写的,因此尝试在model(**input)前,追加噪声代码。注意,原先传入model的是input_ids,而当下由于我们将inputs_embeds增加了噪声,因此传入model的将直接替换为inputs_embeds,代码如下

class TargetLMLossNeft(Loss):def __init__(self, ignore_index):super().__init__()self.ignore_index = ignore_indexself.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)def __call__(self, model, inputs, training_args, return_outputs=False):input_ids = inputs['input_ids'] # B x L [3, 964]attention_mask = inputs['attention_mask'] # B x L target_mask = inputs['target_mask'] # B x L###  ----------------------------- add noise to embedsneftune_alpha = 5embed_device = model.base_model.model.model.embed_tokens.weight.deviceembeds_init = model.base_model.model.model.embed_tokens.forward(input_ids).to(embed_device) # 先forward一下, 变成B X L X hidden_state# embed_device = model.model.embed_tokens.weight.device# embeds_init = model.model.embed_tokens.forward(input_ids).to(embed_device)input_mask = attention_mask.to(embeds_init) # B x Linput_lengths = torch.sum(input_mask, 1) # B, 计算每个sample的实际长度noise_ = torch.zeros_like(embeds_init).uniform_(-1,1) # B X L X hidden_state, 且值域在[-1,1]正态分布delta = noise_ * input_mask.unsqueeze(2) # 追加一个维度,由B X L 变成 B X L X hidden_statedims = input_lengths * embeds_init.size(-1)mag = neftune_alpha / torch.sqrt(dims)delta = (delta * mag.view(-1, 1, 1)).detach() # B X L X hidden_stateinputs_embeds = delta + embeds_init### ----------------------------- add noise to embeds# 模型前馈预测, 原来传入的是input_ids,而现在需要直接将增加了noise的inputs_embeds传入# outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True)logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0] # 正常应该是torch.float32#logits.requires_grad = True # 奇怪,为什么这里会默认为False, 难道是因为上边的detach()# 将labels中不属于target的部分,设为ignore_index,只计算target部分的losslabels = torch.where(target_mask == 1, input_ids, self.ignore_index)shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # float32loss.requires_grad = Truereturn (loss, outputs) if return_outputs else loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/120462.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络攻击的发展

在当今数字化时代,网站被攻击已经成为常态,网络威胁愈演愈烈。这些攻击不仅威胁到企业的安全,还可能导致严重的商业危机。本文将探讨为什么网络流量攻击变得如此普遍和容易,并分析未来可能引发的商业危机。 ​ 网络流量攻击的普遍…

【OpenCV实现图像的几何变换】

文章目录 概要:OpenCV实现图像的几何变换、图像阈值和平滑图像变换小结 概要:OpenCV实现图像的几何变换、图像阈值和平滑图像 使用OpenCV库进行图像处理的三个重要主题:几何变换、图像阈值处理以及图像平滑。在几何变换部分,详细…

【Linux精讲系列】——yum软件包管理

​作者主页 📚lovewold少个r博客主页 ⚠️本文重点:Linux系统软件包管理工具yum讲解 😄每日一言:踏向彼岸的每一步,都是到达彼岸本身。 目录 前言 Linux系统下的软件下载方式 yum 查看软件包 如何安装软件 如何卸…

myTracks for Mac:GPS轨迹记录器的强大与便捷

你是否曾经在户外活动或旅行中,希望能够记录下你的移动轨迹?或者在工作中,需要跟踪你的行程路线?myTracks for Mac 是一款强大的 GPS 轨迹记录器,它可以帮助你实现这些愿望。 myTracks 是一款专门为 Mac 设计的 GPS 轨…

微信JSAPI支付对接

简介 JSAPI支付是指商户通过调用微信支付提供的JSAPI接口,在支付场景中调起微信支付模块完成收款。 应用场景 JSAPI支付适用于线下场所、公众号场景和PC网站场景。 商户已有H5商城网站,用户通过消息或扫描二维码在微信内打开网页时,可以调…

机器学习-学习率:从理论到实战,探索学习率的调整策略

目录 一、引言二、学习率基础定义与解释学习率与梯度下降学习率对模型性能的影响 三、学习率调整策略常量学习率时间衰减自适应学习率AdaGradRMSpropAdam 四、学习率的代码实战环境设置数据和模型常量学习率时间衰减Adam优化器 五、学习率的最佳实践学习率范围测试循环学习率&a…

【spark客户端】Spark SQL CLI详解:怎么执行sql文件、注释怎么写,支持的文件路径协议、交互式模式使用细节

文章目录 一. Spark SQL Command Line Options(命令行参数)二. The hiverc File1. without the -i2. .hiverc 介绍 三. 支持的路径协议四. 支持的注释类型五. Spark SQL CLI交互式命令六. Examples1. running a query from the command line2. setting Hive configuration vari…

缓解光纤激光切割机老化之如何保养光纤激光切割机的光学镜片

激光切割头具备极高的精密度和昂贵的价格,是光纤激光切割机最关键的运行部分之一。在日常的光纤激光切割机维修过程中频繁出现的关于切割头使用寿命的问题就是内部光学镜片的污染及损坏。 部分导致光纤激光切割机激光切割头光学镜片污染的原因主要包括:对…

【APP VTable】和市面上的 Table 组件一样,都是接收表格[] 以及数据源[]

博主&#xff1a;_LJaXi Or 東方幻想郷 专栏&#xff1a; uni-app | 小程序开发 开发工具&#xff1a;HBuilderX 这里写目录标题 表格组件USE 表格组件 <template><view class"scroll-table-wrapper"><view class"scroll-table-container"…

iOS安全加固方法及实现

​ 目录 iOS安全加固方法及实现 摘要 引言 iOS安全加固方法及实现 一、字符串加密 二、类名方法名混淆 三、程序代码混淆 四、加入安全SDK 总结 参考资料 摘要 本文介绍了iOS平台下的应用安全保护方法&#xff0c;包括字符串加密、类名方法名混淆、程序代码混淆和加入…

杂牌行车记录仪特殊AVI结构恢复案例

最近遇到一个杂牌的行车记录仪需要恢复数据&#xff0c;其使用AVI格式&#xff0c;但是在扫描恢复的过程中却发现厂家对其AVI结构进行了“魔改”致程序无法正常识别 故障存储:16G SD卡 fat32文件系统 故障现象: 16G的SD卡&#xff0c;在发生事故后客户尝试自行接到手机上读…

项目进度延误,危机管理5大注意事项

项目延误危机管理的重要性是不可忽视的。项目延误可能会导致资源浪费、成本增加、客户不满、信誉受损等一系列问题&#xff0c;严重影响项目的成功与效益。因此&#xff0c;有效地进行项目延误危机管理是至关重要的&#xff0c;一般主要是从以下5个方面进行管理&#xff1a; 1、…

《动手学深度学习 Pytorch版》 10.6 自注意力和位置编码

在注意力机制中&#xff0c;每个查询都会关注所有的键&#xff0d;值对并生成一个注意力输出。由于查询、键和值来自同一组输入&#xff0c;因此被称为 自注意力&#xff08;self-attention&#xff09;&#xff0c;也被称为内部注意力&#xff08;intra-attention&#xff09;…

竞赛 深度学习人体跌倒检测 -yolo 机器视觉 opencv python

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; **基于深度学习的人体跌倒检测算法研究与实现 ** 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学长非常推荐&#xff01; &#x1f947;学长这里给一个题目综合评分(每项满…

npm改变npm缓存路径和改变环境变量

在安装nodejs时&#xff0c;系统会自动安装在系统盘C&#xff0c; 时间久了经常会遇到C盘爆满&#xff0c;有时候出现红色&#xff0c;此时才发现很多时候是因为npm 缓存保存在C盘导致的&#xff0c;下面就介绍下如何改变npm缓存路径。 1、首先找到安装nodejs的路径&#xff0c…

JVM(Java Virtual Machine)G1收集器篇

前言 本文参考《深入理解Java虚拟机》&#xff0c;本文主要介绍G1收集器的收集思想和具体过程&#xff08;填上一篇文章留下的坑&#xff09; 本系列其他文章链接&#xff1a; JVM&#xff08;Java Virtual Machine&#xff09;内存模型篇 JVM&#xff08;Java Virtual Machi…

SQL sever中函数(2)

目录 一、函数分类及应用 1.1标量函数&#xff08;Scalar Functions&#xff09;&#xff1a; 1.1.1格式 1.1.2示例 1.1.3作用 1.2表值函数&#xff08;Table-Valued Functions&#xff09;&#xff1a; 1.2.1内联表值函数&#xff08;Inline Table-Valued Functions&am…

Linux shell编程学习笔记15:定义数组、获取数组元素值和长度

一、 Linux shell 脚本编程中的数组概述 数组是一种常见的数据结构。跟大多数编程语言一样&#xff0c;大多数Linux shell脚本支持数组&#xff0c;但对数组的支持程度各不相同&#xff0c;比如数组的维度&#xff0c;是支持一维数组还是多维数组&#xff1f;再如&#xff0c;…

Redis为什么变慢了

一、Redis为什么变慢了 1.Redis真的变慢了吗? 对 Redis 进行基准性能测试 例如,我的机器配置比较低,当延迟为 2ms 时,我就认为 Redis 变慢了,但是如果你的硬件配置比较高,那么在你的运行环境下,可能延迟是 0.5ms 时就可以认为 Redis 变慢了。 所以,你只有了解了你的…