BERT+PET方式数据处理

基于BERT+PET方式数据预处理介绍

在这里插入图片描述

BERT+PET方式数据预处理🐾

  • 本项目中对数据部分的预处理步骤如下:
    1. 查看项目数据集
    2. 编写Config类项目文件配置代码
    3. 编写数据处理相关代码

1 查看项目数据集🐾

  • 数据存放位置:/Users/***/PycharmProjects/llm/prompt_tasks/PET/data

  • data文件夹里面包含4个txt文档,分别为:train.txt、dev.txt、prompt.txt、verbalizer.txt


1.1 train.txt
  • train.txt为训练数据集,其部分数据展示如下:
水果	脆脆的,甜味可以,可能时间有点长了,水分不是很足。
平板	华为机器肯定不错,但第一次碰上京东最糟糕的服务,以后不想到京东购物了。
书籍	为什么不认真的检查一下, 发这么一本脏脏的书给顾客呢!
衣服	手感不错,用料也很好,不知道水洗后怎样,相信大品牌,质量过关,五星好评!!!
水果	苹果有点小,不过好吃,还有几个烂的。估计是故意的放的。差评。
衣服	掉色掉的厉害,洗一次就花了

train.txt一共包含63条样本数据,每一行用\t分开,前半部分为标签(label),后半部分为原始输入 (用户评论)。

如果想使用自定义数据训练,只需要仿照上述示例数据构建数据集即可。


1.2 dev.txt
  • dev.txt为验证数据集,其部分数据展示如下:
书籍	"一点都不好笑,很失望,内容也不是很实用"
衣服	完全是一条旧裤子。
手机	相机质量不错,如果阳光充足,可以和数码相机媲美.界面比较人性化,容易使用.软件安装简便
书籍	明明说有货,结果送货又没有了。并且也不告诉我,怎么评啊
洗浴	非常不满意,晚上洗的头发,第二天头痒痒的不行了,还都是头皮屑。
水果	这个苹果感觉是长熟的苹果,没有打蜡,不错,又甜又脆

dev.txt一共包含590条样本数据,每一行用\t分开,前半部分为标签(label),后半部分为原始输入 (用户评论)。

如果想使用自定义数据训练,只需要仿照上述示例数据构建数据集即可。

1.3 prompt.txt
  • prompt.txt为人工设定提示模版,其数据展示如下:
这是一条{MASK}评论:{textA}。

其中,用大括号括起来的部分为「自定义参数」,可以自定义设置大括号内的值。

示例中 {MASK} 代表 [MASK] token 的位置,{textA} 代表评论数据的位置。

你可以改为自己想要的模板,例如想新增一个 {textB} 参数:

{textA}和{textB}是{MASK}同的意思。
1.4 verbalizer.txt🐾
  • verbalizer.txt 主要用于定义「真实标签」到「标签预测词」之间的映射。在有些情况下,将「真实标签」作为 [MASK] 去预测可能不具备很好的语义通顺性,因此,我们会对「真实标签」做一定的映射。

  • 例如:

"中国爆冷2-1战胜韩国"是一则[MASK][MASK]新闻。	体育
  • 这句话中的标签为「体育」,但如果我们将标签设置为「足球」会更容易预测。

  • 因此,我们可以对「体育」这个 label 构建许多个子标签,在推理时,只要预测到子标签最终推理出真实标签即可,如下:

体育 -> 足球,篮球,网球,棒球,乒乓,体育
  • 项目中标签词映射数据展示如下:
电脑	电脑
水果	水果
平板	平板
衣服	衣服
酒店	酒店
洗浴	洗浴
书籍	书籍
蒙牛	蒙牛
手机	手机
电器	电器

verbalizer.txt 一共包含10个类别,上述数据中,我们使用了1对1的verbalizer, 如果想定义一对多的映射,只需要在后面用","分割即可, eg:

水果	苹果,香蕉,橘子

若想使用自定义数据训练,只需要仿照示例数据构建数据集

2 编写Config类项目文件配置代码🐾

  • 代码路径:/Users/***/PycharmProjects/llm/prompt_tasks/PET/pet_config.py

  • config文件目的:配置项目常用变量,一般这些变量属于不经常改变的,比如:训练文件路径、模型训练次数、模型超参数等等

具体代码实现:

# coding:utf-8
import torch
import sys
print(sys.path)class ProjectConfig(object):def __init__(self):# 是否使用GPUself.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'# 预训练模型bert路径self.pre_model = '/home/prompt_project/bert-base-chinese'self.train_path = '/home/prompt_project/PET/data/train.txt'self.dev_path = '/home/prompt_project/PET/data/dev.txt'self.prompt_file = '/home/prompt_project/PET/data/prompt.txt'self.verbalizer = '/home/prompt_project/PET/data/verbalizer.txt'self.max_seq_len = 512self.batch_size = 8self.learning_rate = 5e-5# 权重衰减参数(正则化,抑制模型过拟合)self.weight_decay = 0# 预热学习率(用来定义预热的步数)self.warmup_ratio = 0.06self.max_label_len = 2self.epochs = 50self.logging_steps = 10self.valid_steps = 20self.save_dir = '/home/prompt_project/PET/checkpoints'if __name__ == '__main__':pc = ProjectConfig()print(pc.prompt_file)print(pc.pre_model)

3 编写数据处理相关代码🐾

  • 代码路径:/Users/***/PycharmProjects/llm/prompt_tasks/PET/data_handle.

  • data_handle文件夹中一共包含三个py脚本:template.py、data_preprocess.py、data_loader.py

3.1 template.py
  • 目的:构建固定模版类,text2id的转换

  • 导入必备工具包

# -*- coding:utf-8 -*-
from rich import print # 终端层次显示
from transformers import AutoTokenizer
import numpy as np
import sys
sys.path.append('..')
from pet_config import *
  • 定义HardTemplate类
class HardTemplate(object):"""硬模板,人工定义句子和[MASK]之间的位置关系。"""def __init__(self, prompt: str):"""Args:prompt (str): prompt格式定义字符串, e.g. -> "这是一条{MASK}评论:{textA}。""""self.prompt = promptself.inputs_list = []                       # 根据文字prompt拆分为各part的列表self.custom_tokens = set(['MASK'])          # 从prompt中解析出的自定义token集合self.prompt_analysis()                         # 解析prompt模板def prompt_analysis(self):"""将prompt文字模板拆解为可映射的数据结构。Examples:prompt -> "这是一条{MASK}评论:{textA}。"inputs_list -> ['这', '是', '一', '条', 'MASK', '评', '论', ':', 'textA', '。']custom_tokens -> {'textA', 'MASK'}"""idx = 0while idx < len(self.prompt):str_part = ''if self.prompt[idx] not in ['{', '}']:self.inputs_list.append(self.prompt[idx])if self.prompt[idx] == '{':                  # 进入自定义字段idx += 1while self.prompt[idx] != '}':str_part += self.prompt[idx]             # 拼接该自定义字段的值idx += 1elif self.prompt[idx] == '}':raise ValueError("Unmatched bracket '}', check your prompt.")if str_part:self.inputs_list.append(str_part)# 将所有自定义字段存储,后续会检测输入信息是否完整self.custom_tokens.add(str_part)  idx += 1def __call__(self,inputs_dict: dict,tokenizer,mask_length,max_seq_len=512):"""输入一个样本,转换为符合模板的格式。Args:inputs_dict (dict): prompt中的参数字典, e.g. -> {"textA": "这个手机也太卡了", "MASK": "[MASK]"}tokenizer: 用于encoding文本mask_length (int): MASK token 的长度Returns:dict -> {'text': '[CLS]这是一条[MASK]评论:这个手机也太卡了。[SEP]','input_ids': [1, 47, 10, 7, 304, 3, 480, 279, 74, 47, 27, 247, 98, 105, 512, 777, 15, 12043, 2],'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1],'mask_position': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}"""# 定义输出格式outputs = {'text': '', 'input_ids': [],'token_type_ids': [],'attention_mask': [],'mask_position': []}str_formated = ''for value in self.inputs_list:if value in self.custom_tokens:if value == 'MASK':str_formated += inputs_dict[value] * mask_lengthelse:str_formated += inputs_dict[value]else:str_formated += value# print(f'str_formated-->{str_formated}')encoded = tokenizer(text=str_formated,truncation=True,max_length=max_seq_len,padding='max_length')# print(f'encoded--->{encoded}')outputs['input_ids'] = encoded['input_ids']outputs['token_type_ids'] = encoded['token_type_ids']outputs['attention_mask'] = encoded['attention_mask']token_list = tokenizer.convert_ids_to_tokens(encoded['input_ids'])outputs['text'] = ''.join(token_list)mask_token_id = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]condition = np.array(outputs['input_ids']) == mask_token_idmask_position = np.where(condition)[0].tolist()outputs['mask_position'] = mask_positionreturn outputsif __name__ == '__main__':pc = ProjectConfig()tokenizer = AutoTokenizer.from_pretrained(pc.pre_model)hard_template = HardTemplate(prompt='这是一条{MASK}评论:{textA}')print(hard_template.inputs_list)print(hard_template.custom_tokens)tep = hard_template(inputs_dict={'textA': '包装不错,苹果挺甜的,个头也大。', 'MASK': '[MASK]'},tokenizer=tokenizer,max_seq_len=30,mask_length=2)print(tep)print(tokenizer.convert_ids_to_tokens([3819, 3352]))print(tokenizer.convert_tokens_to_ids(['水', '果']))

3.2 data_preprocess.py🐾
  • 目的: 将样本数据转换为模型接受的输入数据

  • 导入必备的工具包

from template import *
from rich import print
from datasets import load_dataset
# partial:把一个函数的某些参数给固定住(也就是设置默认值),返回一个新的函数,调用这个新函数会更简单
from functools import partial
from pet_config import *

  • 定义数据转换方法convert_example()
def convert_example(examples: dict,tokenizer,max_seq_len: int,max_label_len: int,hard_template: HardTemplate,train_mode=True,return_tensor=False) -> dict:"""将样本数据转换为模型接收的输入数据。Args:examples (dict): 训练数据样本, e.g. -> {"text": ['手机	这个手机也太卡了。','体育	世界杯为何迟迟不见宣传',...]}max_seq_len (int): 句子的最大长度,若没有达到最大长度,则padding为最大长度max_label_len (int): 最大label长度,若没有达到最大长度,则padding为最大长度hard_template (HardTemplate): 模板类。train_mode (bool): 训练阶段 or 推理阶段。return_tensor (bool): 是否返回tensor类型,如不是,则返回numpy类型。Returns:dict (str: np.array) -> tokenized_output = {'input_ids': [[1, 47, 10, 7, 304, 3, 3, 3, 3, 47, 27, 247, 98, 105, 512, 777, 15, 12043, 2], ...],'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ...],'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ...],'mask_positions': [[5, 6, 7, 8], ...],'mask_labels': [[2372, 3442, 0, 0], [2643, 4434, 2334, 0], ...]}"""tokenized_output = {'input_ids': [],'token_type_ids': [],'attention_mask': [],'mask_positions': [],'mask_labels': []}for i, example in enumerate(examples['text']):if train_mode:label, content = example.strip().split('\t')else:content = example.strip()inputs_dict = {'textA': content,'MASK': '[MASK]'}encoded_inputs = hard_template(inputs_dict=inputs_dict,tokenizer=tokenizer,max_seq_len=max_seq_len,mask_length=max_label_len)tokenized_output['input_ids'].append(encoded_inputs["input_ids"])tokenized_output['token_type_ids'].append(encoded_inputs["token_type_ids"])tokenized_output['attention_mask'].append(encoded_inputs["attention_mask"])tokenized_output['mask_positions'].append(encoded_inputs["mask_position"])if train_mode:label_encoded = tokenizer(text=[label])  # 将label补到最大长度# print(f'label_encoded-->{label_encoded}')label_encoded = label_encoded['input_ids'][0][1:-1]label_encoded = label_encoded[:max_label_len]add_pad = [tokenizer.pad_token_id] * (max_label_len - len(label_encoded))label_encoded = label_encoded + add_padtokenized_output['mask_labels'].append(label_encoded)for k, v in tokenized_output.items():if return_tensor:tokenized_output[k] = torch.LongTensor(v)else:tokenized_output[k] = np.array(v)return tokenized_outputif __name__ == '__main__':pc = ProjectConfig()train_dataset = load_dataset('text', data_files=pc.train_path)print(type(train_dataset))print(train_dataset)# print('*'*80)# print(train_dataset['train']['text'])tokenizer = AutoTokenizer.from_pretrained(pc.pre_model)hard_template = HardTemplate(prompt='这是一条{MASK}评论:{textA}')convert_func = partial(convert_example,tokenizer=tokenizer,hard_template=hard_template,max_seq_len=30,max_label_len=2)dataset = train_dataset.map(convert_func, batched=True)for value in dataset['train']:print(value)print(len(value['input_ids']))break

3.3 data_loader.py🐾
  • 目的:定义数据加载器

  • 导入必备的工具包

# coding:utf-8
from torch.utils.data import DataLoader
from transformers import default_data_collator
from data_preprocess import *
from pet_config import *pc = ProjectConfig() # 实例化项目配置文件
tokenizer = AutoTokenizer.from_pretrained(pc.pre_model)

  • 定义获取数据加载器的方法get_data()
def get_data():# prompt定义prompt = open(pc.prompt_file, 'r', encoding='utf8').readlines()[0].strip()  hard_template = HardTemplate(prompt=prompt)  # 模板转换器定义dataset = load_dataset('text', data_files={'train': pc.train_path,'dev': pc.dev_path})# print(dataset)# print(f'Prompt is -> {prompt}')new_func = partial(convert_example,tokenizer=tokenizer,hard_template=hard_template,max_seq_len=pc.max_seq_len,max_label_len=pc.max_label_len)dataset = dataset.map(new_func, batched=True)train_dataset = dataset["train"]dev_dataset = dataset["dev"]# print('train_dataset', train_dataset[:2])# print('*'*80)train_dataloader = DataLoader(train_dataset,shuffle=True,collate_fn=default_data_collator,batch_size=pc.batch_size)dev_dataloader = DataLoader(dev_dataset,collate_fn=default_data_collator,batch_size=pc.batch_size)return train_dataloader, dev_dataloaderif __name__ == '__main__':train_dataloader, dev_dataloader = get_data()print(len(train_dataloader))print(len(dev_dataloader))for i, value in enumerate(train_dataloader):print(i)print(value)print(value['input_ids'].dtype)break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/23238.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp内置的button组件的问题

问题描述 由于想要使用uniapp内置button组件的开放能力&#xff0c;所以就直接使用了button&#xff0c;但是他本身带着边框&#xff0c;而且使用 border&#xff1a;none&#xff1b;是没有效果的。 问题图片 解决方案 button::after {border: none;} 正确样式 此时的分享…

HarmonyOS(31) @Prop标签使用指南

Prop Prop简介State和Prop的同步场景使用示例参考资料 Prop简介 子组件中Prop装饰的变量可以和父组件建立单向的同步关系。子组件Prop装饰的变量是可变的&#xff0c;但是变化不会同步回其父组件。Prop变量允许子组件修改&#xff0c;但修改后的变化不会同步回父组件。当父组件…

python书上的动物是啥

Python的创始人为Guido van Rossum。1989年圣诞节期间&#xff0c;在阿姆斯特丹&#xff0c;Guido为了打发圣诞节的无趣&#xff0c;决心开发一个新的脚本解释程序&#xff0c;做为ABC语言的一种继承。之所以选中Python作为程序的名字&#xff0c;是因为他是一个叫Monty Python…

【核心动画-转场动画-CATransition Objective-C语言】

一、转场动画,CATransition, 1.接下来,我们来说这个转场动画啊,效果呢,会做这么一个小例子, 感觉有一个3D的一个样式一样, 转场动画呢,就是说,你在同一个View,比如说,imageView,去切换图片的时候,你可以去用这个,转场动画, 实际上,包括,控制器之间的切换,也…

新手上路:Linux虚拟机创建与Hadoop集群配置指南①(未完)

一、基础阶段 Linux操作系统: 创建虚拟机 1.创建虚拟机 打开VM,点击文件,新建虚拟机,点击自定义,下一步 下一步 这里可以选择安装程序光盘映像文件,我选择稍后安装 选择linux系统 位置不选C盘,创建一个新的文件夹VM来放置虚拟机,将虚拟机名字改为master方便后续识别…

期望24K,商汤科技golang开发 社招一二三 + hr 面

商汤科技对数据库和中间件相关的东西问的比其他的大厂要少很多&#xff0c;可能他们更多是和算法相关&#xff0c;没有什么高并发的场景。总体感觉对技术的要求不是特别高。当时问了他们主管&#xff0c;我面试的部门的工作是主要去实现他们算法部门研究的算法&#xff0c;感觉…

[图解]企业应用架构模式2024新译本讲解09-领域模型2

1 00:00:01,750 --> 00:00:03,030 代码还是一样的 2 00:00:03,040 --> 00:00:12,640 我们还是从前面人家做的复刻案例来看 3 00:00:14,170 --> 00:00:15,200 这个是它的类图 4 00:00:15,640 --> 00:00:20,650 我们同样用UModel逆转&#xff0c;这个太小了&#…

windows RNDIS开发-概念

远程 NDIS (RNDIS) 是一种独立于总线的类&#xff0c;适用于动态 即插即用 (PnP) 总线&#xff08;例如 USB、1394、蓝牙和 InfiniBand&#xff09;上的以太网 (802.3) 网络设备。 远程 NDIS 通过抽象控制和数据通道在主计算机与远程 NDIS 设备之间定义与总线无关的消息协议。 …

【微信小程序】页面导航

声明式导航 导航到 tabbar 页 tabBar页面指的是被配置为tabBar的页面。 在使用<navigator>组件跳转到指定的tabBar页面时&#xff0c;需要指定url属性和open-type属性&#xff0c;其中&#xff1a; url 表示要跳转的页面的地址&#xff0c;必须以/开头open-type表示跳…

spring boot3登录开发-2(3邮件验证码接口实现)

⛰️个人主页: 蒾酒 &#x1f525;系列专栏&#xff1a;《spring boot实战》 目录 写在前面 上文衔接 接口设计与实现 1.接口分析 2.实现思路 3.代码实现 1.定义验证码短信HTML模板枚举类 2.定义验证码业务接口 3. 验证码业务接口实现 4.控制层代码 4.测试 写…

场外个股期权标的有哪些?

今天带你了解场外个股期权标的有哪些&#xff1f;场外个股期权是一种金融衍生品&#xff0c;它不在交易所内进行交割&#xff0c;而是在交易所以外的场所进行交易的股票期权合约。 场外个股期权标的有哪些&#xff1f; 场外个股期权的标的通常包括A股市场上的融资融券标的&…

ARM服务器在云手机中可以提供哪些支持

ARM服务器作为云手机的底层支撑&#xff0c;在很多社媒APP或者电商APP平台都有着很多看不见的功劳&#xff0c;可以说ARM扮演着至关重要的底层支持角色&#xff1b; 首先&#xff0c;ARM 服务器为云手机提供了强大的计算能力基础。云手机需要处理大量的数据和复杂的运算&#x…

微服务第一轮

课程文档 目录 一、业务流程 1、登录 Controller中的接口&#xff1a; Service中的实现impl&#xff1a; Service中的实现impl所继承的接口IService&#xff08;各种方法&#xff09;&#xff1a; VO&#xff1a; DTO&#xff1a; 2、搜索商品 ​Controller中的接口&a…

【亚马逊云科技 CSDN 联合巨献】 「对话AI 构建者:从基础到应用的 LLM 全景培训」 限时免费!

&#x1f680;&#x1f31f;【亚马逊云科技 & CSDN 联合巨献】 &#x1f4da;「对话AI 构建者&#xff1a;从基础到应用的 LLM 全景培训」&#x1f525; 限时免费&#xff01; &#x1f4c6; 抓紧时间&#xff01;6月7日前注册&#xff0c;原价 399&#xff0c;现在仅需 0…

C基础与SDK调试方法

REVIEW 上次学习了一下软件使用流程zynq PS点灯-CSDN博客 本次学习一下C编程基础与调试方法 1. 硬件编程原理 小梅哥视频链接&#xff1a; 07_Xilinx嵌入式裸机硬件编程原理_哔哩哔哩_bilibili 对应的课程笔记&#xff1a;【zynq课程笔记】【裸机】【第7课 】【硬件编程原理…

C++ STL - 容器

C STL&#xff08;标准模板库&#xff09;中的容器是一组通用的、可复用的数据结构&#xff0c;用于存储和管理不同类型的数据。 目录 零. 简介&#xff1a; 一 . vector&#xff08;动态数组&#xff09; 二. list&#xff08;双向链表&#xff09; 三. deque&#xff08…

yolov8摔倒检测(包含数据集+训练好的模型)

基于先进的YOLOv8模型&#xff0c;实现了一套高效可靠的人体摔倒检测系统。YOLOv8作为YOLO系列的最新成员&#xff0c;以其卓越的检测速度和准确性&#xff0c;在计算机视觉领域尤其是目标检测任务中表现出色。本系统不仅能够实时处理视频流或监控画面&#xff0c;还能对静态图…

SwiftUI中Menu和ControlGroup的使用

本篇文章主要介绍一下Menu组件和ControlGroup组件的使用。Menu组件是在iOS 14&#xff08;tvOS 17&#xff09;推出的一个组件&#xff0c;点击后提供一个可选择的操作列表。ControlGroup组件是一个容器视图&#xff0c;以视觉上适当的方式为给定的上下文显示语义相关的控件&am…

【面试干货】SQL语言分类

【面试干货】SQL语言分类 1、数据查询语言&#xff08;DQL&#xff09;2、数据操纵语言&#xff08;DML&#xff09;3、数据定义语言&#xff08;DDL&#xff09;4、数据控制语言&#xff08;DCL&#xff09;5、结语 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收…

使用gradio库实现Web应用,允许用户上传图像,并使用YOLOv8模型对图像进行目标检测。

一、Gradio Gradio 详细介绍 Gradio 是一个用于构建和分享机器学习模型和数据科学应用的开源Python库。它简化了创建交互式Web界面的过程&#xff0c;让开发者可以快速搭建原型并与他人分享。 主要特性 易用性&#xff1a; 无需前端开发经验&#xff1a;只需几行Python代码就…