大语言模型控制生成的过程Trick:自定义LogitsProcessor实践

前言

在大模型的生成过程中,部分原生的大语言模型未经过特殊的对齐训练,往往会“胡说八道”的生成一些敏感词语等用户不想生成的词语,最简单粗暴的方式就是在大模型生成的文本之后,添加敏感词库等规则手段进行敏感词过滤,但是在生成过程中,生成敏感词仍然耗费了时间和算力成本。

本文以chatglm2-6B为例,通过自定义LogitsProcessor,实践大模型在生成过程中控制一些词语的生成。

LogitsProcessor

从下面代码可以看到,LogitsProcessor的作用就是在生成过程中修改score,改变模型输出的概率分布的工具。

class LogitsProcessor:"""Abstract base class for all logit processors that can be applied during generation."""@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:raise NotImplementedError(f"{self.__class__} is an abstract class. Only classes inheriting this class can be called.")class LogitsProcessorList(list):"""This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each[`LogitsProcessor`] or [`LogitsWarper`] to the inputs."""def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:r"""Args:input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):Prediction scores of a language modeling head. These can be logits for each vocabulary when not usingbeam search or log softmax for each vocabulary token when using beam searchkwargs (`Dict[str, Any]`, *optional*):Additional kwargs that are specific to a logits processor.Return:`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:The processed prediction scores."""for processor in self:function_args = inspect.signature(processor.__call__).parametersif len(function_args) > 2:if not all(arg in kwargs for arg in list(function_args.keys())[2:]):raise ValueError(f"Make sure that all the required parameters: {list(function_args.keys())} for "f"{processor.__class__} are passed to the logits processor.")scores = processor(input_ids, scores, **kwargs)else:scores = processor(input_ids, scores)return scores

自定义LogitsProcessor实践

回到正题,如何自定义LogitsProcessor控制大模型生成的过程呢?下面直接上实践代码:

class new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scores

forbid_token_id_list是不让模型生成词语的id映射列表,对于这些抑制生成的词语,在自定义logits_processor时将其概率推向负无穷大即可。

chatglm2-6B详细实践代码:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from typing import List
import torchclass new_logits_processor(LogitsProcessor):def __init__(self, forbid_token_id_list: List[int] = None):self.forbid_token_id_list = forbid_token_id_listdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:for id_ in self.forbid_token_id_list:scores[:, id_] = -float('inf')return scoresmodel_path = "THUDM/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True).to('mps')def add_forbid_words():'''添加需要抑制的词语,这里简单添加了数字和几个词语进行对比:return:list'''forbid_words = []for i in range(10):forbid_words.append(tokenizer.convert_tokens_to_ids(str(i)))forbid_words.append(tokenizer.convert_tokens_to_ids("首先"))forbid_words.append(tokenizer.convert_tokens_to_ids("积极"))forbid_words.append(tokenizer.convert_tokens_to_ids("回答"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇敢"))forbid_words.append(tokenizer.convert_tokens_to_ids("勇气"))return forbid_wordslogits_processor = LogitsProcessorList()
logits_processor.append(new_logits_processor(add_forbid_words()))streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)input = "列举出10个积极的词语:"outputs = model.generate(tokenizer(input, return_tensors='pt').input_ids.to("mps"),max_new_tokens=1024,logits_processor=logits_processor,  # 不开启注释即可streamer=streamer
)
decode_text = tokenizer.batch_decode(outputs, streamer=streamer)[0]
print(decode_text)

抑制前输出:

1. 勇敢
2. 快乐
3. 成功
4. 努力
5. 积极
6. 乐观
7. 自信
8. 开朗
9. 团结
10. 奋斗

抑制后输出:

- 积极主动
- 乐观向上
- 自信
- 自律
- 诚实守信
- 乐于助人
- 勇于尝试
- 坚韧不拔
- 乐观开朗
- 团结一心

小结

本文通过自定义LogitsProcessor,简单的实践了大语言模型在生成过程中屏蔽生成用户自定义词语的trick。在现实场景中,根据特定场景探索如何灵活的利用LogitsProcessor进行有针对性的控制生成模型的生成过程非常重要。

参考文献

【1】https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/generation/logits_process.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/39913.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

30行JS代码带你手写自动回复语音聊天机器人

🥂(❁◡❁)您的点赞👍➕评论📝➕收藏⭐是作者创作的最大动力🤞 前言 现如今生活中到处都是聊天机器人的身影,聊天机器人不仅仅能减少人工的聊天压力,而且十分的可爱有趣,安卓系统的小AI&#xf…

Springboot整合Mybatis调用Oracle存储过程

1、配置说明 Oracel11g+springboot2.7.14+mybatis3.5.13 目标:springboot整合mybatis访问oracle中的存储过程,存储过程返回游标信息。 mybatis调用oracle中的存储过程方式 2、工程结构 3、具体实现 3.1、在Oracle中创建测试数据库表 具体数据可自行添加 create table s…

Lodash——使用与实例

1. 简介 Lodash是一个一致性、模块化、高性能的JavaScript实用库。Lodash通过降低array、number、objects、string等等的使用难度从而让JavaScript变得简单。Lodash的模块方法,非常适用于: 遍历array、object 和 string对值进行操作和检测创建符合功能的…

字符个数统计(同类型只统计一次)

思路:因为题目圈定出现的字符都是 ascii 值小于等于127的字符,因此只需要定义一个标记数组大小为128 ,然后将字符作为数组下标在数组中进行标记,若数组中没有标记过表示第一次出现,进行计数,否则表示重复字…

简单线性回归:预测事物间简单关系的利器

文章目录 🍀简介🍀什么是简单线性回归?🍀简单线性回归的应用场景使用步骤:注意事项: 🍀代码演示🍀结论 🍀简介 在数据科学领域,线性回归是一种基本而强大的统…

Kali Linux助您网络安全攻防实战

Kali Linux:黑客与防御者的神器 Kali Linux是一款专为网络安全测试和攻防实践而设计的操作系统。它汇集了大量的安全工具,可以用于渗透测试、漏洞扫描、密码破解等任务,不仅为黑客提供了强大的攻击能力,也为安全防御者提供了测试和…

Kafka 入门到起飞 - 什么是 HW 和 LEO?何时更新HW和LEO呢?

上文我们已经学到, 一个Topic(主题)会有多个Partition(分区)为了保证高可用,每个分区有多个Replication(副本)副本分为Leader 和 Follower 两个角色,Follower 从Leader同…

爬虫逆向实战(十八)--某得科技登录

一、数据接口分析 主页地址:某得科技 1、抓包 通过抓包可以发现数据接口是AjaxLogin 2、判断是否有加密参数 请求参数是否加密? 查看“载荷”模块可以发现有一个password加密参数和一个__RequestVerificationToken 请求头是否加密? 无…

【Linux】Reactor模式

Reactor模式 Reactor模式的定义 Reactor反应器模式,也叫做分发者模式或通知者模式,是一种将就绪事件派发给对应服务处理程序的事件设计模式。 Reactor模式的角色构成 Reactor主要由以下五个角色构成: reactor模式的角色 角色解释Handle(句…

保姆级别讲解Python数据处理,你绝对能会

名字:阿玥的小东东 学习:Python、C/C 主页链接:阿玥的小东东的博客_CSDN博客-python&&c高级知识,过年必备,C/C知识讲解领域博主 目录 1. 文件读取 2. 数据处理 3. 处理结果输出 总的来说 为了咱们让程序跑起来,我们需…

DAY3,ARM(LED点灯实验)

1.汇编实现开发板三盏灯点亮熄灭&#xff1b; .text .global _start _start: /**********LED123点灯**************/RCC_INIT:1使能PE10 PF10 PE8RCC..寄存器,E[4]1 F[5]1 0x50000a28ldr r0,0x50000a28ldr r1,[r0]orr r1,r1,#(0x3 << 4)str r1,[r0]LED1_INET:2初始化LED…

酷开系统 | 酷开科技大数据,更好的与目标消费人群建立联系

众所周知&#xff0c;OTT的一大优势在于强曝光&#xff0c;能够给消费者带来强烈的视觉冲击&#xff0c;强化品牌认知。但是&#xff0c;要想达到提升品牌认知&#xff0c;首先要保证OTT的流量规模&#xff0c;实现对目标人群的有效覆盖。得年轻消费者得“天下”&#xff0c;年…

tk切换到mac的code分享

文章目录 前言一、基础环境配置二、开发软件与扩展1.用到的开发软件与平替、扩展情况 总结 前言 最近换上了coding人生的第一台mac&#xff0c;以前一直偏好tk&#xff0c;近来身边的朋友越来越多的用mac了&#xff0c;win的自动更新越来越占磁盘了&#xff0c;而且win11抛弃了…

vue elementui v-for 循环el-table-column 第一列数据变到最后一个

这个动态渲染table表格时发现el-table-column 第一列数据变到最后一个 序号被排到后面 代码 修改后 <el-table:data"tableData"tooltip-effect"dark"style"width: 100%"height"500"><template v-for"(item, index) i…

PostCSS在vue中的使用

1、安装 PostCSS 和所需的插件。在命令行中运行以下命令&#xff1a; npm install postcss autoprefixer cssnano postcss-pxtorem --save-dev 这将安装 PostCSS、Autoprefixer、CSSnano 和 postcss-pxtorem 插件&#xff0c;同时将它们添加到项目的开发依赖中。 2、在项目根目…

每天一道leetcode:1926. 迷宫中离入口最近的出口(图论中等广度优先遍历)

今日份题目&#xff1a; 给你一个 m x n 的迷宫矩阵 maze &#xff08;下标从 0 开始&#xff09;&#xff0c;矩阵中有空格子&#xff08;用 . 表示&#xff09;和墙&#xff08;用 表示&#xff09;。同时给你迷宫的入口 entrance &#xff0c;用 entrance [entrancerow, …

SpringBoot的配置文件(properties与yml)

文章目录 1. 配置文件的作用2. 配置文件格式3. 配置文件的使用方法3.1. properties配置文件3.1.1. 基本语法和使用3.1.2. properties优缺点分析 3.2. yml配置文件3.2.1. 基本语法与使用3.2.2. yml中单双引号问题3.2.3. yml配置不同类型的数据类型及null3.2.4. 配置对象3.2.5. 配…

android设置竖屏仍然跟随屏幕旋转怎么办

如题所问&#xff0c;我最近遇到一个bug&#xff0c;就是设置了摇感&#xff0c;然后有用户反馈说设置了手机下拉的系统设置-屏幕旋转-关闭。然后屏幕还是会旋转的问题。 首先&#xff0c;我们先从如何设置横竖屏了解下好了 设置横屏和竖屏的方法&#xff1a; 方法一&#x…

uni-app引入sortable列表拖拽,兼容App和H5,拖拽排序。

效果: 拖拽排序 背景&#xff1a; 作为一名前端开发人员&#xff0c;在工作中难免会遇到拖拽功能&#xff0c;分享一个github上一个不错的拖拽js库&#xff0c;能满足我们在项目开发中的需要&#xff0c;下面是我在uniapp中使用SortableJS的使用详细流程&#xff1b; vue开发…

Centos7安装docker后默认开启docker0的网卡|卸载默认网卡

一&#xff1a; 停掉服务 systemctl stop docker [rootwww ~]# systemctl stop docker [rootwww ~]# systemctl status docker ● docker.service - Docker Application Container Engine Loaded: loaded (/usr/lib/systemd/system/docker.service; enabled; vendor prese…