Transformer微调实战:通过低秩分解(LoRA)对T5模型进行微调(LoRA Fine Tune)

scient

scient一个用python实现科学计算相关算法的包,包括自然语言、图像、神经网络、优化算法、机器学习、图计算等模块。

scient源码和编译安装包可以在Python package index获取。

The source code and binary installers for the latest released version are available at the [Python package index].

https://pypi.org/project/scient

可以用pip安装scient

You can install scient like this:

pip install scient

也可以用setup.py安装。

Or in the scient directory, execute:

python setup.py install

scient.neuralnet

神经网络相关算法模块,包括attention、transformer、bert、lstm、resnet、crf、dataset、fit等。

scient.neuralnet.lora

实现了多个网络层的LoRA微调,包括Linear。

scient.neuralnet.lora.Linear(in_features: int, out_features: int, r:int, bias: bool = True)

Parameters

  • in_features : int
    Linear层的输入节点数.
  • out_features : int
    Linear层的输出节点数.
  • r : int
    中间层维度为r.
  • bias : bool, optional
    Linear层的bias参数.

Algorithms

LoRA的基本原理是冻结预训练的模型参数,然后在Transfomer的每一层中加入一个可训练的旁路矩阵(低秩可分离矩阵),接着将旁路输出与初始路径输出相加输入到网络当中,并只训练这些新增的旁路矩阵参数。其中,低秩可分离矩阵由两个矩阵组成,第一个矩阵负责降维,第二个矩阵负责升维,中间层维度为r,从而来模拟本征秩(intrinsic rank),这两个低秩矩阵能够大幅度减小参数量。

在这里插入图片描述

Examples

下面采用代码实例说明LoRA微调T5的过程,首先需要构建T5模型,T5模型的构建参见:Transformer经典模型实战:零基础训练一个面向中文的T5模型(Text to Text Transfer Transformer)
本示例所用的代码与上述链接中的T5模型构建、数据准备、训练、验证基本一致,不同之处是在模型构建时加入了如下LoRA部分:

pretrain_path='d:\\model.state_dict'#构建T5模型,并加载预训练的权重,后面对此预训练模型进行微调。
model=transformer.T5Transformer(vocab_size=vocab_size,dropout=0.1,ffn_size=3072)
model.load_state_dict(torch.load(pretrain_path),strict=False)#本示例的LoRA作用于attention中的query权重
for layer in model.encoder+model.decoder:# breaklayer.multi_head_attn.query=lora.Linear(layer.multi_head_attn.query.in_features, layer.multi_head_attn.query.out_features,r=64,bias=layer.multi_head_attn.query.bias)#LoRA矩阵的命名为 lora_A 和 lora_B,这里将LoRA矩阵之外的权重进行冻结
for k,v in model.named_parameters():# breakif 'lora' not in k:v.requires_grad=Falseelse:print(k,v.requires_grad)

进行如上设置,采用T5模型相同的训练方式,即可对T5进行微调,具体训练方式参见:Transformer经典模型实战:零基础训练一个面向中文的T5模型(Text to Text Transfer Transformer)

在训练前后,可以查看LoRA权重不断更新,非LoRA权重不更新,查看方式如下:

model.encoder[0].multi_head_attn.query.lora_A
model.encoder[0].multi_head_attn.query.lora_B
model.encoder[0].multi_head_attn.query.weight
model.encoder[0].multi_head_attn.query.bias

附代码中用到的tokenizer模型spiece.model、训练数据rewrite_train3.xlsx和预训练模型model.state_dict的下载地址:
链接:https://pan.baidu.com/s/12vEZBYldXvPrJTiFUEKGUw?pwd=DTFM
提取码:DTFM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/51042.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Aria2安装和使用-Mac版

起因是需要网盘下载,无奈限速很烦,查找很多方案后,最终决定使用Aria2 Tampermonkey。 其中Aria2是一款开源轻量的下载软件,简单来说就是可以通过URL直接下载。 Tampermonkey则是一款插件,我这里是.crx结尾的谷歌插件…

文件IO函数:open/close,read/write,lseek

open和close函数 C语言中的文件IO操作需要使用到open()函数和close()函数来打开和关闭文件。 open()函数的原型如下: int open(const char *pathname, int flags); int open(const char *filename, int flags,mode_t mode);其中,filename表示要打开…

性能优化理论篇 | swap area是个什么东西

我们知道每台计算机的内存(RAM)都是有限的,而我们的应用程序需要加载到内存才能被运行,如果一台机器运行多个应用程序时,内存可能会耗尽。Linux 系统中的“交换空间(也称为交换分区)”可以帮助缓…

Google Play开发者账号地址验证难题?这些经验或许能帮到你

目前,想要把应用顺利上架到 Google Play,已经不像以前那么简单了,主要是开发者需要应对 Google 日益严格的审核机制。其中,账号验证的地址验证绝对是让很多人头疼的一个环节。 今天就来给大家分享一些真实的经验和干货&#xff0c…

IA实验:静态路由(1基础版)

实验拓扑: 实验要求: 1.实现全网通 实验思路: 1.给各个设备配置好端口的IP地址并且划分广播域,暂定网段为: 192.168.1.0 192.168.2.0 192.168.3.0 2.pc1去ping各网段的设备以及端口,发现3.0网段以及2.2网…

数控单主轴走心机多少钱

单主轴走心机的价格因品牌、型号、性能及配置等因素而异,因此无法给出一个具体的统一价格。一般来说,单主轴走心机的价格在数万元到数十万元不等。 市场上某些品牌的单主轴走心机价格可能在十几万之间不等,而另一些高端型号或定制产品的价格可…

苹果上架没有iphone、没有ipad也可以生成截屏

使用flutter、uniapp或其他跨平台框架开发ios的APP,上架的时候都会遇到一个问题,上架的时候需要各种尺寸的设备来做ios截屏。 比如目前最新的要求是,iphone需要三种不同尺寸的设备的截屏,假如支持ipad则还需要使用ipad 2代和ipad…

从0开始搭建个人博客《第十一期:优化网站访问速度》

目录 一、背景说明 二、Nginx性能优化 (一)文件句柄 1.系统全局性修改和用户局部性修改 2.进程局部性修改 (二)CPU亲和配置 1.设置工作进程数 2.设置连接数 (三)事件处理模型优化 (四&…

MySql 高阶 概念(了解即可)

mysql 分为4层结构: 连接层:负责处理链接,鉴权,安全。 服务层:负责sql接口,sql分析,sql优化,sql缓存。 引擎层:负责执行服务层的操作,不同的引擎拥有不同的特…

《机器学习》一元、多元线性回归的实现 No.4

一、一元线性回归实现 先直接看完整代码: import pandas as pd import matplotlib.pyplot as plt from sklearn.linear_model import LinearRegressiondate pd.read_csv(data.csv) #导入数据plt.scatter(date[广告投入],date[销售额]) # 用散点图展示数据 plt.sh…

这对二婚夫妻结婚半年,一起生活才一个月,就走到了婚姻尽头!

这对二婚夫妻结婚半年,一起生活才一个月,就走到了婚姻尽头! 这是一篇涉离婚纠纷的民事起诉状 (范文点评) 离 婚 起 诉 状 原告:韩某斌,男,现年37岁,汉族,打…

记录一个变量溢出的bug

文章目录 如题 如题 count2变量溢出了(超过了255),结果导致busOff_16bitRecordHILTime变量莫名其妙被清0

「C++系列」vector 容器

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站:人工智能教程 文章目录 一、vector 容器1. 基本特性2. 基本操作3. 注意事项 二、应用场景1. 应用场景2. 案例案例一&#xff1…

DRF——请求的封装与版本管理

文章目录 django restframework1. 快速上手2. 请求数据的封装3. 版本管理3.1 URL的GET参数传递(*)3.2 URL路径传递(*)3.3 请求头传递3.4 二级域名传递3.5 路由的namespace传递 小结 django restframework 快速上手请求的封装版本…

公司的全称可以申请注册商标吗,还有什么注意!

近日有个网友找到普推知产商标老杨,发来公司全称,问这个可以申请注册商标不,看发来是“贵州**酒业有限公司”,应该是做茅台镇酒的,以前以分析过《公司全称能不能注册商标》,这次帮网友分析完做下补充。 公…

基于springboot的招聘系统的设计与实现

TOC springboot614基于springboot的招聘系统的设计与实现--论文 研究背景 近年来,由于计算机技术和互联网技术的快速发展,使得所有企事业单位内部都是数字化、信息化、无纸化的发展趋势,随着趋势的发展,各种决策系统、辅助系统…

编码器精度

系列文章目录 1.元件基础 2.电路设计 3.PCB设计 4.元件焊接 5.板子调试 6.程序设计 7.算法学习 8.编写exe 9.检测标准 10.项目举例 11.职业规划 文章目录 前言一、影响因素二、编码器精度三、位置因素四、环境因素五、磁编码器 前言 送给大学毕业后找不到奋斗方向的你&…

IntelliJ IDEA 集成 ShardingSphere-JDBC 访问分库分表

背景 众所周知,IntelliJ IDEA 是 Java 领域常用的开发工具之一,IDEA Ultimate(旗舰版)或其他例如 DataGrip 等 Intellij 平台的工具都集成了对数据库的访问能力。 但是,对于做了分库分表的项目,直接使用 …

中秋节月饼销售利用106短信群发平台业绩翻倍案例分析

在中秋节这一传统佳节,月饼作为节日的标志性食品,其销售市场竞争尤为激烈。为了在众多品牌中脱颖而出,不少月饼销售企业开始探索创新的营销方式。其中,利用106短信群发平台进行精准营销,成为众多企业实现业绩翻倍的有效…

TCP端口范围

ip_local_port_range sysctl -a | grep ip_local_port_range | head 默认情况下,net.ipv4.ip_local_port_range的默认值为32768-60999。这意味着本地应用程序可以使用的端口号范围为32768到60999。 sysctl -a | grep net.ipv4.ip_local_reserved_ports |head …