【Python】科研代码学习:八 FineTune PretrainedModel (用 trainer,用 script);LLM文本生成

【Python】科研代码学习:八 FineTune PretrainedModel [用 trainer,用 script] LLM文本生成

  • 自己整理的 HF 库的核心关系图
  • 用 trainer 来微调一个预训练模型
  • 用 script 来做训练任务
  • 使用 LLM 做生成任务
    • 可能犯的错误,以及解决措施

自己整理的 HF 库的核心关系图

  • 根据前面几期,自己整理的核心库的使用/继承关系
    在这里插入图片描述

用 trainer 来微调一个预训练模型

  • HF官网API:FT a PretrainedModel
    今天讲讲FT训练相关的内容吧
    这里就先不提用 keras 或者 native PyTorch 微调,直接看一下用 trainer 微调的基本流程
  • 第一步:加载数据集和数据集预处理
    使用 datasets 进行加载 HF 数据集
from datasets import load_datasetdataset = load_dataset("yelp_review_full")

另外,需要用 tokenizer 进行分词。自定义分词函数,然后使用 dataset.map() 可以把数据集进行分词。

from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)

也可以先选择其中一小部分的数据单独拿出来,做测试或者其他任务

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
  • 第二步,加载模型,选择合适的 AutoModel 或者比如具体的 LlamaForCausalLM 等类。
    使用 model.from_pretrained() 加载
from transformers import AutoModelForSequenceClassificationmodel = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
  • 第三步,加载 / 创建训练参数 TrainingArguments
from transformers import TrainingArgumentstraining_args = TrainingArguments(output_dir="test_trainer")
  • 第四步,指定评估指标。trainer 在训练的时候不会去自动评估模型的性能/指标,所以需要自己提供一个
    ※ 这个 evaluate 之前漏了,放后面学,这里先摆一下 # TODO
import numpy as np
import evaluatemetric = evaluate.load("accuracy")
  • 第五步,使用 trainer 训练,提供之前你创建好的:
    model模型,args训练参数,train_dataset训练集,eval_dataset验证集,compute_metrics评估方法
trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics,
)
trainer.train()
  • 完整代码,请替换其中的必要参数来是配置自己的模型和任务
from datasets import load_dataset
from transformers import (LlamaTokenizer,LlamaForCausalLM,TrainingArguments,Trainer,)
import numpy as np
import evaluatedef tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)"""
Load dataset, tokenizer, model, training args
preprosess into tokenized dataset
split training dataset and eval dataset
"""
dataset = load_dataset("xxxxxxxxxxxxxxxxxxxx")tokenizer = LlamaTokenizer.from_pretrained("xxxxxxxxxxxxxxxxxxxxxxxxxx")
tokenized_datasets = dataset.map(tokenize_function, batched=True)small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))model = LlamaForCausalLM.from_pretrained("xxxxxxxxxxxxxxx")training_args = TrainingArguments(output_dir="xxxxxxxxxxxxxx")"""
define metrics
set trainer and train
"""trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics,
)trainer.train()

用 script 来做训练任务

  • 我们在很多项目中,都会看到启动脚本是一个 .sh 文件,一般里面可能会这么写:
python examples/pytorch/summarization/run_summarization.py \--model_name_or_path google-t5/t5-small \--do_train \--do_eval \--dataset_name cnn_dailymail \--dataset_config "3.0.0" \--source_prefix "summarize: " \--output_dir /tmp/tst-summarization \--per_device_train_batch_size=4 \--per_device_eval_batch_size=4 \--overwrite_output_dir \--predict_with_generate
  • 或者最近看到的一个
OUTPUT_DIR=${1:-"./alma-7b-dpo-ft"}
pairs=${2:-"de-en,cs-en,is-en,zh-en,ru-en,en-de,en-cs,en-is,en-zh,en-ru"}
export HF_DATASETS_CACHE=".cache/huggingface_cache/datasets"
export TRANSFORMERS_CACHE=".cache/models/"
# random port between 30000 and 50000
port=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))accelerate launch --main_process_port ${port} --config_file configs/deepspeed_train_config_bf16.yaml \run_cpo_llmmt.py \--model_name_or_path haoranxu/ALMA-13B-Pretrain \--tokenizer_name haoranxu/ALMA-13B-Pretrain \--peft_model_id  haoranxu/ALMA-13B-Pretrain-LoRA \--cpo_scorer kiwi_xcomet \--cpo_beta 0.1 \--use_peft \--use_fast_tokenizer False \--cpo_data_path  haoranxu/ALMA-R-Preference \--do_train \--language_pairs ${pairs} \--low_cpu_mem_usage \--bf16 \--learning_rate 1e-4 \--weight_decay 0.01 \--gradient_accumulation_steps 1 \--lr_scheduler_type inverse_sqrt \--warmup_ratio 0.01 \--ignore_pad_token_for_loss \--ignore_prompt_token_for_loss \--per_device_train_batch_size 2 \--evaluation_strategy no \--save_strategy steps \--save_total_limit 1 \--logging_strategy steps \--logging_steps 0.05 \--output_dir ${OUTPUT_DIR} \--num_train_epochs 1 \--predict_with_generate \--prediction_loss_only \--max_new_tokens 256 \--max_source_length 256 \--seed 42 \--overwrite_output_dir \--report_to none \--overwrite_cache 
  • 玛雅,这么多 --xxx ,看着头疼,也不知道怎么搞出来这么多参数作为启动文件的。
    这种就是通过 script 启动任务了
  • github:transformers/examples
    看一下 HF github 给的一些任务的 examples 学习例子,就会发现
    main 函数中,会有这样的代码
    这个就是通过 argparser 来获取参数
    貌似还有 parserHfArgumentParser,这些都可以打包解析参数,又是挖个坑 # TODO
    这样的话,就可以通过 .sh 来在启动脚本中提供相关参数了
def main():parser = argparse.ArgumentParser()parser.add_argument("--model_type",default=None,type=str,required=True,help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),)parser.add_argument("--model_name_or_path",default=None,type=str,required=True,help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),)parser.add_argument("--prompt", type=str, default="")parser.add_argument("--length", type=int, default=20)parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")# ....... 太长省略
  • 用脚本启动还有什么好处呢
    可以使用 accelerate launch run_summarization_no_trainer.py 进行加速训练
    再给 accelerate 挖个坑 # TODO
  • 所以,在 .sh script 启动脚本中具体能提供哪些参数,取决于这个入口 .py 文件的 parser 打包解析了哪些参数,然后再利用这些参数做些事情。

使用 LLM 做生成任务

  • HF官网API:Generation with LLMs
    官方都特地给这玩意儿单独开了一节,就说明其中有些很容易踩的坑…
  • 对于 CausalLM,首先看一下 next token 的生成逻辑:输入进行分词与嵌入后,通过多层网络,然后进入到一个LM头,最终获得下一个 token 的概率预测
  • 那么生成句子的逻辑,就是不断重复这个过程,获得 next token 概率预测后,通过一定的算法选择下一个 token,然后再重复该操作,就能生成整个句子了。
  • 那什么时候停止呢?要么是下一个token选择了 eos,要么是到达了之前定义的 max token length
    在这里插入图片描述
  • 接下来看一下代码逻辑
  • 第一步,加载模型
    device_map:控制模型加载在 GPUs上,不过一般我会使用 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 以及 os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
    load_in_4bit 设置加载量化
from transformers import AutoModelForCausalLMmodel = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto", load_in_4bit=True
)
  • 第二步,加载分词器和分词
    记得分词的向量需要加载到 cuda
from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")
  • 但这个是否需要分词取决于特定的 model.generate() 方法的参数
    就比如 disc 模型的 generate() 方法的参数为:
    也就是说,我输入的 prompt 只用提供字符串即可,又不需要进行分词或者分词器了。
    在这里插入图片描述
  • 第三步,通常的 generate 方法,输入是 tokenized 后的数组,然后获得 ids 之后再 decode 变成对应的字符结果
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  • 当然我也可以批处理,一次做多个操作,批处理需要设置pad_token
tokenizer.pad_token = tokenizer.eos_token  # Most LLMs don't have a pad token by default
model_inputs = tokenizer(["A list of colors: red, blue", "Portugal is"], return_tensors="pt", padding=True
).to("cuda")
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

可能犯的错误,以及解决措施

  • 控制输出句子的长度
    需要在 generate 方法中提供 max_new_tokens 参数
model_inputs = tokenizer(["A sequence of numbers: 1, 2"], return_tensors="pt").to("cuda")# By default, the output will contain up to 20 tokens
generated_ids = model.generate(**model_inputs)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]# Setting `max_new_tokens` allows you to control the maximum length
generated_ids = model.generate(**model_inputs, max_new_tokens=50)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  • 生成策略修改
    有时候默认使用贪心策略来获取 next token,这个时候容易出问题(循环生成等),需要设置 do_sample=True
    在这里插入图片描述

  • pad 对齐方向
    如果输入不等长,那么会进行pad操作
    由于默认是右侧padding,而LLM在训练时没有学会从pad_token接下来的生成策略,所以会出问题
    所以需要设置 padding_side="left![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/6084ff91d85c49e28a4faf498b8e5997.png) "
    在这里插入图片描述

  • 如果没有使用正确的 prompt(比如训练时的prompt格式),得到的结果就会不如预期
    (in one sitting = 一口气) (thug = 暴徒)
    这里需要参考 HF对话模型的模板 以及 HF LLM prompt 指引
    在这里插入图片描述
    比如说,QA的模板就像这样。
    更高级的还有 few shotCOT 技巧。

torch.manual_seed(4)
prompt = """Answer the question using the context below.
Context: Gazpacho is a cold soup and drink made of raw, blended vegetables. Most gazpacho includes stale bread, tomato, cucumbers, onion, bell peppers, garlic, olive oil, wine vinegar, water, and salt. Northern recipes often include cumin and/or pimentón (smoked sweet paprika). Traditionally, gazpacho was made by pounding the vegetables in a mortar with a pestle; this more laborious method is still sometimes used as it helps keep the gazpacho cool and avoids the foam and silky consistency of smoothie versions made in blenders or food processors.
Question: What modern tool is used to make gazpacho?
Answer:
"""sequences = pipe(prompt,max_new_tokens=10,do_sample=True,top_k=10,return_full_text = False,
)for seq in sequences:print(f"Result: {seq['generated_text']}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/736983.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ZYNQ实验--PDM波形生成

一、PDM简介 将信号的振幅变化按比例地变换成脉冲宽度的变化,得到脉冲宽度调制(PDM)。详细的原理理论可以参考该文:文献阅读–Pulse-Width Modulation,本文主要介绍PDM的FPGA实现,PDM的生成方式很多具体形式根据需求会有所不同 二…

【Stable Diffusion】入门:原理简介+应用安装(Windows)+生成步骤

【Stable Diffusion】入门:原理简介应用安装(Windows)生成步骤 原理简介应用安装 原理简介 稳定扩散生成模型(Stable Diffusion)是一种潜在的文本到图像扩散模型,能够在给定任何文本输入的情况下生成照片般逼真的图像。 应用安…

中国广电的独特优势:与三大运营商相比的亮点

2023年,中国广电正式上市了,发出了第一批号段192的号码,然而值得大家了解的是:在中国的通信市场中,中国移动、中国联通和中国电信长期以来占据主导地位。然而,随着中国广电的加入,市场格局正在发…

了解转义字符

了解转义字符 也许在前面的代码中你看到 \n , \0 很纳闷是啥。其实在字符中有⼀组特殊的字符是转义字符,转义字符顾名思义:转变原来的意思的字符。 比如:我们有字符 n ,在字符串中打印的时候自然能打印出这个字符,如下…

鸿蒙操作系统 HarmonyOS 3.2 API 9 Stage模型通过ArkTS接入高德地图

用鸿蒙ArkTS语言开发地图APP应用时&#xff0c;很多地图厂商只接入了鸿蒙Java&#xff0c;ArkTS版本陆续接入中&#xff0c;等一段时间才能面世&#xff0c;当前使用地图只能通过鸿蒙的Web组件&#xff0c;将HTML页面嵌入到鸿蒙APP中。具体方法如下&#xff1a;编写HTML <!…

C++容器适配器stack、queue、priority_queue

文章目录 C容器适配器stack、queue、priority_queue1、stack1.1、stack的介绍1.2、stack的使用1.3、stack的模拟实现 2、queue2.1、queue的介绍2.2、queue的使用2.3、queue的模拟实现 3、priority_queue3.1、priority_queue的介绍3.2、priority_queue的使用3.3、仿函数3.4、pri…

IAR全面支持小华全系芯片,强化工控及汽车MCU生态圈

IAR Embedded Workbench for Arm已全面支持小华半导体系列芯片&#xff0c;加速高端工控MCU和车用MCU应用的安全开发 嵌入式开发软件和服务的全球领导者IAR与小华半导体有限公司&#xff08;以下简称“小华半导体”&#xff09;联合宣布&#xff0c;IAR Embedded Workbench fo…

C语言——递归题

对于递归问题&#xff0c;我们一定要想清楚递归的结束条件&#xff0c;每个递归的结束条件&#xff0c;就是思考这个问题的起始点。 题目1&#xff1a; 思路&#xff1a;当k1时&#xff0c;任何数的1次方都是原数&#xff0c;此时返回n&#xff0c;这就是递归的结束条件&#…

基于FPGA加速的bird-oid object算法实现

导语 今天继续康奈尔大学FPGA 课程ECE 5760的典型案例分享——基于FPGA加速的bird-oid object算法实现。 &#xff08;更多其他案例请参考网站&#xff1a; Final Projects ECE 5760&#xff09; 1. 项目概述 项目网址 ECE 5760 Final Project 模型说明 Bird-oid object …

企业计算机服务器中了mkp勒索病毒如何解密,mkp勒索病毒解密流程

网络技术的应用与发展&#xff0c;为企业的生产运营提高了效率&#xff0c;越来越多的企业利用网络开展多项工作业务&#xff0c;利用网络的优势&#xff0c;可以为企业更好的服务&#xff0c;但是稍不注意就会被网络威胁所盯上。近日&#xff0c;云天数据恢复中心接到多家企业…

CAP告诉你系统没法做到完美,只能做到权衡和适当

一、CAP介绍 CAP原理&#xff0c;全称为Consistency&#xff08;一致性&#xff09;、Availability&#xff08;可用性&#xff09;和Partition tolerance&#xff08;分区容错性&#xff09;&#xff0c;是分布式系统设计中的基本原理。它强调了在设计分布式系统时&#xff0c…

面试题:分布式锁用了 Redis 的什么数据结构

在使用 Redis 实现分布式锁时&#xff0c;通常使用 Redis 的字符串&#xff08;String&#xff09;。Redis 的字符串是最基本的数据类型&#xff0c;一个键对应一个值&#xff0c;它能够存储任何形式的字符串&#xff0c;包括二进制数据。字符串类型的值最多可以是 512MB。 Re…

二次供水无人值守解决方案

二次供水无人值守解决方案 二次供水系统存在一定的管理难题和技术瓶颈&#xff0c;如设备老化、维护不及时导致的水质安全隐患&#xff0c;以及如何实现高效运行和智能化管理等问题。在一些地区&#xff0c;特别是老旧小区或农村地区&#xff0c;二次供水设施建设和改造滞后&a…

grpc的metadata机制

引言 gRPC让我们可以像本地调用一样实现远程调用&#xff0c;对于每一次的RPC调用中&#xff0c;都可能会有一些有用的数据&#xff0c;而这些数据就可以通过metadata来传递。metadata是以key-value的形式存储数据的&#xff0c;其中key是 string类型&#xff0c;而value是[]s…

mysql日常优化的总结

文章目录 一、数据表结构相关优化建字段类型注意事项1. int类型的选择2.varchar、char、text类型3.date、datetime、timestamp类型 表规划1. 垂直分表2. 水平分表 二、查询语句优化1.对于字段多的表&#xff0c;避免使用SELECT *2.避免使用!操作符3.避免使用null做条件4.like查…

ElasticSearchLinux安装和springboot整合的记录和遇到的问题

前面整合遇到的一些问题有的记录在下面了&#xff0c;有的当时忘了记录下来&#xff0c;希望下面的能帮到你们 1&#xff1a;Linux安装ES 下载安装&#xff1a; 参考文章&#xff1a;连接1 连接2 wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch…

PostgreSQL 安装部署

文章目录 一、PostgreSQL部署方式1.Yum方式部署2.RPM方式部署3.源码方式部署4.二进制方式部署5.Docker方式部署 二、PostgreSQL部署1.Yum方式部署1.1.部署数据库1.2.连接数据库 2.RPM方式部署2.1.部署数据库2.2.连接数据库 3.源码方式部署3.1.准备工作3.2.编译安装3.3.配置数据…

手机app制作商用系统软件开发

手机端的用户占比已经超过了电脑端的用户量&#xff0c;企业想要发展手机端的业务就必须拥有自己的app软件&#xff0c;我们公司就是专门为企业开发手机软件的公司&#xff0c;依据我们多年的开发经验为大家提供在开发app软件的时候怎么选择开发软件的公司。 手机app制…

【竞技宝】LOL:knight阿狸伤害爆炸 BLG2-0轻取RA

北京时间2024年3月11日,英雄联盟LPL2024春季常规赛继续进行,昨日共进行三场比赛,首场比赛由BLG对阵RA。本场比赛BLG选手个人实力碾压RA2-0轻松击败对手。以下是本场比赛的详细战报。 第一局: BLG:剑魔、千珏、妮蔻、卡牌、洛 RA:乌迪尔、蔚、阿卡丽、斯莫德、芮尔 首局比赛,B…

QMS质量管理系统在离散型制造业的应用及效益

在离散型制造行业中&#xff0c;质量管理是确保产品满足客户期望和市场需求的关键环节。QMS质量管理系统的实施为企业提供了一种全面、系统的方法来管理和改进质量。 例如&#xff0c;在汽车制造行业&#xff0c;QMS质量管理系统可以应用于零部件采购和装配过程的质量控制。通过…