微调大模型时,如何进行数据预处理? 将<input, output>转换为模型所需的<input_ids, labels, attention_mask>

原始训练数据集格式如下:

<input, output>

形式:字符

模型训练所需数据格式如下:

# tokenizer处理后
return {"input_ids": example,"labels": labels,"attention_mask": example_mask,
}

将字符转换为id,生成三个部分。

讲解

在大模型训练中,input_idslabelsattention_mask 是标准的数据格式,用于表示输入文本和相关的辅助信息。这些数据是模型训练中必需的,尤其是在自然语言处理(NLP)任务中。下面是这些字段的详细解释:

1. input_ids

  • 含义input_ids 是输入文本的 Token IDs(即词汇表中的索引)。
  • 用途:它们是模型输入的实际内容。模型通常需要将文本转换为数字形式来处理,input_ids 就是将每个词(或子词)映射到一个整数,这个整数表示该词在预训练模型的词汇表中的位置。
  • 例子
    • 假设输入文本是:"I love AI."
    • 经过分词和映射后,可能得到 input_ids = [101, 1045, 2293, 9474, 102],其中:
      • 101102 是特殊的标记(如开始和结束标记)。
      • 1045 是词 "I" 对应的 ID。
      • 2293 是词 "love" 对应的 ID,依此类推。
  • 注意:每个模型都有自己独特的词汇表,因此同样的文本在不同模型中得到的 input_ids 可能不同。

2. labels

  • 含义labels 是模型输出的目标数据,通常用于监督学习任务中的目标标签。
  • 用途labels 用于计算损失函数(比如交叉熵损失)来优化模型。它通常与 input_ids 具有相同的格式,用于生成目标预测。在一些任务中,labels 是输入文本的某种变换,比如机器翻译的目标句子,或者是文本分类任务中的标签。
  • 例子
    • 对于语言模型的训练,labels 通常与 input_ids 相同,代表的是下一个词的预测。即,模型在每个位置预测下一个词的 input_id
    • 在问答任务中,labels 可能是模型应当输出的答案。
    • 对于分类任务,labels 可以是一个整数值,表示文本的类别(如 012 等)。

语言模型的例子

  • 输入句子:"I love AI."
  • input_ids = [101, 1045, 2293, 9474, 102]
  • labels 对于语言模型任务,可能和 input_ids 一样:
    • labels = [1045, 2293, 9474, 102, -100]
    • 这里,-100 是一个特殊的标记,用于掩盖某些位置的标签,避免在计算损失时对某些位置进行更新。

3. attention_mask

  • 含义attention_mask 是一个与 input_ids 等长的向量,指示模型在哪些位置需要关注(即哪些位置是有效的)以及哪些位置应忽略(即哪些位置是填充的)。
  • 用途:在许多 NLP 模型中,文本可能需要进行填充(padding),特别是在处理不同长度的文本时。attention_mask 用于指示模型应该计算注意力的哪些位置,而哪些位置是填充,应该被忽略。
    • 1 表示该位置是有效的,模型应该关注这个位置。
    • 0 表示该位置是填充,模型应该忽略这个位置。
  • 例子
    • 假设输入文本是两句话,分别是:"I love AI." 和 "I enjoy machine learning."
    • input_ids[101, 1045, 2293, 9474, 102, 101, 1045, 2829, 4974, 102]
    • 由于句子长度不同,第二句需要填充。假设最大长度是 10,所以填充的位置会使用 0,并且 attention_mask 会标明哪些位置需要关注:
    • attention_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
    • 这里,1 表示模型应该关注的位置,0 表示填充的位置。

总结:

这些字段在训练时的意义:

  • input_ids:模型的输入序列,表示文本的词汇索引。
  • labels:目标输出序列,通常与 input_ids 相同,用于计算损失函数。
  • attention_mask:指示哪些位置是有效的,哪些是填充,帮助模型避免处理填充部分的数据。

这种格式在 NLP 任务中广泛使用,尤其是在语言模型训练(如 GPT、BERT 等)和序列到序列任务(如机器翻译、问答系统)中。

在大多数情况下,input_ids不包含 训练数据中的真实输出(如目标标签)。input_ids 主要用于表示输入序列,即模型的输入,而真实的输出通常会在 labels 中提供。我们可以通过具体的任务来更好地理解它们之间的关系。

input_ids 和 labels

1. 语言建模任务(如 GPT)

对于语言建模任务,input_idslabels 是非常相似的,甚至有时完全相同。这是因为语言模型的任务是根据前面的上下文预测下一个词。因此,模型的输入(input_ids)和目标输出(labels)是相同的,且 labels 的每个位置都表示目标词。

举例

  • 输入句子: "I love AI"

  • 假设使用的词汇表索引(简化表示):I -> 1045, love -> 2293, AI -> 9474

  • input_ids = [1045, 2293, 9474](这就是模型的输入)

  • labels = [2293, 9474, -100]-100 是一个占位符,表示忽略该位置)

在语言模型任务中,模型的目标是预测每个位置的下一个词。所以 labels 会从第一个词开始,包含实际的目标词。input_idslabels 对于每个位置来说,在训练时是同步的,只是模型预测的是 下一个 词。

2. 序列标注任务(如命名实体识别 NER 或 POS 标注)

在序列标注任务中,input_ids 仍然是输入序列的表示,但 labels 是每个输入单词或标记的标签。这时,input_ids 仅包含输入文本的词汇索引,而真实标签(比如实体类别或词性标签)则在 labels 中。

举例

  • 输入文本: "I love AI"

  • 目标标签(假设为命名实体识别任务):I -> O, love -> O, AI -> B-ORG

  • input_ids = [1045, 2293, 9474](表示输入的文本)

  • labels = [0, 0, 1](表示每个词的标签,其中 0 是普通词,1 表示 "AI" 是一个组织实体)

在这种任务中,input_idslabels 是不同的,input_ids 表示输入,而 labels 表示这些输入对应的标签。

3. 序列到序列任务(如机器翻译或文本生成)

在机器翻译或文本生成任务中,input_idslabels 也会有明显的区别:

  • input_ids 是源语言的文本表示。
  • labels 是目标语言的文本表示。

例如,在翻译任务中:

  • 输入文本(源语言): "I love AI"

  • 目标文本(目标语言): "J'aime l'IA"

  • input_ids = [1045, 2293, 9474](表示源语言)

  • labels = [2013, 1244, 1849](表示目标语言)

在这种情况下,input_idslabels 完全不同,因为它们分别表示源语言和目标语言。

4. 总结:

  • input_ids:表示模型的输入文本的 token 化结果,通常是对文本进行分词、编码后的词汇索引。它仅包含输入数据,不包含目标输出。
  • labels:表示模型的目标输出,通常用于训练期间计算损失。在语言建模任务中,labels 可能和 input_ids 一样;但在其他任务(如分类、序列标注、机器翻译等)中,labelsinput_ids 完全不同,表示模型应该生成或预测的目标结果。

因此,input_ids 不包含 训练数据中的真实输出,而 labels 才是训练时用来计算损失和评估模型性能的目标值。

处理代码

crop_train.json训练数据集格式如下:

{"instruction": "你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。","input": "煤是一种常见的化石燃料,家庭用煤经过了从\"煤球\"到\"蜂窝煤\"的演变。","output": "[{\"head\": \"煤\", \"relation\": \"use\", \"tail\": \"燃料\"}]"
},

数据预处理代码如下: 

import json
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainerdataset = load_dataset("json", data_files="./crop_train.json", split="train")
print(f"dataset: {dataset}")tokenizer = AutoTokenizer.from_pretrained("./glm-4-9b-chat", trust_remote_code=True)
print(f"tokenizer: {tokenizer}")def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []# 合并example的instruction和input字段为一个字符串instruction = f"{example['instruction']} {example['input']}".strip()  # queryinstruction = tokenizer.apply_chat_template([{"role": "user", "content": instruction}],add_generation_prompt=True,tokenize=True,return_tensors="pt",return_dict=True)  # '[gMASK] <sop> <|user|> \nquery <|assistant|>'# 检查example["output"]是否是列表,并相应地处理if isinstance(example["output"], list):response_text = "\n".join(example["output"])else:response_text = "\n" + example["output"]response = tokenizer(response_text, add_special_tokens=False)  # \n response, 缺少eos token# input_ids = input + outputinput_ids = instruction["input_ids"][0].numpy().tolist() + response["input_ids"] + [tokenizer.eos_token_id]attention_mask = instruction["attention_mask"][0].numpy().tolist() + response["attention_mask"] + [1]# labels = input(-100) + outputlabels = [-100] * len(instruction["input_ids"][0].numpy().tolist()) + response["input_ids"] + [tokenizer.eos_token_id]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}# 训练数据集经过预处理后生成<input_ids, labels, attention_mask>
tokenized_ds = dataset.map(process_func, remove_columns=['instruction', 'input', 'output'])print(f"All tokenizer tokens ids: {tokenized_ds}")     # features: ['input_ids', 'attention_mask', 'labels'],# tokenized_ds: 包含input_ids, attention_mask, labels = [], [], []
input_ids_1 = tokenized_ds[0]["input_ids"]
attention_mask_1 = tokenized_ds[0]["attention_mask"]
labels_1 = tokenized_ds[0]["labels"]
print(f"input_ids_1: {input_ids_1}")
print(f"attention_mask_1: {attention_mask_1}")
print(f"labels_1: {labels_1}")input_text_1 = tokenizer.decode(input_ids_1)
print(f"input_ids_1_decode: {input_text_1}")

tokenized_ds:里面包含所有的训练数据经过转换后的 ['input_ids', 'attention_mask', 'labels']集合,用于model直接使用来训练。

模型训练所需数据格式如下: 

input_ids_1: [151331, 151333, 151336, 198, 103408, 112687, 99788, 102014, 98638, 99172, 115023, 98314, 100153, 1773, 98964, 98484, 98602, 100966, 103231, 98322, 100319, 107325, 99172, 120673, 98555, 3837, 107399, 102189, 104559, 98745, 106522, 1773, 98964, 99928, 5370, 121478, 98314, 104714, 99770, 1773, 10231, 227, 97, 100375, 104250, 112075, 106512, 3837, 99716, 98340, 100855, 114094, 98484, 1, 100855, 98781, 1, 98344, 1, 125272, 100855, 1, 98314, 110001, 1773, 151337, 198, 58, 4913, 1983, 788, 330, 100855, 497, 330, 22166, 788, 330, 810, 497, 330, 14576, 788, 330, 106512, 9204, 60, 151329]
attention_mask_1: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels_1: [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 198, 58, 4913, 1983, 788, 330, 100855, 497, 330, 22166, 788, 330, 810, 497, 330, 14576, 788, 330, 106512, 9204, 60, 151329]input_ids_1_decode: 
[gMASK] <sop> <|user|> 
你是农作物领域专门进行关系抽取的专家。请从给定的文本中抽取出关系三元组,不存在的关系返回空列表。请按照JSON字符串的格式回答。 煤是一种常见的化石燃料,家庭用煤经过了从"煤球"到"蜂窝煤"的演变。 
<|assistant|> [{"head": "煤", "relation": "use", "tail": "燃料"}] <|endoftext|>

模型训练

# 模型训练参数
args = TrainingArguments(output_dir="./chatbot",per_device_train_batch_size=2,gradient_accumulation_steps=8,gradient_checkpointing=True,logging_steps=100,num_train_epochs=10,learning_rate=1e-4,remove_unused_columns=False,save_strategy="epoch"
)# 开始训练
trainer = Trainer(model=model,args=args,# 使用 <input_ids, labels, attention_mask>train_dataset=tokenized_ds.select(range(10000)),data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

总结

LLM的训练或者微调都是需要<input_ids, labels, attention_mask>形式的数据,e. g. Structured IE中的三个task的dataset仍是如此。

class MOFDataset(Dataset):def __init__(self, dataset_config, tokenizer, split_name, max_words=1024):#self.data = json.load(open(dataset_config.data_path))if split_name == "train":self.data = json.load(open(dataset_config.data_path+"/train.json")) # self.data[0]["train"]  # Adjust this based on your dataset's structureelse:self.data = json.load(open(dataset_config.data_path+"/val.json"))# self.data[0]["validation"]  # Adjust this based on your dataset's structureself.max_words = max_wordsself.tokenizer = tokenizerdef __len__(self):return len(self.data)def __getitem__(self, index):IGNORE_INDEX = -100  # The default setting in CrossEntropyLossitem = self.data[index]#prompt = f"### Instruction:\n{item['instruction']}\n\n### Input:\n{item['input']}\n\n### Response:"prompt = item['input']#f"item['input']\n\n"# example = input+ outputexample = prompt + item["output"]
#        print(example)prompt = torch.tensor(self.tokenizer.encode(prompt), dtype=torch.int64)# example = input_ids + output_idsexample = self.tokenizer.encode(example)    # input+output
#        print(example)# example = input_ids + output_ids + <eos>example.append(self.tokenizer.eos_token_id)example = torch.tensor(example, dtype=torch.int64)padding = self.max_words - example.shape[0]# 用 -1 填充if padding > 0:example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))# 截断elif padding < 0:example = example[: self.max_words]# labels = examplelabels = copy.deepcopy(example)# labels = input(-1) + output# 复制 example,并将 example 中与 prompt 对应的部分设置为 -1。这样,模型在训练时只会关注 output 部分作为标签,而忽略掉输入部分。labels[: len(prompt)] = -1# 创建一个 mask,用于标记 example 中不为 -1(即有效)的部分。example_mask = example.ge(0)# 创建一个 mask,用于标记 labels 中不为 IGNORE_INDEX(即有效)的部分。label_mask = labels.ge(0)# 将无效的部分(填充部分)置为 0,这样它们不会对损失计算产生影响。example[~example_mask] = 0labels[~label_mask] = IGNORE_INDEXexample_mask = example_mask.float()label_mask = label_mask.float()return {"input_ids": example,   # example = input_ids + output_ids + <eos>"labels": labels,       # labels = input(-1) + output,输入部分被替换为 -1,只保留 output 部分作为目标标签。"attention_mask": example_mask,  # 返回一个 mask,指示哪些位置是有效的(即不是填充部分)。}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/65152.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

git推送本地仓库到远程(Gitee)

目录 一、注册创建库 二、创建仓库 三、推送本地仓库到远程 1.修改本地仓库用户名和邮箱 2.本地库关联远程仓库 3.拉取远程仓库的文件 4.推送本地库的文件 5.查看远程仓库 四、远程分支查看 1.查看远程分支 2.修改test.txt文件 一、注册创建库 Gitee官网&#xff1…

GoZero框架接入数据库引擎Gorm 并实战:构建简单的 CRUD 业务API

GoZero 是一个高性能的微服务框架&#xff0c;它基于 Go 语言开发&#xff0c;提供了丰富的工具支持&#xff0c;能够帮助开发者快速构建可扩展、易维护的应用。Gorm 是 Go 语言中常用的 ORM 库&#xff0c;它帮助我们简化数据库操作&#xff0c;使用面向对象的方式进行增删改查…

KNN分类算法 HNUST【数据分析技术】(2025)

1.理论知识 KNN&#xff08;K-Nearest Neighbor&#xff09;算法是机器学习算法中最基础、最简单的算法之一。它既能用于分类&#xff0c;也能用于回归。KNN通过测量不同特征值之间的距离来进行分类。 KNN算法的思想&#xff1a; 对于任意n维输入向量&#xff0c;分别对应于特征…

探索Flink动态CEP:杭州银行的实战案例

摘要&#xff1a;本文撰写自杭州银行大数据工程师唐占峰、欧阳武林老师。将介绍 Flink 动态 CEP的定义与核心概念、应用场景、并深入探讨其技术实现并介绍使用方式。主要分为以下几个内容&#xff1a; Flink动态CEP简介 Flink动态CEP的应用场景 Flink动态CEP的技术实现 Flin…

打造高效租赁小程序让交易更便捷

内容概要 在如今节奏飞快的商业世界里&#xff0c;租赁小程序如同一只聪明的小狐狸&#xff0c;迅速突围而出&#xff0c;成为商家与消费者之间的桥梁。它不仅简化了交易流程&#xff0c;还在某种程度上将传统租赁模式带入了互联网时代。越来越多的企业意识到&#xff0c;这种…

【MinIO系列】MinIO Client (mc) 完全指南

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Jmeter录制https请求

jmeter 5.5版本&#xff0c;chrome浏览器 1、首先添加Test Plan-Thread Group-HTTP(S) Test Script Recorder 2、设置HTTP(S) Test Script Recorder界面的Port&#xff08;监听端口&#xff0c;设置浏览器代理时需要与这里保持一致&#xff09;、HTPS Domains&#xff08;录制…

前端最新Vue2+Vue3基础入门到实战项目全套教程,自学前端vue就选黑马程序员,一套全通关!

Vue 快速上手 Vue概念 Vue 是一个用于构建用户界面的渐进式框架 构建用户界面&#xff1a;基于数据渲染出用户看到的页面 渐进式&#xff1a;循序渐进 框架&#xff1a;一套完整的项目解决方案 Vue 的两种使用方式: ① Vue 核心包开发 场景:局部 模块改造 ② Vue 核心包 &am…

基于Spring Boot的高校请假管理系统

一、系统背景与意义 随着高校规模的扩大和学生数量的增加&#xff0c;传统的请假管理方式已经难以满足高校管理的需求。人工请假流程繁琐、耗时长&#xff0c;且容易出现信息错误或遗漏。因此&#xff0c;开发一套基于Spring Boot的高校请假管理系统具有重要意义&#xff0c;它…

Gate.io 平台通证 GT:持续赋能与销毁、财富效应显著

在瞬息万变的加密市场中&#xff0c;每一轮牛熊转换都在加速 CEX 市场的一轮又一轮洗牌&#xff0c;这也使得该赛道的格局始终处于动态的变化。而在本轮牛市中&#xff0c;CEX 赛道也正在从最初的三大领衔变成了多强角逐&#xff0c;而 Gate.io 作为创立 11 余年的老牌交易平台…

WebRTC音视频同步原理与实现详解(下)

WebRTC音视频同步原理与实现详解&#xff08;上&#xff09; 第四章、音视频同步实现详解 4.1 音视频同步标准 音视频做到什么程度才算是同步呢&#xff1f; 关于音画同步, 业界有3个标准&#xff1a; 1&#xff09;ITU-R BT.1359&#xff08;1998&#xff09;&#xff1a…

1.系统学习-线性回归

系统学习-线性回归 前言线性回归介绍误差函数梯度下降梯度下降示例 回归问题常见的评价函数1. MAE, mean absolutely error2. MSE, mean squared error3. R square &#xff08;决定系数或R方&#xff09; 机器学习建模流程模型正则化拓展阅读作业 链接: 2.系统学习-逻辑回归 …

Oracle 日常巡检

1. 检查服务器状态 1.1. CPU使用情况 1.1.1. top top 命令是 Linux 和 Unix 系统中用于显示实时系统状态的工具&#xff0c;特别是对于监控 CPU 和内存的使用非常有用。 在命令行中输入 top&#xff0c;top 会显示一个实时更新的界面&#xff0c;其中包含系统的关键指标&am…

熊军出席ACDU·中国行南京站,详解SQL管理之道

12月21日&#xff0c;2024 ACDU中国行在南京圆满收官&#xff0c;本次活动分为三个篇章——回顾历史、立足当下、展望未来&#xff0c;为线上线下与会观众呈现了一场跨越时空的技术盛宴&#xff0c;吸引了众多业内人士的关注。云和恩墨副总经理熊军出席此次活动并发表了主题演讲…

如何在网页端使用 IDE 高效地阅读 GitHub 源码?

如何在网页端使用 IDE 高效地阅读 GitHub 源码&#xff1f; 前言什么是 GitHub1s&#xff1f;使用 GitHub1s 阅读 browser-use 项目源码步骤 1: 打开 GitHub 项目页面步骤 2: 修改 URL 使用 GitHub1s步骤 3: 浏览文件结构步骤 4: 使用代码高亮和智能补全功能步骤 5: 快速跳转和…

3D布展平台主要有哪些功能?有什么特点?

3D布展平台是一种利用3D技术和虚拟现实&#xff08;VR&#xff09;技术&#xff0c;为用户提供线上虚拟展览和展示服务的平台。这些平台通常允许用户创建、设计和发布3D虚拟展厅&#xff0c;从而提供沉浸式的展览体验。以下是对3D布展平台的详细介绍&#xff1a; 一、主要功能 …

TowardsDataScience 博客中文翻译 2018~2024(一百二十三)

TowardsDataScience 博客中文翻译 2018~2024&#xff08;一百二十三&#xff09; 引言 从 2018 年到 2024 年&#xff0c;数据科学的进展超越了许多技术领域的速度。Towards Data Science 博客依然是这个领域的关键平台&#xff0c;记录了从基础工具到前沿技术的多方面发展。…

Docker Build 命令详解:在 Ubuntu 上构建 Docker 镜像教程

简介 Docker 通过提供轻量级、可移植和高效的解决方案&#xff0c;彻底改变了软件开发和部署。docker build 命令是 Docker 镜像创建过程的核心。本文将探讨 docker build 命令、其语法、用法以及优化 Docker 构建的最佳实践。本教程的目标是手把手教你如何在 Linux 服务器上使…

Springboot应用开发:配置类整理

目录 编写目的 一、线程池 1.1 setCorePoolSize 1.2 setMaxPoolSize 1.3 setQueueCapacity 1.4 setKeepAliveSeconds 1.5 setThreadNamePrefix 1.6 setRejectedExecutionHandler 1.7 示例代码 二、Durid数据库连接池 2.1 ServletRegistrationBean 2.2 FilterRegist…

【Spring】深入解析 Spring 原理:Bean 的多方面剖析(源码阅读)

&#x1f525;个人主页&#xff1a; 中草药 &#x1f525;专栏&#xff1a;【Java】登神长阶 史诗般的Java成神之路 一、Bean的作用域 在 Java Spring 框架中&#xff0c;Bean 的作用域是一个关键概念&#xff0c;它决定了 Bean 的生命周期和实例化方式&#xff0c;对应用的性…