Paddle 基于预训练模型 ERNIE-Gram 实现语义匹配

文章目录

    • 1. 导入一些包
    • 2. 加载数据
    • 3. 数据预处理
      • 3.1 获取tokenizer,得到 input_ids, token_type_ids
      • 3.2 转换函数、batch化函数、sampler、data_loader
    • 4. 编写模型
    • 5. 学习率、参数衰减、优化器、loss、评估标准
    • 6. 评估函数
    • 7. 训练+评估
    • 8. 保存模型到文件
    • 9. 预测
    • 10. 多GPU并行设置

项目介绍 项目链接:https://aistudio.baidu.com/aistudio/projectdetail/2029701
单机多卡训练参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/02_paddle2.0_develop/06_device_cn.html
支持 star PaddleNLP github https://github.com/PaddlePaddle/PaddleNLP

1. 导入一些包

import time
import os
import numpy as np
import paddle
import paddlenlp
import paddle.nn.functional as F
from paddlenlp.datasets import load_dataset
import paddle.distributed as dist  # 并行

2. 加载数据

batch_size = 64
epochs = 5# 加载数据集
train_ds, dev_ds = load_dataset("lcqmc", splits=["train", "dev"])
# 展示数据
for i, example in enumerate(train_ds):if i < 5:print(example)# {'query': '喜欢打篮球的男生喜欢什么样的女生', 'title': '爱打篮球的男生喜欢什么样的女生', 'label': 1}# {'query': '我手机丢了,我想换个手机', 'title': '我想买个新手机,求推荐', 'label': 1}# {'query': '大家觉得她好看吗', 'title': '大家觉得跑男好看吗?', 'label': 0}# {'query': '求秋色之空漫画全集', 'title': '求秋色之空全集漫画', 'label': 1}# {'query': '晚上睡觉带着耳机听音乐有什么害处吗?', 'title': '孕妇可以戴耳机听音乐吗?', 'label': 0}

3. 数据预处理

3.1 获取tokenizer,得到 input_ids, token_type_ids

# 使用预训练模型的tokenizer
tokenizer = paddlenlp.transformers.ErnieGramTokenizer.from_pretrained("ernie-gram-zh")
# https://gitee.com/paddlepaddle/PaddleNLP/blob/develop/docs/model_zoo/transformers.rstdef convert_data(data, tokenizer, max_seq_len=512, is_test=False):text1, text2 = data["query"], data["title"]encoded_inputs = tokenizer(text=text1, text_pair=text2, max_seq_len=max_seq_len)input_ids = encoded_inputs["input_ids"]token_type_ids = encoded_inputs["token_type_ids"]if not is_test:label = np.array([data["label"]], dtype="int64")return input_ids, token_type_ids, labelreturn input_ids, token_type_idsinput_ids, token_type_ids, label = convert_data(train_ds[0], tokenizer)
print(input_ids)
# [1, 692, 811, 445, 2001, 497, 5, 654, 21, 692, 811, 614, 356, 314, 5, 291, 21, 2, 
#  329, 445, 2001, 497, 5, 654, 21, 692, 811, 614, 356, 314, 5, 291, 21, 2]
print(token_type_ids)
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
print(label)
# [1]

3.2 转换函数、batch化函数、sampler、data_loader

  • 包装转换函数,方便简化后续代码
from functools import partial
trans_func = partial(convert_data, tokenizer=tokenizer, max_seq_len=512)
  • 生成 data_loader
from paddlenlp.data import Stack, Pad, Tuple# batch 化函数
batchify_fn = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id),Pad(axis=0, pad_val=tokenizer.pad_token_type_id),Stack(dtype="int64") # 分别对应于 input_ids, token_type_ids, label
): [d for d in fn(samples)]
# 将长度不同的多个句子padding到统一长度,取N个输入数据中的最大长度
# 长度是指的: 一个batch中的最大长度,主要考虑性能开销
# paddlenlp.data.Tuple	将多个batchify函数包装在一起batch_sampler = paddle.io.DistributedBatchSampler(train_ds, batch_size=batch_size, shuffle=True)
# 注意训练可以用 用分布式的 sampler,充分利用资源train_data_loader = paddle.io.DataLoader(dataset=train_ds.map(trans_func), # 数据转换batch_sampler=batch_sampler, # 取样collate_fn=batchify_fn, # batch化函数return_list=True
)batch_sampler = paddle.io.BatchSampler(dev_ds, batch_size=batch_size, shuffle=False)
dev_data_loader = paddle.io.DataLoader(dataset=dev_ds.map(trans_func),batch_sampler=batch_sampler,collate_fn=batchify_fn,return_list=True
)

4. 编写模型

预训练模型,接 FC

import paddle.nn as nn
pretrained_model = paddlenlp.transformers.ErnieGramModel.from_pretrained("ernie-gram-zh")# %%class TeachingPlanModel(nn.Layer):def __init__(self, pretrained_model, dropout=None):super().__init__()self.ptm = pretrained_modelself.dropout = nn.Dropout(dropout if dropout is not None else 0.1)self.clf = nn.Linear(self.ptm.config["hidden_size"], 2)def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids, attention_mask)cls_embedding = self.dropout(cls_embedding)logits = self.clf(cls_embedding)probs = F.softmax(logits)return probsmodel = TeachingPlanModel(pretrained_model)

5. 学习率、参数衰减、优化器、loss、评估标准

from paddlenlp.transformers import LinearDecayWithWarmupnum_training_steps = len(train_data_loader) * epochs# 学习率调度器
lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.0)
# 衰减的参数
decay_params = [p.name for n, p in model.named_parameters()if not any(nd in n for nd in ["bias", "norm"])
]# 优化器
optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,parameters=model.parameters(),weight_decay=0.0,apply_decay_param_fun=lambda x: x in decay_params
)# 损失函数
criterion = paddle.nn.loss.CrossEntropyLoss()# 评估标准
metric = paddle.metric.Accuracy()

6. 评估函数

@paddle.no_grad()
def evaluate(model, criterion, metric, data_loader, phase="dev"):model.eval()metric.reset()losses = []for batch in data_loader:input_ids, token_type_ids, labels = batch# 前向传播probs = model(input_ids=input_ids, token_type_ids=token_type_ids)# 损失loss = criterion(probs, labels)losses.append(loss.numpy())# 准确率correct = metric.compute(probs, labels)metric.update(correct)acc = metric.accumulate()print("评估 {} loss: {:.5}, acc: {:.5}".format(phase, np.mean(losses), acc))model.train()metric.reset()

7. 训练+评估

global global_step
global_step = 0
t_start = time.time()
for epoch in range(1, epochs + 1):for step, batch in enumerate(train_data_loader, start=1):input_ids, token_type_ids, labels = batch# 前向传播probs = model(input_ids=input_ids, token_type_ids=token_type_ids)# 损失loss = criterion(probs, labels)# 准确率correct = metric.compute(probs, labels)metric.update(correct)acc = metric.accumulate()global_step += 1# 打印训练信息if global_step % 10 == 0:print("训练步数 %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f, speed: %.2f step/s"% (global_step, epoch, step, loss, acc,10 / (time.time() - t_start)))t_start = time.time()# 反向传播loss.backward()# 更新参数optimizer.step()lr_scheduler.step()# 清除梯度optimizer.clear_grad()# 训练100步,评估一次if global_step % 100 == 0:evaluate(model, criterion, metric, dev_data_loader, "dev")

训练过程:

训练步数 5010, epoch: 3, batch: 1278, loss: 0.39062, acc: 0.90781, speed: 0.33 step/s
训练步数 5020, epoch: 3, batch: 1288, loss: 0.41552, acc: 0.90312, speed: 1.87 step/s
训练步数 5030, epoch: 3, batch: 1298, loss: 0.34011, acc: 0.90521, speed: 1.57 step/s
训练步数 5040, epoch: 3, batch: 1308, loss: 0.37718, acc: 0.90703, speed: 1.55 step/s
训练步数 5050, epoch: 3, batch: 1318, loss: 0.35848, acc: 0.91125, speed: 1.80 step/s
训练步数 5060, epoch: 3, batch: 1328, loss: 0.37751, acc: 0.91042, speed: 1.67 step/s
训练步数 5070, epoch: 3, batch: 1338, loss: 0.42495, acc: 0.91161, speed: 1.72 step/s
训练步数 5080, epoch: 3, batch: 1348, loss: 0.38556, acc: 0.91035, speed: 1.67 step/s
训练步数 5090, epoch: 3, batch: 1358, loss: 0.40671, acc: 0.91024, speed: 1.85 step/s
训练步数 5100, epoch: 3, batch: 1368, loss: 0.36824, acc: 0.91000, speed: 1.74 step/s
评估 dev loss: 0.44395, acc: 0.86321
训练步数 5110, epoch: 3, batch: 1378, loss: 0.41520, acc: 0.92188, speed: 0.32 step/s
训练步数 5120, epoch: 3, batch: 1388, loss: 0.42261, acc: 0.91250, speed: 1.65 step/s
训练步数 5130, epoch: 3, batch: 1398, loss: 0.37139, acc: 0.91615, speed: 1.68 step/s
训练步数 5140, epoch: 3, batch: 1408, loss: 0.38124, acc: 0.90781, speed: 1.68 step/s
训练步数 5150, epoch: 3, batch: 1418, loss: 0.41482, acc: 0.90781, speed: 1.76 step/s
训练步数 5160, epoch: 3, batch: 1428, loss: 0.38554, acc: 0.91120, speed: 1.75 step/s
训练步数 5170, epoch: 3, batch: 1438, loss: 0.38424, acc: 0.91027, speed: 1.77 step/s
训练步数 5180, epoch: 3, batch: 1448, loss: 0.39620, acc: 0.90938, speed: 1.72 step/s
训练步数 5190, epoch: 3, batch: 1458, loss: 0.41320, acc: 0.90747, speed: 1.77 step/s
训练步数 5200, epoch: 3, batch: 1468, loss: 0.39017, acc: 0.90859, speed: 1.64 step/s
评估 dev loss: 0.4526, acc: 0.8556

8. 保存模型到文件

pathname = "checkpoint"
isExists = os.path.exists(pathname)
if not isExists:os.mkdir(pathname)save_dir = os.path.join(pathname, "model_%d" % global_step)
save_param_path = os.path.join(save_dir, "model_state_pdparams")paddle.save(model.state_dict(), save_param_path)
tokenizer.save_pretrained(save_dir)

9. 预测

def predict(model, data_loader):batch_probs = []model.eval() # 评估模式with paddle.no_grad(): # 不需要梯度更新for batch_data in data_loader:input_ids, token_type_ids = batch_datainput_ids = paddle.to_tensor(input_ids)token_type_ids = paddle.to_tensor(token_type_ids)batch_prob = model(input_ids=input_ids, token_type_ids=token_type_ids).numpy()batch_probs.append(batch_prob)batch_probs = np.concatenate(batch_probs, axis=0)return batch_probs# 数据转换函数
trans_func_test = partial(convert_data, tokenizer=tokenizer, max_seq_len=512, is_test=True)# batch化函数
batchify_fn = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id),Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
): [data for data in fn(samples)]# 加载测试集
test_ds = load_dataset("lcqmc", splits=["test"])# 定义 sampler
batch_sampler = paddle.io.BatchSampler(test_ds, batch_size=batch_size, shuffle=False)# 定义data_loader
predict_data_loader = paddle.io.DataLoader(dataset=test_ds.map(trans_func_test),batch_sampler=batch_sampler,collate_fn=batchify_fn,return_list=True
)# 定义模型
pretrained_model = paddlenlp.transformers.ErnieGramModel.from_pretrained("ernie-gram-zh")
model = TeachingPlanModel(pretrained_model)# 加载训练好的参数
state_dict = paddle.load(save_param_path)
# 设置参数
model.set_dict(state_dict)# 预测
y_probs = predict(model, predict_data_loader)
y_preds = np.argmax(y_probs, axis=1)# 预测结果写入文件
with open("lcqmc.tsv", 'w', encoding="utf-8") as f:f.write("index\tprediction\n")for idx, y_pred in enumerate(y_preds):f.write("{}\t{}\n".format(idx, y_pred))# text_pair = test_ds.data[idx]# text_pair["label"] = y_pred# print(text_pair)

10. 多GPU并行设置

import paddle.distributed as dist  # 并行if __name__ == "__main__":dist.init_parallel_env()  # 初始化并行环境# 启动命令 python -m paddle.distributed.launch --gpus '0,1' xxx.py &# your code 。。。

可以看见 2个 GPU 都使用起来了

Sat Jun 19 18:18:34 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA Tesla T4     Off  | 00000000:00:09.0 Off |                    0 |
| N/A   67C    P0    69W /  70W |   9706MiB / 15109MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA Tesla T4     Off  | 00000000:00:0A.0 Off |                    0 |
| N/A   68C    P0    68W /  70W |  11004MiB / 15109MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     34450      C   ...nda3/envs/pp21/bin/python     9703MiB |
|    1   N/A  N/A     34453      C   ...nda3/envs/pp21/bin/python    11001MiB |
+-----------------------------------------------------------------------------+

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/472077.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

j2ee和mysql怎么连接_J2EE数据库连接不再烦恼

刚开始接触j2ee的时候总是为数据库的开关连接问题而烦恼,虽然问题很简单却很是琐碎,于是干脆写成一个类将所有必要的基本操作全部总结进去,以后只要轻松的import一下就可以了啊:)菜鸟们enjoying!import java.sql.Connection;import java.sql.Statement;import java.sql.ResultS…

SQL SERVER PIVOT 行转列、列传行

在数据库操作中&#xff0c;有些时候我们遇到需要实现“行转列”的需求&#xff0c;例如一下的表为某店铺的一周收入情况表&#xff1a; WEEK_INCOME(WEEK VARCHAR(10),INCOME DECIMAL) 我们先插入一些模拟数据&#xff1a; INSERT INTO WEEK_INCOME SELECT 星期一,1000 UNION…

python安装scipy出现红字_windows下安装numpy,scipy遇到的问题总结

1.安装numpy下载numpy编译包&#xff0c;进入该目录下&#xff0c; 调用命令 python setup.py install进行安装&#xff0c;返回错误&#xff1a;error: Unable to find vcvarsall.bat出现这个原因的问题貌似跟vc编译器有关&#xff0c;具体原因没有细究&#xff0c;但是经Goog…

mysql cluster 查看数据库表名称_MySQL Cluster如何创建磁盘表方法解读

MySQL Cluster采用一系列的Disk Data objects来实现磁盘表;接下来为您详细介绍一、概念MySQL Cluster采用一系列的Disk Data objects来实现磁盘表。Tablespaces&#xff1a;作用是作为其他Disk Data objects的容器。Undo log files&#xff1a;存储事务进行回滚需要的信息&…

(运算符) 运算符

& 运算符既可作为一元运算符也可作为二元运算符。 备注 一元 & 运算符返回操作数的地址&#xff08;要求 unsafe 上下文&#xff09;。 为整型和 bool 类型预定义了二进制 & 运算符。 对于整型&#xff0c;& 计算操作数的逻辑按位“与”。 对于 bool 操作数&am…

LeetCode 1903. 字符串中的最大奇数

文章目录1. 题目2. 解题1. 题目 给你一个字符串 num &#xff0c;表示一个大整数。 请你在字符串 num 的所有 非空子字符串 中找出 值最大的奇数 &#xff0c;并以字符串形式返回。如果不存在奇数&#xff0c;则返回一个空字符串 “” 。 子字符串 是字符串中的一个连续的字符…

python模拟qq空间登录_python selenium模拟登录163邮箱和QQ空间

最近在看python网络爬虫&#xff0c;于是我想自己写一个邮箱和QQ空间的自动登录的小程序&#xff0c;下面以登录163邮箱和QQ空间和为例&#xff1a;了解到在Web应用中经常会遇到frame/iframe 表单嵌套页面的应用&#xff0c;WebDriver 只能在一个页面上对元素识别与定位&#x…

mysql分页插件springboot_SpringBoot--使用Mybatis分页插件

1、导入分页插件包和jpa包org.springframework.bootspring-boot-starter-data-jpacom.github.pagehelperpagehelper-spring-boot-starter1.2.52、增加分页配置# 主键自增回写方法,默认值MYSQL,详细说明请看文档mapper:identity: MYSQL# 设置 insert 和 update 中&#xff0c;是…

top 命令详解

作用&#xff1a; 实时动态查看系统的整体运行情况&#xff0c; 是一个综合了多方信息监测系统性能和运行信息的实用工具。 选项&#xff1a;-b 以批处理模式操作-c 显示完整的命令-d 屏幕刷新间隔时间-I 忽略失效过程-s 保密模式-S 累积模式-i 设置时间间隔-u 指定用户…

LeetCode 1904. 你完成的完整对局数

文章目录1. 题目2. 解题1. 题目 一款新的在线电子游戏在近期发布&#xff0c;在该电子游戏中&#xff0c;以 刻钟 为周期规划若干时长为 15 分钟 的游戏对局。 这意味着&#xff0c;在 HH:00、HH:15、HH:30 和 HH:45 &#xff0c;将会开始一个新的对局&#xff0c;其中 HH 用一…

python scipy库函数solve用法_如何在中使用事件scipy.integrate.solve_ivp

我不确定事件处理是否scipy.integrate.solve_ivp工作正常。在下面的例子中&#xff0c;我对一个导数进行积分&#xff0c;得到一个三次多项式&#xff0c;它的根在x-6&#xff0c;x-2和x2。我设置了一个事件函数&#xff0c;返回y&#xff0c;在x值处为零。我希望在解决方案的t…

将MYSQL查询导出到文件

sql文件&#xff1a; set names utf8; select * from xxxxx mysql命令&#xff1a; mysql -h xxxx -uxxxx -p < 4.sql > 4.txt 转载于:https://www.cnblogs.com/aguncn/p/4449969.html

mysql维护计划任务_浅谈MySQL event 计划任务

一、查看event是否开启show variables like %sche%;set global event_scheduler 1;二、-- 设置时区并设置计划事件调度器开启&#xff0c;也可以 event_scheduler onset time_zone 8:00;set global event_scheduler 1;-- 设置该事件使用或所属的数据库base数据库use test;--…

LeetCode 1905. 统计子岛屿(BFS)

文章目录1. 题目2. 解题1. 题目 给你两个 m x n 的二进制矩阵 grid1 和 grid2 &#xff0c;它们只包含 0 &#xff08;表示水域&#xff09;和 1 &#xff08;表示陆地&#xff09;。 一个 岛屿 是由 四个方向 &#xff08;水平或者竖直&#xff09;上相邻的 1 组成的区域。 任…

vue是什么软件_Angular vs React vs Vue:2020年的最佳选择是什么?

在2020年&#xff0c;想象没有HTML&#xff0c;CSS和Javascript的Web开发是不切实际的。 Javascript是Web应用程序前端开发的灵魂。 如果您登陆此页面&#xff0c;那么我认为您在Java语言和Java编程语言的不同框架和库之间感到困惑。企业和软件开发人员最常见的一些查询是&…

Ajax的实现

一、JavaScript的ajax //Ajaxvar xhr;if(window.XMLHttpRequest){ //除IE外的浏览器xhr new XMLHttpRequest()}else{xhr new ActiveXObject("Microsoft.XMLHTTP"); //IE}xhr.open(get,http://demo_get.asp,true); //三个参数&#xff0c;method&#xff0c;…

LeetCode 1910. 删除一个字符串中所有出现的给定子字符串

文章目录1. 题目2. 解题1. 题目 给你两个字符串 s 和 part &#xff0c;请你对 s 反复执行以下操作直到 所有 子字符串 part 都被删除&#xff1a; 找到 s 中 最左边 的子字符串 part &#xff0c;并将它从 s 中删除。 请你返回从 s 中删除所有 part 子字符串以后得到的剩余…

mysql获取网站绝对路径_Symfony2获取web目录绝对路径、相对路径、网址的方法

本文实例讲述了Symfony2获取web目录绝对路径、相对路径、网址的方法。分享给大家供大家参考&#xff0c;具体如下&#xff1a;对于你的需求&#xff0c;Symfony2通过DIC提供了kernel服务&#xff0c;以及request(请求)的封装。在controller里(在其他地方你可以自行注入kernel&a…

tcp长连接和短连接的区别_TCP --- 连接

一个TCP连接由4个元组组成&#xff1a;2个ip地址和2个端口号tcp三次握手为什么是三次握手解决历史连接问题通过三次握手才能阻止重复历史连接的初始化通过三次握手&#xff0c;才能对通讯双方的初始序号初始化如果只有2次握手&#xff0c;发送方一旦发送创建连接的请求就无法撤…

python substr函数_Sql SUBSTR函数

SQL常用函数总结SQL常用函数总结 这是我在项目开发中使用db2数据库写存储过程的时候经常用到的sql函数.希望对大家有所帮助: sql cast函数 (1).CAST()函数的参数是一个表达式,它包括用AS关键字分 ...SQL中CHARINDEX&lpar;&rpar;&sol;INSTR&lpar;&rpar;函数…