NLP实践——LLM生成过程中防止重复循环

NLP实践——LLM生成过程中防止重复

  • 1. 准备工作
  • 2. 问题分析
  • 3. 创建processor
    • 3.1 防止重复生成的processor
    • 3.2 防止数字无规则循环的processor
  • 4. 使用

本文介绍如何使用LogitsProcessor避免大模型在生成过程中出现重复的问题。

1. 准备工作

首先实例化一个大模型,以GLM2为例:

import re
import os
import json
import random
from typing import *
from copy import deepcopyimport torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList, MaxNewTokensCriteria, StoppingCriteria

创建模型:

tokenizer = AutoTokenizer.from_pretrained(".../ChatGLM2/", trust_remote_code=True)
model = AutoModel.from_pretrained(".../ChatGLM2/", trust_remote_code=True).half()
model.to('cuda:0')

2. 问题分析

接下来思考一下,如何防止模型不停的重复呢?重复分为几种情况,一个字符循环出现,或者多个字符循环出现,例如:

'abcdeeeeee'
'abcdededede'

从生成的过程来考虑,防止模型生成重复的内容,第一步自然是要判断模型陷入了重复,第二步就是打断它重复的过程,也就是将重复的token,在当前step生成的时候,将其概率设置为-inf,那么重复的过程自然就停止了。

3. 创建processor

3.1 防止重复生成的processor

先来解决如何判定重复。这里直接去leetcode上找一个题,获取一个字符串中最大的重复片段,解法如下:

def longest_dup_substring(s: str) -> str:# 生成两个进制a1, a2 = random.randint(26, 100), random.randint(26, 100)# 生成两个模mod1, mod2 = random.randint(10**9+7, 2**31-1), random.randint(10**9+7, 2**31-1)n = len(s)# 先对所有字符进行编码arr = [ord(c)-ord('a') for c in s]# 二分查找的范围是[1, n-1]l, r = 1, n-1length, start = 0, -1while l <= r:m = l + (r - l + 1) // 2idx = check(arr, m, a1, a2, mod1, mod2)# 有重复子串,移动左边界if idx != -1:l = m + 1length = mstart = idx# 无重复子串,移动右边界else:r = m - 1return s[start:start+length] if start != -1 else ""def check(arr, m, a1, a2, mod1, mod2):n = len(arr)aL1, aL2 = pow(a1, m, mod1), pow(a2, m, mod2)h1, h2 = 0, 0for i in range(m):h1 = (h1 * a1 + arr[i]) % mod1h2 = (h2 * a2 + arr[i]) % mod2# 存储一个编码组合是否出现过seen = {(h1, h2)}for start in range(1, n - m + 1):h1 = (h1 * a1 - arr[start - 1] * aL1 + arr[start + m - 1]) % mod1h2 = (h2 * a2 - arr[start - 1] * aL2 + arr[start + m - 1]) % mod2# 如果重复,则返回重复串的起点if (h1, h2) in seen:return startseen.add((h1, h2))# 没有重复,则返回-1return -1

效果如下:

longestDupSubstring('埃尔多安经济学可以重振经济,土耳其土耳其')
# '土耳其'

那么我们就可以写一个processor,在每一个step即将生成的时候,判定一下,是否之前已经生成的结果中,出现了重复。以及,如果出现了重复,则禁止重复部分的第一个token(例如上面例子中,土耳其的土字),在当前step被生成。

针对实际使用中由这个processor引发的一些其他的问题,我又对这个processor增加了一点规则限制,一个比较好用的版本如下。

其中的参数threshold是判断重复多少的情况算作循环,例如将threshold设置为10,那么如果重复部分的长度是3,重复了3次,3×3=9,则不被判定为陷入了循环,而如果重复了4次,3×4=12,则被判定为循环,此时processor将发挥效果了。

class ForbidDuplicationProcessor(LogitsProcessor):"""防止生成的内容陷入循环。当循环内容与循环次数之乘积大于指定次数则在生成下一个token时将循环内容的第一个token概率设置为0---------------ver: 2023-08-17by: changhongyu"""def __init__(self, tokenizer, threshold: int = 10):self.tokenizer = tokenizerself.threshold = thresholddef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:current_sequence = self.tokenizer.decode(input_ids[0][current_token_len: ])current_dup_str = longest_dup_substring(current_sequence)if len(current_dup_str):# 如果存在重复子序列,则根据其长度与重复次数判断是否禁止循环if len(current_dup_str) > 1 or (len(current_dup_str) == 1 and current_dup_str * self.threshold in current_sequence):if len(current_dup_str) * current_sequence.count(current_dup_str) >= self.threshold:token_ids = self.tokenizer.encode(current_dup_str)# 获取截止目前的上一个tokenlast_token = input_ids[0][-1].detach().cpu().numpy().tolist()if len(token_ids) and last_token == token_ids[-1]:# 如果截止目前的上一个token,与重复部分的最后一个token一致# 说明即将触发重复, 先把重复部分的第一个token禁掉scores[:, token_ids[0]] = 0# 然后按出现比率判断是否重复部分内还有其他重复for token_id in token_ids:if token_ids.count(token_id) * len(token_ids) > 1.2:scores[:, token_id] = 0return scores

需要注意的是,为了获取当前的序列已经生成的长度,需要在processor的外部,也就是与model.generate同级的结构处,定义一个全局变量current_token_len

global current_token_len

3.2 防止数字无规则循环的processor

出了上述的情况,还有一种常见的循环,无法利用上面的规则解决,即数字无规则循环的情况。针对这个场景,创建另一个processor,只要连续出现的数字出现次数,大于一定的阈值,则禁止当前step再次生成数字。

class MaxConsecutiveProcessor(LogitsProcessor):"""给定一个集合,集合中的字符最多连续若干次下一次生成时不能再出现该集合中的字符---------------ver: 2023-08-17by: changhongyu---------------修复bugver: 2023-09-11"""def __init__(self, consecutive_token_ids, max_num: int = 10):self.consecutive_token_ids = consecutive_token_idsself.max_num = max_numdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:input_ids_list = input_ids.squeeze(0).detach().cpu().numpy().tolist()cur_num = 0for token in input_ids_list[::-1]:if token in self.consecutive_token_ids:cur_num += 1else:breakif cur_num >= self.max_num:# 如果连续次数超过阈值,那集合中的所有token在下一个step都不可以再出现for token_id in self.consecutive_token_ids:scores[..., token_id] = 0return scores

4. 使用

使用方法非常简单,首先创建processor容器。对processor不熟悉的同学,可以去看之前的文章,有非常详细的介绍。

logits_processor = LogitsProcessorList()

然后对于ChatGLM而言,需要先添加其默认的processor:

logits_processor.append(InvalidScoreLogitsProcessor())

接下来,再添加防止陷入循环的两个processor:

number_tokens = [str(i) for i in range(10)] + ['.', '-']
number_token_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in number_tokens]
logits_processor.append(ForbidDuplicationProcessor(tokenizer))
logits_processor.append(MaxConsecutiveProcessor(number_token_ids))

最后在调用generate的时候,把logits_processor作为参数传进去就可以了。

以上便是使用logits_processor来防止大模型在生成过程中陷入循环的方法。经过我的反复调整,基本可以覆盖大多数情景,如果在使用中遇到了bug,也欢迎指出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/166756.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实时语音克隆:5 秒内生成任意文本的语音 | 开源日报 No.84

CorentinJ/Real-Time-Voice-Cloning Stars: 43.3k License: NOASSERTION 这个开源项目是一个实时语音克隆工具&#xff0c;可以在5秒内复制一种声音&#xff0c;并生成任意文本的语音。 该项目的主要功能包括&#xff1a; 从几秒钟的录音中创建声纹模型根据给定文本使用参考…

数字化转型没钱?没人?没IT?低代码平台轻松帮你搞定

随着数字技术的不断渗透&#xff0c;数字化已经不仅仅是一个趋势&#xff0c;而是深入人心的日常生活部分。在这样的时代背景下&#xff0c;企业面临的挑战也愈发严峻&#xff1a;如何不断创新&#xff0c;满足用户日益增长的业务需求&#xff1f; 传统的开发方式&#xff0c;随…

基于单片机设计的大气气压检测装置(STC89C52+BMP180实现)

一、前言 本项目设计一个大气气压检测装置&#xff0c;该装置以单片机为基础&#xff0c;采用STC89C52作为核心控制芯片&#xff0c;结合BMP180模块作为气压传感器。大气气压&#xff0c;也就是由气体重力在大气层中产生的压力&#xff0c;其变化与天气预报、气象观测以及高度…

江苏某市人民医院实现IT基础资源统一监控

一、背景介绍 江苏某市人民医院是一家拥有丰富医疗资源和庞大患者群体的医疗机构。随着医疗业务的不断发展&#xff0c;其IT系统的规模和复杂性也不断增加&#xff0c;涉及各类IT资源&#xff0c;包括服务器、网络设备、数据库、应用软件等。为了提高IT系统的可靠性和稳定性&am…

11.7统一功能处理

一.登录拦截器 1.实现一个普通的类,实现HeadlerInterceptor接口,重写preHeadler方法. 2.将拦截器添加到配置中,并设定拦截规则. 二.访问前缀添加 方法1: 方法2:properties 三.统一异常处理 以上返回的是空指针异常,如果是别的异常就不会识别,建议加上最终异常 . 四.统一数据格…

英语学习软件 Eudic欧路词典 mac中文版介绍说明

欧路词典 mac (Eudic) 是一个功能强大的英语学习工具&#xff0c;它包含了丰富的英语词汇、短语和例句&#xff0c;并提供了发音、例句朗读、单词笔记等功能。 Eudic欧路词典 mac 软件介绍 多语种支持&#xff1a;欧路词典支持多种语言&#xff0c;包括英语、中文、日语、法语…

uni微信小程序 map 添加padding

问题背景&#xff1a; 规划驾车线路的时候&#xff0c;使用uni的include-points指定可视范围的时候&#xff0c;会很极限。导致marker不能完全显示。 解决方法 给地图显示范围添加padding (推荐) <mapid"myMap":markers"markers":polyline"pol…

视频服务网关的三大部署(二)

视频网关是软硬一体的一款产品&#xff0c;可提供多协议&#xff08;RTSP/ONVIF/GB28181/海康ISUP/EHOME/大华、海康SDK等&#xff09;的设备视频接入、采集、处理、存储和分发等服务&#xff0c; 配合视频网关云管理平台&#xff0c;可广泛应用于安防监控、智能检测、智慧园区…

spark写入关系型数据库的duplicateIncs参数使用

在看一段spark写数据到关系型数据库代码时&#xff0c;发现一个参数没有见过&#xff1a; df.write.format("org.apache.spark.sql.execution.datasources.jdbc2").options(Map("savemode" -> JDBCSaveMode.Update.toString,"driver" -> …

Android13 launcher循环切页

launcher 常规切页&#xff1a;https://blog.csdn.net/a396604593/article/details/125305234 循环切页 我们知道&#xff0c;launcher切页是在packages\apps\Launcher3\src\com\android\launcher3\PagedView.java的onTouchEvent中实现的。 1、滑动限制 public boolean onT…

Python与设计模式--门面模式

8-Python与设计模式–门面模式 一、火警报警器&#xff08;1&#xff09; 假设有一组火警报警系统&#xff0c;由三个子元件构成&#xff1a;一个警报器&#xff0c;一个喷水器&#xff0c; 一个自动拨打电话的装置。其抽象如下&#xff1a; class AlarmSensor:def run(self):…

c语言习题1124

分别定义函数求圆的面积和周长。 写一个函数&#xff0c;分别求三个数当中的最大数。 写一个函数&#xff0c;计算输入n个数的乘积 一个判断素数的函数&#xff0c;在主函数输入一个整数&#xff0c;输出是否为素数的信息 写一个函数求n! ,利用该函数求1&#xff01;2&…

功率半导体器件CV测试系统

概述 电容-电压(C-V)测量广泛用于测量半导体参数&#xff0c;尤其是MOS CAP和MOSFET结构。MOS(金属-氧化物-半导体)结构的电容是外加电压的函数&#xff0c;MOS电容随外加电压变化的曲线称之为C-V曲线&#xff08;简称C-V特性&#xff09;&#xff0c;C-V 曲线测试可以方便的确…

opencv-使用 Haar 分类器进行面部检测

Haar 分类器是一种用于对象检测的方法&#xff0c;最常见的应用之一是面部检测。Haar 分类器基于Haar-like 特征&#xff0c;这些特征可以通过计算图像中的积分图来高效地计算。 在OpenCV中&#xff0c;Haar 分类器被广泛用于面部检测。以下是一个简单的使用OpenCV进行面部检测…

鸿蒙系统使用hdc_std.exe使用身份证读卡器等外设USB获得权限方法

hdc_std.exe是OpenHarmony 的命令行工具&#xff0c;由于使用的开源鸿蒙开发板上面没有文件管理器&#xff0c;所以无法通过U盘等方式进行安装.hap应用。 下面是使用hdc_std.exe安装身份证读卡器的步骤&#xff1a; 1、hdc_std.exe放桌面&#xff0c;然后WINR&#xff0c;打开…

CBTC 2023氢能展倒计时6天,最新同期会议活动Plus版发布

随着时间的推移&#xff0c;CBTC2023深圳氢能技术展览会即将拉开序幕。这场盛会将于11月30日在深圳福田会展中心盛大开幕&#xff0c;以“以储赋能&#xff0c;智造未来”为主题&#xff0c;旨在搭建一个商务交流、供需合作、创新产品发布的平台&#xff0c;让氢能全产业链之间…

寻找质数 II

题目描述 输入两个整数 a&#xff0c;b&#xff0c;计算并输出小于 a 的 b个质数&#xff0c;所有符合条件的质数里&#xff0c;输出最大的 b 个质数&#xff0c;按照从大到小输出&#xff0c;使用空格隔开。 假如符合条件的数量不够&#xff0c;则输出已经满足的质数。 如果…

详解Java中的异常体系机构(throw,throws,try catch,finally)

目录 一.异常的概念 二.异常的体系结构 三.异常的处理 异常处理思路 LBYL&#xff1a;Look Before You Leap EAFP: Its Easier to Ask Forgiveness than Permission 异常抛出throw 异常的捕获 提醒声明throws try-catch捕获处理 finally的作用 四.自定义异常类 一.异…

微信小程序:This Mini Program cannot be opened as your Weixin version is out-of-date.

项目场景&#xff1a; 问题描述 升级基础库3.2.0&#xff0c;然后PC端整个小程序都打不开了&#xff0c;点击小程序提示”This Mini Program cannot be opened as your Weixin version is out-of-date. Update Weixin to the latest version.“&#xff0c;并且点击Update Wei…

一个悄然崛起的国产软件!!AI 又进化了!!

大家好&#xff0c;我是 Jack。 AI 写代码想必很多人都体验过了&#xff0c;使用 AI 编程工具是一个大趋势&#xff0c;越早学会使用 AI 辅助你写代码&#xff0c;你的效率也会越高。 甚至有些公司已经要求员工具备 AI 编程能力。 对于学生党&#xff0c;AI 编程可以帮助我们…