大模型 | NEFTune之引入随机噪声对大模型训练的收益

大模型 | NEFTune之引入随机噪声对大模型训练的收益

paper中提到,在模型foward过程中,对inputs_embedding增加适度的随机噪声,会带来显著的收益。

Paper: https://arxiv.org/pdf/2310.05914.pdf
Github: https://github.com/neelsjain/NEFTune

文章目录

  • 大模型 | NEFTune之引入随机噪声对大模型训练的收益
  • 理论
  • 一. 实践方法
    • 1.1 等待Hugging发布该功能
    • 1.2 直接封装model
    • 1.3 改写compute_loss


理论

核心是输入经过Embedding层后,再加入一个均匀分布的噪声,噪声的采样范围为 [ − α L d , α L d ] [-\frac{\alpha}{\sqrt{Ld}},\frac{\alpha}{\sqrt{Ld}}] [Ld α,Ld α]之间,其中 α \alpha α为噪声超参,L为输入长度,d为Embedding层维度(即hidden维度)
在这里插入图片描述
在AlpacaEval榜单上,利用GPT4作为评分器,在多个数据上微调Llama2-7B模型,NEFTune方法相较于直接微调方法,均有显著提高。
在这里插入图片描述
可以缓解模型在指令微调阶段的过拟合现象,可以更好的利用预训练阶段的知识内容。

一. 实践方法

1.1 等待Hugging发布该功能

进度:等待hugging face正式发布此功能,2023-10-26

[10/17/2023] NEFTune has been intregrated into the Huggingface’s TRL (Transformer Reinforcement Learning) library. See Annoucement.

1.2 直接封装model

进度:直接对模型进行如下封装,原理是对model.embed_tokens.forward()进行改写,经实践,这种方法不管用,会报堆栈溢出的error。

from torch.nn import functional as Fdef NEFTune(model, noise_alpha=5)def noised_embed(orig_embed, noise_alpha):def new_func(x):# during training, we add noise to the embedding# during generation, we don't add noise to the embeddingif model.training:embed_init = orig_embed(x)dims = torch.tensor(embed_init.size(1) * embed_init.size(2))mag_norm = noise_alpha/torch.sqrt(dims)return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)else:return orig_embed(x)return new_func##### NOTE: this is for a LLaMA model ##### ##### For a different model, you need to change the attribute path to the embedding #####model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha)return model

1.3 改写compute_loss

进度:loss能够正常计算,但optimzer会报错,可能与精度有关,尚未解决

由于损失函数是自己写的,因此尝试在model(**input)前,追加噪声代码。注意,原先传入model的是input_ids,而当下由于我们将inputs_embeds增加了噪声,因此传入model的将直接替换为inputs_embeds,代码如下

class TargetLMLossNeft(Loss):def __init__(self, ignore_index):super().__init__()self.ignore_index = ignore_indexself.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)def __call__(self, model, inputs, training_args, return_outputs=False):input_ids = inputs['input_ids'] # B x L [3, 964]attention_mask = inputs['attention_mask'] # B x L target_mask = inputs['target_mask'] # B x L###  ----------------------------- add noise to embedsneftune_alpha = 5embed_device = model.base_model.model.model.embed_tokens.weight.deviceembeds_init = model.base_model.model.model.embed_tokens.forward(input_ids).to(embed_device) # 先forward一下, 变成B X L X hidden_state# embed_device = model.model.embed_tokens.weight.device# embeds_init = model.model.embed_tokens.forward(input_ids).to(embed_device)input_mask = attention_mask.to(embeds_init) # B x Linput_lengths = torch.sum(input_mask, 1) # B, 计算每个sample的实际长度noise_ = torch.zeros_like(embeds_init).uniform_(-1,1) # B X L X hidden_state, 且值域在[-1,1]正态分布delta = noise_ * input_mask.unsqueeze(2) # 追加一个维度,由B X L 变成 B X L X hidden_statedims = input_lengths * embeds_init.size(-1)mag = neftune_alpha / torch.sqrt(dims)delta = (delta * mag.view(-1, 1, 1)).detach() # B X L X hidden_stateinputs_embeds = delta + embeds_init### ----------------------------- add noise to embeds# 模型前馈预测, 原来传入的是input_ids,而现在需要直接将增加了noise的inputs_embeds传入# outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True)logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0] # 正常应该是torch.float32#logits.requires_grad = True # 奇怪,为什么这里会默认为False, 难道是因为上边的detach()# 将labels中不属于target的部分,设为ignore_index,只计算target部分的losslabels = torch.where(target_mask == 1, input_ids, self.ignore_index)shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # float32loss.requires_grad = Truereturn (loss, outputs) if return_outputs else loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/120462.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Centos7 安装 Etcd

Github上下载并解压安装包 wget https://github.com/coreos/etcd/releases/download/v3.4.10/etcd-v3.4.10-linux-amd64.tar.gz tar xzvf etcd-v3.4.10-linux-amd64.tar.gz mv etcd-v3.4.10-linux-amd64 /opt/etcd解压后是一些文档和两个二进制文件etcd和etcdctl。etcd是serve…

网络攻击的发展

在当今数字化时代,网站被攻击已经成为常态,网络威胁愈演愈烈。这些攻击不仅威胁到企业的安全,还可能导致严重的商业危机。本文将探讨为什么网络流量攻击变得如此普遍和容易,并分析未来可能引发的商业危机。 ​ 网络流量攻击的普遍…

【OpenCV实现图像的几何变换】

文章目录 概要:OpenCV实现图像的几何变换、图像阈值和平滑图像变换小结 概要:OpenCV实现图像的几何变换、图像阈值和平滑图像 使用OpenCV库进行图像处理的三个重要主题:几何变换、图像阈值处理以及图像平滑。在几何变换部分,详细…

【Linux精讲系列】——yum软件包管理

​作者主页 📚lovewold少个r博客主页 ⚠️本文重点:Linux系统软件包管理工具yum讲解 😄每日一言:踏向彼岸的每一步,都是到达彼岸本身。 目录 前言 Linux系统下的软件下载方式 yum 查看软件包 如何安装软件 如何卸…

myTracks for Mac:GPS轨迹记录器的强大与便捷

你是否曾经在户外活动或旅行中,希望能够记录下你的移动轨迹?或者在工作中,需要跟踪你的行程路线?myTracks for Mac 是一款强大的 GPS 轨迹记录器,它可以帮助你实现这些愿望。 myTracks 是一款专门为 Mac 设计的 GPS 轨…

el-tree业务

<el-form-item label"选择节点" prop"node_ids"><el-checkboxv-if"regionList.length"v-model"selectAll":disabled"selectDisabled":indeterminate"isIndeterminate":show-checkbox"!selectDisabl…

微信JSAPI支付对接

简介 JSAPI支付是指商户通过调用微信支付提供的JSAPI接口&#xff0c;在支付场景中调起微信支付模块完成收款。 应用场景 JSAPI支付适用于线下场所、公众号场景和PC网站场景。 商户已有H5商城网站&#xff0c;用户通过消息或扫描二维码在微信内打开网页时&#xff0c;可以调…

机器学习-学习率:从理论到实战,探索学习率的调整策略

目录 一、引言二、学习率基础定义与解释学习率与梯度下降学习率对模型性能的影响 三、学习率调整策略常量学习率时间衰减自适应学习率AdaGradRMSpropAdam 四、学习率的代码实战环境设置数据和模型常量学习率时间衰减Adam优化器 五、学习率的最佳实践学习率范围测试循环学习率&a…

Docker 批量导入镜像

可以编写一个脚本&#xff0c;该脚本循环遍历一个文件夹中的所有镜像存档文件&#xff0c;并使用 docker load 命令加载它们。以下是一个 Bash 脚本示例&#xff1a; #!/bin/bash# 指定存档文件所在的目录 archive_dir"/path/to/archives/"# 遍历存档文件并加载到 D…

十四、城市建成区时空扩张分析——景观格局指数

一、前言 景观格局指数:指景观格局与景观指数,景观格局通常是指景观的空间结构特征,具体是指由自然或人为形成的,一系列大小、形状各异,排列不同的景观镶嵌体在景观空间的排列,它即是景观异质性的具体表现,同时又是包括干扰在内的各种生态过程在不同尺度上作用的结果。…

Shopee新店多久出单?shopee新店如何运营?——站斧浏览器

shopee新店多久出单&#xff1f; 就以店铺每天上新来说&#xff0c;从店铺下来那天开始&#xff0c;每天10-20个产品去上新&#xff0c;正常情况下两周以内你的店铺是一定会有订单产生的。如果一两个月过去了&#xff0c;店铺还是没有单出&#xff0c;那就证明店铺存在很大的问…

【spark客户端】Spark SQL CLI详解:怎么执行sql文件、注释怎么写,支持的文件路径协议、交互式模式使用细节

文章目录 一. Spark SQL Command Line Options(命令行参数)二. The hiverc File1. without the -i2. .hiverc 介绍 三. 支持的路径协议四. 支持的注释类型五. Spark SQL CLI交互式命令六. Examples1. running a query from the command line2. setting Hive configuration vari…

缓解光纤激光切割机老化之如何保养光纤激光切割机的光学镜片

激光切割头具备极高的精密度和昂贵的价格&#xff0c;是光纤激光切割机最关键的运行部分之一。在日常的光纤激光切割机维修过程中频繁出现的关于切割头使用寿命的问题就是内部光学镜片的污染及损坏。 部分导致光纤激光切割机激光切割头光学镜片污染的原因主要包括&#xff1a;对…

【APP VTable】和市面上的 Table 组件一样,都是接收表格[] 以及数据源[]

博主&#xff1a;_LJaXi Or 東方幻想郷 专栏&#xff1a; uni-app | 小程序开发 开发工具&#xff1a;HBuilderX 这里写目录标题 表格组件USE 表格组件 <template><view class"scroll-table-wrapper"><view class"scroll-table-container"…

SpringMVC原理及核心组件

一、SpringMVC原理及核心组件 1、 Spring MVC的工作原理 Spring MVC 是一个对javaWeb中Servlet 简化和封装&#xff0c; 1.首先SpringMVC 配置DispatcherServlet 来接受所有的请求&#xff0c;我们通过DispatcherServlet 响应的所有数据&#xff0c;DispatcherServlet 是Htt…

iOS安全加固方法及实现

​ 目录 iOS安全加固方法及实现 摘要 引言 iOS安全加固方法及实现 一、字符串加密 二、类名方法名混淆 三、程序代码混淆 四、加入安全SDK 总结 参考资料 摘要 本文介绍了iOS平台下的应用安全保护方法&#xff0c;包括字符串加密、类名方法名混淆、程序代码混淆和加入…

好数组——尺取法

好数组 给定一个长度为 n 的数组 a&#xff0c;计算数组 a 中所有子数组中好数组的数目。 好数组定义如下&#xff1a; 对于数组 al ,al1, ⋯ ,ar &#xff0c;若数组中所有数的质因数种类数不超过 k&#xff0c;则称为好数组。 Input 输入的第一行包含两个正整数 n,k (1≤…

杂牌行车记录仪特殊AVI结构恢复案例

最近遇到一个杂牌的行车记录仪需要恢复数据&#xff0c;其使用AVI格式&#xff0c;但是在扫描恢复的过程中却发现厂家对其AVI结构进行了“魔改”致程序无法正常识别 故障存储:16G SD卡 fat32文件系统 故障现象: 16G的SD卡&#xff0c;在发生事故后客户尝试自行接到手机上读…

系列三、Spring IOC

一、概述 IOC的中文意思是控制反转&#xff0c;通俗地讲就是把创建对象的控制权交给了Spring去管理&#xff0c;以前是由程序员自己去创建控制对象&#xff0c;现在交由Spring去创建控制。 二、优点 集中管理对象&#xff0c;方便维护&#xff0c;降低耦合度。 三、IOC的底层…

前后端分离使用RSA加密

简介 1、前提 本篇文章前端使用的react&#xff0c;后端用的springboot&#xff0c;前端用什么框架都可以&#xff0c;大体实现逻辑是一样的&#xff0c;而且也是用jsencrypt这个库&#xff0c;只是后端可以我写的&#xff08;大部分是copy的别人的代码&#xff09;用RAS的工…