OpenAI大模型中的模型推理

模型推理

推理有两个方案,一个和训练相同,直接加入Lora层,不过会增加推理延时因为多了lora层的计算,适合线下测评用,如下

from peft import PeftModel
from transformers import AutoModel, AutoTokenizer
​
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, load_in_8bit=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = PeftModel.from_pretrained(model, "./lora_ckpt")
model.half().to(device)
model.eval()

另一个没有推理延时的方案,是先把lora权重和原始模型权重进行合并,把合并后的参数存储成新的bin文件,然后和加载常规模型一样加载合并后的模型参数进行推理。权重合并的代码如下

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
# when merging disable int8
model = AutoModel.from_pretrained("THUDM/chatglm-6b", load_in_8bit=False, torch_dtype=torch.float16,trust_remote_code=True, device_map={"": "cpu"},
)
## 用来检查权重是否合并成功,合并成功weight会改变
first_weight = model.base_model.layers[0].attention.query_key_value.weight
first_weight_old = first_weight.clone()
​
# 返回的不是新的模型,而是在原始模型上加了adapter层
lora_model = PeftModel.from_pretrained(model,"./lora_ckpt",device_map={"": "cpu"},torch_dtype=torch.float16,
)
# 报错:A*B shape mismatch,大概率是get_peft_model错误修改了peft_config里面的fan_in_fan_out参数,某个peft的revision有这个bug
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
​
# 报错:大概率peft训练有问题,检查adapter.bin大小
assert not torch.allclose(first_weight_old, first_weight), 'Weight Should Change after Lora Merge'
​
# lora模型权重把原模型权重加了prefix,这里移除恢复原始key
deloreanized_sd = {k.replace("base_model.model.", ""): vfor k, v in lora_model.state_dict().items()if "lora" not in k
}
# 保存合并后的模型权重
lora_model.save_pretrained(output_dir, state_dict=deloreanized_sd)

T5

  • paper: 2019.10 Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer

  • Task: Everything

  • Prompt: 前缀式人工prompt

  • Model: Encoder-Decoder

  • Take Away: 加入前缀Prompt,所有NLP任务都可以转化为文本生成任务

img

T5论文的初衷如标题所言,是为了全面公平的对比不同预训练和迁移策略的贡献和效果,避免在A模型上效果不好的预训练目标在B上可能效果更优的情况,对比项包括

  • 预训练目标:语言模型,乱序还原,MLM(不同的掩码率),Span掩码, etc

  • 预训练数据:构建C4数据集,从C4抽取不同领域语料来训练

  • 模型架构: Encoder-Decoder,Decoder Only,Encoder Only

  • 迁移策略:逐步解冻,全量微调,局部微调

  • 其他:多任务预训练,模型大小

说句题外话,再看论文结果发现Encoder-Decoder的模型结果+SpanMLM损失函数效果最好。不知道这是否是谷歌押注T5,而没有像OpenAI一样选择Deocder结构的原因。

具体对比结果这里不细说,本文只关注T5为了公平对比以上差异,提出的Text2Text的通用建模框架:用相同的模型,相同的预训练,相同的损失函数和解码方式,把文本分类,摘要,翻译,QA都转化成了生成任务,而转化的方式就是通过加入前缀prompt。

针对不同的下游微调任务,我们看下T5提出的Text2Text是如何构建prompt模板的

  1. WMT英语到德语的翻译任务,输入是'translate English to German:'+input, 输出是翻译结果

  2. CNN Mail摘要任务: 文本摘要任务,输入是‘Summarize:'+input,输出是摘要

  3. MNLI任务:输入是'mnli hypothesis:'+假设+'premise:'+叙述,输出是contradiction, entailment,neutral

  4. STS文本相似任务:输入是'stsb sentence1:'+input1+‘sentence2:’+input2, 输出是1~5的打分(离散化)

  5. 问答SQuAD任务:输入是'question:'+提问+ 'context:'+上下文,输出是答案

不难发现在T5的时代,prompt模板的构建还比较粗糙,更多是单纯的任务名称+任务类型来区分不同的NLP任务,只是让模型在解码时多一层条件概率,既给定不同prompt前缀在解码时采用不同的条件概率(attention)。并没有太多从语义和上下文关联的角度去进行prompt模板的构建,我猜这是T5在论文中提到他们尝试了不同的prompt模板发现效果影响有限的原因(哈哈因为都不太好所以没啥差异),不不能否定T5在通用LM上做出的贡献~

PET-TC(a)

  • paper a: 2020.1 Exploiting Cloze Questions for Few Shot Text Classification and Natural

  • prompt: 单字完形填空式人工Prompt

  • Task: Text Classification

  • Model: Roberta-large, XLM-R

  • Take Away: 加入完形填空式Prompt把文本分类任务转化成单字MLM

img

和第一章的LAMA相似,PET-TC也是把输入映射成完形填空式的prompt模板,对掩码词进行预测作为分类标签。不过PET没有直接使用prompt,而是用了半监督的方案。用多个prompt模板微调模型后,对大规模无监督数据进行预测,然后在伪标签上进行常规的模型微调,哈哈绕了一个圈最后还是输出的常规微调的模型。我大胆猜测作者很看好prompt范式在微调时引入的前置语义信息,以及无额外参数的设定,但是对不同prompt和answer模板带来的不稳定性感到头疼,于是搞出这么个折中的方法~

prompt & Answer Engineer

PET针对每个数据集人工设计了prompt模板和Answer词对标签的映射。针对单双文本输入分别举两个例子,以下a,b为原始输入文本,'_'位置为MASK词

  • 单输入:Yelp评论1~5星打分,标签词分别为terrible, bad,okay,good,great

img

  • 双输入:AG's News新闻四分类问题, 标签词分别为分类名称Worlds,Sports, Business, Science/Tech,

img

可以看出作者构建prompt模板的思路是尽可能还原文本所在的上下文场景,Answer词的选取是一对一的构建模式,每个label只选取一个词来表示。

固定prompt微调LM

完形填空式的prompt模板在微调时的优势,我认为主要有以下三点

  • 没有额外参数的引入,常规微调需要引入hidden_size * label_size的额外参数(classify head)作为每个标签对应的空间表征,这部分需要针对下游任务重头学习。而完形填空的token是在原始vocab中的,于是只需要调整标签词的预训练表征让它在label上线性可分即可

  • 前置语义信息的引入,因为标签词的选取本身符合label的原始语义,例如以上YELP评论打分中的5个形容词本身就是隐含了评论质量信息的,所以会引入部分前置信息,避免重头学习,这一点和MRC有些相似

  • 预训练和微调的一致性高,都是解决完形填空问题,学习目标一致

微调的损失函数是交叉熵,作者没有引入额外参数,而是把MASK位置上模型的预估logits在label上归一化来得到分类预测。例如上面的AG新闻分类任务,先得到MASK位置worlds,sports,business,science这四个词的预测logits,然后归一化得到预估概率,再和分类标签计算交叉熵。

为了避免灾难遗忘作者在下游任务微调时加入了预训练的MLM任务,于是微调的损失函数如下

半监督+蒸馏

这部分的设计可以和prompt的部分分开来看,是一个半监督方案。以上每个任务对应的多个prompt模板,分别固定prompt微调LM得到一版模型,然后在大量的未标注样本上进行预测,再对多个模型的预测值进行加权得到伪标签。

最终在为标签上使用常规的微调方案(加classifier head),训练模型作为输出,这一步类比知识蒸馏。所以PET最后输出的还是常规的监督微调模型,Prompt只是被当做了一种半监督方案。效果上在小样本的设定上比直接使用监督微调都有一定的效果提升。

img

作者还做了iPET对以上过程通过迭代逐步扩大数据集,提高伪标签准确率的方案,不过这么麻烦的实现一点都不适合我这种懒人,哈哈就不细说了~

针对PET有几点疑问

  • 完形填空类的prompt,在微调过程中可能的灾难遗忘,是否因为对label词的微调偏离了词在原始文本中语义表征,以及和其他词的相对位置

  • prompt模板差异带来的效果差异尚未解决,人工构建的prompt模板不一定是最优的

  • Answer词单token,以及和label一一对应的设定,限制性较强。这部分在后面的续作里作者做了改良

后面介绍的几个模型,大多是基于PET上述问题的改良~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/581041.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解密负载均衡:如何平衡系统负载(下)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

spring boot 增量包部署,jar包变小

##pom.xml配置 <plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId><configuration><layout>ZIP</layout><includes><include><groupId&…

云原生数据库性能对比(阿里云、百度智能云、腾讯云)

本文作者 LYZ 近些年&#xff0c;云原生数据库成为云厂商的重要发展方向&#xff0c;阿里云、百度智能云、腾讯云均先后发布了自研的云原生数据库。笔者认为云原生数据库具有更高的性价比、更极致的弹性&#xff0c;可以满足业务发展的不同阶段和负载场景的需求&#xff0c;也是…

FPGA高端项目:SDI 视频+音频编解码,提供工程源码和技术支持

目录 1、前言免责声明 2、相关方案推荐我这里已有的 GT 高速接口解决方案我目前已有的SDI编解码方案 3、设计思路和框架设计框图GV8601A均衡EQGTX 时钟要求GTX 调用与控制SMPTE SD/HD/3G-SDISMPTE SD/HD/3G-SDI 接收SMPTE SD/HD/3G-SDI 发送 SDI 视频接收数据处理SDI 音频接收-…

pycharm 工具栏不见了

新版pycharm后&#xff0c; 菜单栏和工具栏不见了 目录 我发现的解决方法&#xff1a; 其他旧版的解决方法&#xff1a; 我发现的解决方法&#xff1a; 其他旧版的解决方法&#xff1a; 另外&#xff0c;一些使用pycharm的新手可能会由于不熟悉软件的功能而误操作&#xff…

【头歌实训】PySpark Streaming 数据源

文章目录 第1关&#xff1a;MySQL 数据源任务描述相关知识PySpark JDBC 概述PySpark JDBCPySpark Streaming JDBC 编程要求测试说明答案代码 第2关&#xff1a;Kafka 数据源任务描述相关知识Kafka 概述Kafka 使用基础PySpark Streaming Kafka 编程要求测试说明答案代码 第1关&a…

scikit-learn文档中的数据生成器

目录 1. make_classification: 2. make_regression: 3. make_blobs: 4. make_moons: 5.make_circles 6. make_sparse_coded_signal: 1. make_classification: 这是一个用于生成复杂二维数据的函数&#xff0c;通常用于可视化分类器的学习过程或者测试机器学习算法的性能…

Kali Linux如何启动SSH并在Windows系统远程连接

文章目录 1. 启动kali ssh 服务2. kali 安装cpolar 内网穿透3. 配置kali ssh公网地址4. 远程连接5. 固定连接SSH公网地址6. SSH固定地址连接测试 简单几步通过[cpolar 内网穿透](cpolar官网-安全的内网穿透工具 | 无需公网ip | 远程访问 | 搭建网站)软件实现ssh 远程连接kali! …

工具系列:TimeGPT_(9)模型交叉验证

交叉验证 文章目录 交叉验证外生变量比较不同的模型 时间序列预测中的主要挑战之一是随着时间的推移固有的不确定性和变异性&#xff0c;因此验证所采用的模型的准确性和可靠性至关重要。交叉验证是一种强大的模型验证技术&#xff0c;特别适用于此任务&#xff0c;因为它提供了…

使用 GitHub 进行团队协作的操作指南

目录 前言1 使用github进行团队开发的意义2 邀请成员加入团队3 克隆和提交代码3.1 克隆远程仓库到本地3.2 加入暂存区3.3 提交修改到本地仓库3.4 设置本地仓库和远程仓库的关联3.5 将本地仓库的代码推送到远程仓库 结语 前言 GitHub 是一个广泛使用的基于 Git 的代码托管平台&…

Java - 获取 Jar 包内的 pom.xml 文件

目录 一.引言 二.通过 jar 命令 ◆ 查看 Jar 包内文件 ◆ 导出 Pom.xml ◆ 导出 Jar 包内文件 三.通过 unzip 命令 ◆ 导出 Jar 包内文件 四.总结 一.引言 引用其他同学的 Jar 包时&#xff0c;需要获取其对应 jar 包内的 pom.xml 文件检查版本依赖关系&#xff0c;下…

MYSQL存储过程和存储函数-数据库实验五

Mysql数据库实验及练习题相关 MySQL 数据库和表的管理-数据库实验一 MySQL连接查询、索引、视图-数据库实验二、实验三 MySQL约束、触发器-数据库实验四 MYSQL存储过程和存储函数-数据库实验五 MySQL批量随机生成name、TEL、idNumber MYSQL数据库的安全管理-数据库实验六 MYSQ…

基于JetCache整合实现一级、二级缓存方案(方案实现)

目录 一、整体方案说明 1.1 需求说明 1.2 整体方案实现组件结构图 二、Caffeine缓存实现 2.1 组件说明 2.2 组件结构图 2.3 组件Maven依赖 2.4 组件功能实现源码 2.4.1 CaffeineCacheManager扩展实现 2.4.2 CaffeineConfiguration配置类实现 2.4.3 涉及其他组件的类 …

如何在Android Termux中使用SFTP实现远程传输文件

文章目录 1. 安装openSSH2. 安装cpolar3. 远程SFTP连接配置4. 远程SFTP访问5. 配置固定远程连接地址6、结语 SFTP&#xff08;SSH File Transfer Protocol&#xff09;是一种基于SSH&#xff08;Secure Shell&#xff09;安全协议的文件传输协议。与FTP协议相比&#xff0c;SFT…

Spring Boot 中的虚拟线程

在本文中&#xff0c;我将讨论 Spring Boot 中的虚拟线程。 什么是虚拟线程&#xff1f; 虚拟线程作为 Java 中的一项功能引入&#xff0c;旨在简化并发性。 Virtual threads 是 轻量级的线程&#xff0c;由 Java Virtual Machine 而不是操作系统管理。它们被设计为易于使用且…

ElasticSearch:centos7安装elasticsearch7,kibana,ik中文分词器,云服务器安装elasticsearch

系统&#xff1a;centos7 elasticsearch: 7.17.16 安装目录&#xff1a;/usr/local 云服务器的安全组&#xff1a;开放 9200 和5601的端口 一、下载安装elasticsearch7.17.16 1、安装 #进入安装目录 cd /usr/local#下载elasticsearch wget https://artifacts.elastic.co/d…

Elasticsearch:在不停机的情况下优化 Elasticsearch Reindex

实现零停机、高效率和成功迁移更新的指南。更多阅读&#xff1a;Elasticsearch&#xff1a;如何轻松安全地对实时 Elasticsearch 索引 reindex 你的数据。 在使用 Elasticsearch 的时候&#xff0c;总会有需要修改索引映射的时候&#xff0c;遇到这种情况&#xff0c;我们只能做…

前端实现websocket类封装

随着Web应用程序的发展&#xff0c;越来越多的人开始利用Websocket技术来构建实时应用程序。Websocket是一种在客户端和服务器之间建立持久连接的协议。这种协议可以在一个单独的连接上实现双向通信。与HTTP请求-响应模型不同&#xff0c;Websocket允许服务器自主地向客户端发送…

想要学会JVM调优,先掌握JVM内存模型和JVM运行原理

1、前言 今天将和你一起探讨Java虚拟机&#xff08;JVM&#xff09;的性能调优。 JVM算是面试中的高频问题了&#xff0c;通常情况下总会有人问到&#xff1a;请你讲解下 JVM 的内存模型&#xff0c;JVM 的 性能调优做过&#xff1f; 2、为什么 JVM 在 Java 中如此重要 首…

利用网络教育系统构建个性化学习平台

在现代教育中&#xff0c;网络教育系统作为一种创新的学习方式&#xff0c;为学生提供了更加个性化和灵活的学习体验。在本文中&#xff0c;我们将通过简单的技术代码&#xff0c;演示如何构建一个基础的网络教育系统&#xff0c;为学生提供个性化的学习路径和资源。 1. 环境…