Transformers 加速的一些常用技巧

Transformers 是一个强大的架构,但模型因其采用的自注意力机制,虽然能够有效地处理序列数据并捕获长距离依赖关系,但同时也容易导致在训练过程中出现OOM(Out of Memory,内存不足)或者达到GPU的运行时限制。

主要是因为

  1. 参数数量庞大:Transformer模型通常包含大量的参数,尤其是在模型层面进行扩展时(例如,增加层数或头数)。这些参数需要大量的内存来存储权重和梯度。
  2. 自注意力计算:自注意力机制需要对输入序列的每个元素与其他所有元素计算其相互关系,导致计算复杂度和内存需求随着输入长度的增加而显著增加。对于非常长的序列,这一点尤其突出。
  3. 激活和中间状态存储:在训练过程中,需要存储前向传播中的中间激活状态,以便于反向传播时使用。这增加了额外的内存负担。

为了解决这些问题,我们今天来总结以下一些常用的加速策略

固定长度填充

在处理文本数据时,由于文本序列的长度可能各不相同,但许多机器学习模型(尤其是基于Transformer的模型)需要输入数据具有固定的尺寸,因此需要对文本序列进行固定长度填充(padding)。

在使用Transformer模型时,填充部分不应影响到模型的学习。因此通常需要使用注意力掩码(attention mask)来指示模型在自注意力计算时忽略这些填充位置。通过这种固定长度填充和相应的处理方法,可以使得基于Transformer的模型能够有效地处理不同长度的序列数据。在实际应用中,这种方法是处理文本输入的常见策略。

 def fixed_pad_sequences(sequences, max_length, padding_value=0):padded_sequences = []for sequence in sequences:if len(sequence) >= max_length:padded_sequence = sequence[:max_length]  # Trim the sequence if it exceeds max_lengthelse:padding = [padding_value] * (max_length - len(sequence))  # Calculate paddingpadded_sequence = sequence + padding  # Pad the sequencepadded_sequences.append(padded_sequence)return padded_sequences

这种方式会将所有的序列填充成一个长度,这样虽然长度相同了,但是因为序列的实际大小本来就不同,同一批次很可能出现有很多填充的情况,所以就出现了动态填充策略。

动态填充是在每个批处理中动态填充输入序列到最大长度。与固定长度填充不同,在固定长度填充中,所有序列都被填充以匹配整个数据集中最长序列的长度,动态填充根据该批中最长序列的长度单独填充每个批中的序列。

这样虽然每个批次的长度是不同的,但是批次内部的长度是相同的,可以加快处理速度。

 def pad_sequences_dynamic(sequences, padding_value=0):max_length = max(len(seq) for seq in sequences)  # Find the maximum length in the sequencespadded_sequences = []for sequence in sequences:padding = [padding_value] * (max_length - len(sequence))  # Calculate paddingpadded_sequence = sequence + padding  # Pad the sequencepadded_sequences.append(padded_sequence)return padded_sequences

等长匹配

等长匹配是在训练或推理过程中将长度相近的序列分组成批处理的过程。等长匹配通过基于序列长度将数据集划分为桶,然后从这些桶中采样批次来实现的。

从上图可以看到,通过等长匹配的策略,减少了填充量,这样也可以加速计算

 def uniform_length_batching(sequences, batch_size, padding_value=0):# Sort sequences based on their lengthssequences.sort(key=len)# Divide sequences into buckets based on lengthbuckets = [sequences[i:i+batch_size] for i in range(0, len(sequences), batch_size)]# Pad sequences within each bucket to the length of the longest sequence in the bucketpadded_batches = []for bucket in buckets:max_length = len(bucket[-1])  # Get the length of the longest sequence in the bucketpadded_bucket = []for sequence in bucket:padding = [padding_value] * (max_length - len(sequence))  # Calculate paddingpadded_sequence = sequence + padding  # Pad the sequencepadded_bucket.append(padded_sequence)padded_batches.append(padded_bucket)return padded_batches

自动混合精度

自动混合精度(AMP)是一种通过使用单精度(float32)和半精度(float16)算法的组合来加速深度学习模型训练的技术。它利用了现代gpu的功能,与float32相比,使用float16数据类型可以更快地执行计算,同时使用更少的内存。

 import torchfrom torch.cuda.amp import autocast, GradScaler# Define your modelmodel = YourModel()# Define optimizer and loss functionoptimizer = torch.optim.Adam(model.parameters(), lr=1e-3)criterion = torch.nn.CrossEntropyLoss()# Create a GradScaler object for gradient scalingscaler = GradScaler()# Inside the training loopfor inputs, targets in dataloader:# Clear previous gradientsoptimizer.zero_grad()# Cast inputs and targets to the appropriate deviceinputs, targets = inputs.to(device), targets.to(device)# Enable autocasting for forward passwith autocast():# Forward passoutputs = model(inputs)loss = criterion(outputs, targets)# Backward pass# Scale the loss valuescaler.scale(loss).backward()# Update model parametersscaler.step(optimizer)# Update the scale for next iterationscaler.update()

AMP在训练过程中动态调整计算精度,允许模型在大多数计算中使用float16,同时自动将某些计算提升为float32,以防止下流或溢出等数值不稳定问题。

Fp16 vs Fp32

双精度(FP64)消耗64位。符号值为1位,指数值为11位,有效精度为52位。

单精度(FP32)消耗32位。符号值为1位,指数值为8位,有效精度为23位。

半精度(FP16)消耗16位。符号值为1位,指数值为5位,有效精度为10位。

所以Fp16可以提高内存节省,并可以大大提高模型训练的速度。考虑到Fp16的优势和它在模型使用方面的主导区域,它非常适合推理任务。但是fp16会产生数值精度的损失,导致计算或存储的值不准确,考虑到这些值的精度至关重要。

另外就是这种优化师针对于分类任务的,对于回归这种需要精确数值的任务Fp16的表现并不好。

总结

以上这些方法,可以在一定程度上缓解内存不足和计算资源的限制,但是对于大型的模型我们还是需要一个强大的GPU。

https://avoid.overfit.cn/post/7240bee210cd408a90ca04279830040e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/11200.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI大模型探索之路-训练篇22: ChatGLM3微调实战-从原理到应用的LoRA技术全解

系列篇章💥 AI大模型探索之路-训练篇1:大语言模型微调基础认知 AI大模型探索之路-训练篇2:大语言模型预训练基础认知 AI大模型探索之路-训练篇3:大语言模型全景解读 AI大模型探索之路-训练篇4:大语言模型训练数据集概…

MPLAB X IDE编译attiny1616工程报错却无报错信息

MPLAB X IDE(XC-8编译器)编译报错,无具体错误内容,仅显示需要xc-8 pro的警告。 内存占用率显示为81%,未超标。 原因:软件使用了microchip的bootloader功能。应用程序起始地址(也是bootloader结束地址)设置错…

社交巨头:探索Facebook的震撼力量

Facebook作为社交媒体领域的巨头,不仅在数字化社会中占据着重要地位,更是影响了人们的生活、工作和社交方式。本文将深入探索Facebook的震撼力量,从多个角度解读其在当今社会中的重要性和影响。 1. 全球用户覆盖的壮观规模 Facebook作为全球…

docker安装时报错:Error: Nothing to do

安装docker时报以下错误 解决方法: 1.下载关于docker的相关依赖环境 yum -y install yum-utils device-mapper-persistent-data lvm22.设置下载Docker的镜像源 yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo3…

FMEA存在的五个主要不足及改进措施——FMEA软件

免费试用FMEA软件-免费版-SunFMEA 在制造业和产品设计领域,失效模式与影响分析(Failure Modes and Effects Analysis,简称FMEA)被广泛运用,用于预防潜在的设计或制造缺陷。然而,尽管FMEA在风险管理方面发挥…

开发者集结号:大湾区 Open Source Day 邀您共探技术前沿

开源技术正以其开放、协作的特性,引领着软件开发的新潮流,是推动社会进步的重要力量。作为开发者,您是否渴望深入了解开源项目的前沿动态?由ALC深圳与2024中国互联网发展创新与投资大赛联合举办、FISCO金链盟深度参与的大湾区 Ope…

MySQL————创建存储过程函数

存储过程使用大纲 有参数传递 delimiter $$ 声明一个名称为get_student_introduce create procedure add_student_infor( in p_userName VARCHAR(20),in p_phone VARCHAR(11),in p_sex char(2),in p_introduce VARCHAR(255)) 开始操作 BEGIN 撰写真正在操作DMLDQL都行 INSE…

CSS---复合选择器、元素显示模式和背景(三)

一、CSS的复合选择器 1.1 什么是复合选择器 在CSS中,可以根据选择器的类型把选择器分为基础选择器和复合选择器,复合选择器是建立在基础选择器之上,对基本选择器进行组合形成的。 复合选择器是由两个或多个基础选择器连写组成,它…

【云原生】kubernetes核心组件

引言: Kubernetes 是为运行分布式集群而建立的,分布式系统的本质使得网络成为 Kubernetes 的核心和必要组成部分,了解 Kubernetes 网络模型可以使你能够正确运行、监控和排查应用程序故障。 一、Kubernetes的核心组件 1.1、Master组件 1.1.…

基于Springboot+Vue的Java项目-农产品直卖平台系统开发实战(附演示视频+源码+LW)

大家好!我是程序员一帆,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &am…

可道云teamOS企业网盘实用插件介绍:实时在线流程图编辑与分享,用在线流程图打造数字化工作流程

在使用企业网盘用于日常办公的情况下,有一些实用的在线小工具能为团队效率和协作带来一定的提升。 今天要给大家介绍的可道云teamOS的在线画流程图,是很值得介绍的一个在线工具。 在线流程图:直观展示,高效便捷 以往我们想要梳理…

FANUC机器人单轴零点标定时提示无法执行零点标定,由于重力补偿已启用,所有机器人轴的脉冲计数必须有效

FANUC机器人单轴零点标定时提示无法执行零点标定,由于重力补偿已启用,所有机器人轴的脉冲计数必须有效 首先,机器人由于长时间断电未使用,6个轴的编码器数据全部丢失,上电后报警SRVO-062, 有关SRVO-062故障报警的相关内容可参考以下链接: FANUC机器人SRVO-062报警原因分…

Scratch四级:第08讲 排序算法

第08讲 排序算法 教练:老马的程序人生 微信:ProgrammingAssistant 博客:https://lsgogroup.blog.csdn.net/ 讲课目录 常考的排序算法项目制作:“三个数排序”项目制作:“成绩查询”项目制作:“排序”项目制…

单片机智能灯控制系统源程序仿真原理图与论文全套资料

目录 1、设计描述 2、仿真图 3、程序 4、资料内容 资料下载地址:单片机智能灯控制系统源程序仿真原理图与论文全套资料下载 1、设计描述 设计了一款智能控制系统。 AT89C51LCD1602DS1302按键LED组成了这样一个完整的设计。 P2.0-P2.3 4个LED等代表庭院内的4…

架构设计之学新而知故

缘由 因为一些特殊的机缘,接触到洋葱架构等一些新架构设计概念。 尝试理解了一段时间,就想简单梳理下对它们的理解,以达到学新而知故 😃 信息增益 以前计算机专业并不设置通信领域的信息论的专业课程,但是&#xf…

英语复习之英语形近词总结(四)

英语形近词总结复习第四部分: 单词 释义例句 genuine 英 /ˈdʒenjuɪn/ 美 /ˈdʒenjuɪn/ adj.真实的,真正的;诚恳的 1.Only genuine refugees can apply for asylum. 只有真正的难民才能申请政治避难。 《牛津词典》 2.This isnt a genui…

C++笔试强训day19

目录 1.小易的升级之路 2.礼物的最大价值 3.对称之美 1.小易的升级之路 链接 模拟就行&#xff0c;唯一可能是难点得就是gcd&#xff08;最大公约数&#xff09; #include <iostream> using namespace std; #define int long long const int N 1e5 10; int arr[N];…

利用IP地址查询解决被“薅羊毛”的方法

在互联网时代&#xff0c;随着各种网络诈骗手段的不断更新和演变&#xff0c;“薅羊毛”成为了一种常见的网络犯罪行为。其中&#xff0c;利用查询IP地址进行欺诈活动已经成为一种普遍的手段。当个人或组织的IP地址被不法分子查询后&#xff0c;可能会面临虚假注册、盗取个人信…

AVL Cruise与Simulink联合仿真(通过MATLAB DLL方式)

最近毕业设计需要用到AVL Cruise与Simulink进行联合仿真&#xff0c;分析汽车模型的经济性。下面介绍一下我所知的AVL Cruise与Simulink联合仿真的几种方式&#xff0c;它们各自的优缺点&#xff0c;以及DLL方式联合仿真的具体配置过程。我这里用的MATLAB软件版本是2021a&#…

有边数限制的最短路

文章目录 题目 有边数限制的最短路算法分析1、问题&#xff1a;为什么Dijkstra不能使用在含负权的图中&#xff1f;dijkstra详细步骤2、什么是bellman - ford算法&#xff1f;3、bellman - ford算法的具体步骤4、在下面代码中&#xff0c;是否能到达n号点的判断中需要进行if(di…