以Llama-2为例,在生成模型中使用自定义StoppingCriteria

以Llama-2为例,在生成模型中使用自定义StoppingCriteria

  • 1. 前言
  • 2. 场景介绍
  • 3. 解决方法
  • 4. 结语

1. 前言

在之前的文章中,介绍了使用transformers模块创建的模型,其generate方法的详细原理和使用方法,文章链接:

以beam search为例,详解transformers中generate方法(上)
以beam search为例,详解transformers中generate方法(下)

其中提到了用户参与生成过程的两个关键组件,logits_processorstopping_criteria,使用这两个类,是用户控制生成过程的主要手段。其中,logits_processor用来在生成过程中,根据用户设置的指定规则,强行修改当前step在词表空间上的概率分布,而stopping_criteria,根据用户所规定的规则来中止生成。

这两个组件在transformers模块中都有一些预设的类可以直接使用,预设类的基本信息介绍可参考以beam search为例,详解transformers中generate方法(上)。

本文将结合实际应用场景,介绍用户如何根据自己的需求来设计并实现一个自定义的stopping_criteria,来控制生成过程提前结束。

2. 场景介绍

这次介绍的场景是,使用Llama-2的生成能力对一段新闻进行概括,希望能够生成一句简短的话,来概括新闻中发生的最核心的事情。

通过给定对话背景,结合历史样例的方式,希望Llama-2能够输出期望的结果。

对话的prompt构造方法可以参考之前的内容:NLP实践——Llama-2 多轮对话prompt构建。

然而,即便是采用了in-context learning的方式,Llama-2生成的结果仍然过于冗长。

例如对于这样一篇新闻:

text = """, Photo Credit : Associated Press Four air crew members were missing after an Australian army helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States, officials said Saturday. The MRH-90 Taipan helicopter went down near Lindeman Island, a Great Barrier... ..."""  
# 后边忽略若干内容

模型生成的结果为:

Four Australian army air crew members are missing after an Australian army MRH-90 Taipan helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States. The helicopter went down near Lindeman Island, a Great Barrier Reef tourist resort, at around 11 pm on Friday. A search involving US, Canadian, and Australian personnel is underway to find the missing crew, who are all Australian men. Debris that appeared to be from a helicopter has been recovered, according to Queensland Police Assistant Commissioner Douglas McDonald. The Taipan was taking part in Talisman Sabre, a biennial joint US-Australian military exercise that is largely based in Queensland. This year's exercise involves 13 nations and over 30,000 military personnel. Defense Minister Richard Marles said the helicopter ditched, which refers to an emergency landing on water. He added that defense exercises, which are so necessary for the readiness of our defense force, are serious and carry risk. US Defense Secretary... ...
# 后边忽略若干内容

可以看出,并不是模型生成的结果不好,但是它太啰嗦了,而对于我的需求而言,模型只需要输出其中的第一句话就足够了。

这时候可能有人就会觉得:“那我分句然后把第一句话保留下来不就好了?”

——这样做虽然也可以达成效果,但是这个生成过程,时间和算力已经被消耗了。

所以需要采取方法,让模型在生成到第一个句号的时候,就停止生成,返回结果。于是就需要用到今天的主角——Stopping Criteria。

3. 解决方法

transformers模块中内置了几个默认的stopping criteria,然而,在很多情况下,它们并不能满足需求,这时,就需要创建自定义的stopping criteria。

首先需要引用基类:

from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings

其中,

  • StoppingCriteriaList是一个容器,需要将所有的criteria都添加到其中,generate时传入的是这个容器;
  • StoppingCriteria是基础类,自定义的criteria需要继承这个基础类。

接下来就实现一个criteria,效果是,遇到指定的token时,就停止生成:

class StopAtSpecificTokenCriteria(StoppingCriteria):"""当生成出第一个指定token时,立即停止生成---------------ver: 2023-08-02by: changhongyu"""def __init__(self, token_id_list: List[int] = None):""":param token_id_list: 停止生成的指定token的id的列表"""self.token_id_list = token_id_list@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:# return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list# 储存scores会额外占用资源,所以直接用input_ids进行判断return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list

那么,如果希望遇到句号就停止生成,那就用句号对应的token_id去实例化一个这样的stopping criteria,并将它添加到容器中:

# Llama-2的词表中,英文句号的id是29889
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[29889]))

然后,在生成的时候,假如原本的生成指令是:

model.generate(**inputs)

那么再把stopping criteria作为参数传入进去,就可以发挥效果了:

model.generate(stopping_criteria=stopping_criteria, **inputs)

4. 结语

Stopping Criteria用于在每一个step的生成结束时,判断生成过程是否要结束,是用户控制生成过程的有效手段,其发挥作用的方式也比较直接,实现自定义criteria也并不复杂,只需要确保该类的调用方法返回值是bool值,并覆盖全部情况即可。

Logits Processor是用户控制生成的另一个有效工具,在接下来的博客中,还将介绍自定义logits processor是如何使用的,欢迎感兴趣的同学继续关注。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/21221.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

动态IP代理的工作原理

动态IP代理在现今的网络环境中扮演着重要角色。它不仅能够帮助用户绕过网站的访问限制和封锁,还可以保护用户的隐私和匿名性。本文将深入分析动态IP代理的工作原理,解释其关键要素,并探讨其在现代网络中的应用。 在互联网时代,随着…

GaussDB数据库的元数据及其管理简介

目录 一、前言 二、元数据简介 1、元数据定义 2、元数据分类 3、数据库元数据管理 三、GaussDB数据库的元数据管理 1、GaussDB数据库的元数据管理 2、通过“SQL 系统表/系统视图/系统函数”的方式管理(采集)元数据 1)获取表、视图及…

flutter开发实战-实现css线性渐变转换flutter渐变LinearGradient功能

flutter开发实战-实现css线性渐变转换flutter渐变LinearGradient功能 在之前项目开发中,遇到更换样式,由于从服务器端获取的样式均为css属性值,需要将其转换成flutter类对应的属性值。这里只处理线性渐变linear-gradient 比如渐变 “linear-…

全球模拟器市场规模,行业洞察

模拟器适用于飞机、汽车、航船驾驶训练、教育和研究的装备。汽车模拟器,交通安全教育宣传基地推荐设备、新机动车驾驶学员的培训,道路训练,大中专院校、中学以及驾校的素质教育与劳动技能教育。飞行仿真器 (flight simulator) 在地面模仿飞机…

【Linux操作系统】网络配置详解:从原理到实践(详细通俗讲明DNS)

导语:网络配置是Linux系统中的一项重要任务,合理的网络配置可以保证计算机与其他设备的正常通信。本文将详细介绍Linux网络配置的原理和实践,包括网络配置原理、查看网络IP和网关、测试网络连通性、网络环境配置、设置主机名和hosts映射以及主…

win10笔记本显示器根据页面显示亮度自动调节亮度的问题

系统是win10企业版,针对这个问题查了很多种方法,比如: 1、控制面板->硬件和声音->电源选项->点击当前电源计划的更改计划设置->更改高级电源设置->显示->启用自适应亮度 但是我发现我的电源计划只有平衡这一种&#xff0c…

windows运行WPscan报错:无法打开库libcurl.dll

windows运行WPscan报错:无法打开库libcurl.dll 1.问题背景2.解决方案1.问题背景 在Windows上启动WPScan时: wpscan --url xxx.ru提示如下错误: Could not open library libcurl.dll: �� ������ ��������� ������. . Could not open library libcu

selenium官文文档阅读总结(day 3)

1.关联型xpath的用法 driver.find_element(By.XPATH,//a[text()"xxx"]/ancestor::祖先元素的标签名//……) 2.selenium等待 等待的作用 :在系统运行的过程中,等待网页内容的加载显示。需要耗费的时间,与网络速度、接口的复杂程度…

申请科技型中小企业的好处有哪些?

科技型中小企业,这是由国家出台的,科技部认定的,对中小型企业的一种荣誉。这种企业是有一定数量的科技人员从事科技研究开发,有了一定成果并转化为高新技术产品或服务,实现可持续发展的中小企业。 申请科技型中小企业有…

【雕爷学编程】 MicroPython动手做(38)——控制触摸屏

MixPY——让爱(AI)触手可及 MixPY布局 主控芯片:K210(64位双核带硬件FPU和卷积加速器的 RISC-V CPU) 显示屏:LCD_2.8寸 320*240分辨率,支持电阻触摸 摄像头:OV2640,200W像素 扬声器&#…

【Golang 接口自动化08】使用标准库httptest完成HTTP请求的Mock测试

目录 前言 http包的HandleFunc函数 http.Request/http.ResponseWriter httptest 定义被测接口 测试代码 测试执行 总结 资料获取方法 前言 Mock是一个做自动化测试永远绕不过去的话题。本文主要介绍使用标准库net/http/httptest完成HTTP请求的Mock的测试方法。 可能有…

【安装】阿里云轻量服务器安装Ubuntu图形化界面(端口号/灰屏问题)

阿里云官网链接 https://help.aliyun.com/zh/simple-application-server/use-cases/use-vnc-to-build-guis-on-ubuntu-18-04-and-20-04 网上搜了很多教程,但是我没在界面看到有vnc连接,后面才发现官网有教程。 其实官网很详细了,不过这里还是…

题解 | #1001.Count# 2023杭电暑期多校6

1001.Count 签到 题目大意 给定 n , m , k n,m,k n,m,k ,构造长度为 n n n 的整数序列,元素大小范围为 a i ∈ [ 1 , m ] a_i\in [1,m] ai​∈[1,m] ,并且需要保证前 k k k 个元素和后 k k k 个元素对应相同 求可以构造出的序列数量 …

Flink Windows(窗口)详解

Windows(窗口) Windows是流计算的核心。Windows将流分成有限大小的“buckets”,我们可以在其上应用聚合计算(ProcessWindowFunction,ReduceFunction,AggregateFunction或FoldFunction)等。在Fl…

【java安全】无Commons-Collections的Shiro550反序列化利用

文章目录 【java安全】无Commons-Collections的Shiro550反序列化利用Shiro550利用的难点CommonsBeanutils1是否可以Shiro中?什么是serialVersionUID?W 无依赖的Shiro反序列化利用链POC 【java安全】无Commons-Collections的Shiro550反序列化利用 Shiro5…

整数拆分——力扣343

文章目录 题目描述法一 动态规划法二 动态规划优化法三 数学 题目描述 法一 动态规划 int integerBreak(int n) {vector<int> dp(n1);for(int i2;i<n;i){int curMax 0;for(int j1;j<i;j){curMax max(curMax, max(j*(i-j), j*dp[i-j]));}dp[i] curMax;} return d…

AI赋能转型升级 助力打造“数智辽宁”——首次大模型研讨沙龙在沈成功举行

当前&#xff0c;以“ChatGPT”为代表的大模型正在引领新一轮全球人工智能技术发展浪潮&#xff0c;推动人工智能从以专用小模型定制训练为主的“手工作坊时代”&#xff0c;迈入以通用大模型预训练为主的“工业化时代”&#xff0c;正不断加速实体经济智能化升级&#xff0c;深…

MyBatis的三级缓存

MyBatis的三级缓存 一、什么是MyBatis的三级缓存&#xff1f; MyBatis的三级缓存指的是一级缓存、二级缓存和三级缓存。 缓存是一种提高数据读取性能的技术&#xff0c;在MyBatis中&#xff0c;一级缓存指的是Session缓存&#xff0c;二级缓存指的是Mapper级的缓存&#xff…

变量方法常用命名

文件路径&#xff1a; \src\locales\lang\zh-CN 在多语言应用程序中&#xff0c;通常会将不同语言的翻译文本存储在不同的文件中&#xff0c;这样可以方便地管理和维护多个语言版本。locales目录一般用于存放所有的语言文件&#xff0c;lang目录则用于存放特定语言的文件。 具体…

C语言编程技巧 全局变量在多个c文件中公用的方法

在使用C语言编写程序时&#xff0c;经常会遇到这样的情况&#xff1a;我们希望在头文件中定义一个全局变量&#xff0c;并将其包含在两个不同的C文件中&#xff0c;以便这个全局变量可以在这两个文件中共享。举个例子&#xff0c;假设项目文件夹"project"下有三个文件…