NLP之Bert实现文本多分类

文章目录

  • 代码
  • 代码整体流程解读
  • debug上面的代码

代码

from pypro.chapters03.demo03_数据获取与处理 import train_list, label_list, val_train_list, val_label_list
import tensorflow as tf
from transformers import TFBertForSequenceClassificationbert_model = "bert-base-chinese"model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)
model.compile(metrics=['accuracy'], loss=tf.nn.sigmoid_cross_entropy_with_logits)
model.summary()
result = model.fit(x=train_list[:24], y=label_list[:24], batch_size=12, epochs=1)
print(result.history)
# 保存模型(模型保存的本质就是保存训练的参数,而对于深度学习而言还保存神经网络结构)
model.save_weights('../data/model.h5')model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)
model.load_weights('../data/model.h5')
result = model.predict(val_train_list[:12])  # 预测值
print(result)
result = tf.nn.sigmoid(result)
print(result)
result = tf.cast(tf.greater_equal(result, 0.5), tf.float32)
print(result)

代码整体流程解读

这段代码的目的是利用TensorFlow和transformers库来进行文本序列的分类任务。下面是整体流程的概述和逐步计划:

  1. 导入必要的库和数据:

    • 从一个叫做 pypro.chapters03.demo03_数据获取与处理 的模块中导入了四个列表:train_list, label_list, val_train_list, val_label_list。这些列表分别包含训练数据、训练标签、验证数据和验证标签。
    • 导入TensorFlow和transformers库。
  2. 初始化预训练的BERT模型:

    • 使用 bert-base-chinese 模型初始化一个用于序列分类的BERT模型。
    • 模型被配置为对32个不同的标签进行分类。
  3. 编译模型:

    • 使用sigmoid交叉熵作为损失函数,并跟踪准确度作为性能指标。
  4. 模型摘要:

    • 输出模型的概要信息,包括每一层的名称、类型、输出形状和参数数量。
  5. 训练模型:

    • 使用提供的训练数据和标签(仅取前24个样本)来训练模型。
    • 批量大小设置为12,训练仅进行1个时代(epoch),这意味着数据将通过模型传递一次。
  6. 输出训练结果:

    • 打印训练过程中记录的历史数据,通常包括损失值和准确度。
  7. 保存模型权重:

    • 将训练后的模型权重保存到本地文件 model.h5
  8. 加载模型权重:

    • 初始化一个新的模型结构,并加载之前保存的权重。
  9. 模型预测:

    • 使用验证数据(仅取前12个样本)进行预测。
  10. 激活函数处理:

    • 将预测结果通过sigmoid函数处理,转换成0到1之间的值。
  11. 转换预测结果:

    • 通过比较预测值是否大于或等于0.5来将概率转换为二进制分类结果。

debug上面的代码

下面逐行解释上述代码:

  1. from pypro.chapters03.demo03_数据获取与处理 import train_list, label_list, val_train_list, val_label_list

    这行代码从demo03_数据获取与处理模块中导入四个列表。这些列表包含训练数据和标签(train_list, label_list),以及验证数据和标签(val_train_list, val_label_list)。这是数据准备步骤的一部分。

  2. import tensorflow as tf

    这行代码导入了TensorFlow库,它是一个广泛用于机器学习和深度学习任务的开源库。

  3. from transformers import TFBertForSequenceClassification

    这里导入了transformers库中的TFBertForSequenceClassification类。transformers库包含了许多预训练模型,用于NLP任务,这里特别导入的是适用于TensorFlow的BERT模型,用于序列分类任务。

  4. bert_model = "bert-base-chinese"

    定义一个字符串变量bert_model,它保存了预训练模型的名称。在这里,我们将使用中文BERT基础模型。

  5. model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)

    使用bert-base-chinese模型和TFBertForSequenceClassification类创建一个新的序列分类模型实例。num_labels=32表明有32个不同的类别用于分类。

  6. model.compile(metrics=['accuracy'], loss=tf.nn.sigmoid_cross_entropy_with_logits)

    编译模型,设置度量为准确度(accuracy),并使用sigmoid_cross_entropy_with_logits作为损失函数,这通常用于二分类问题,但在这里,由于是多标签分类(32个类别),可能是对每个标签进行二分类。

  7. model.summary()

    输出模型的摘要信息,包括模型中的层,每层的输出形状和参数数量等详细信息。

  8. result = model.fit(x=train_list[:24], y=label_list[:24], batch_size=12, epochs=1)

    开始训练模型,仅使用前24个样本作为训练数据和标签。批处理大小设置为12,意味着每次梯度更新将基于12个样本。epochs=1表示整个数据集只通过模型训练一次。

  9. print(result.history)

    打印出训练过程中的历史数据,如损失和准确度。

  10. model.save_weights('../data/model.h5')

    保存训练好的模型权重到本地文件model.h5

  11. model = TFBertForSequenceClassification.from_pretrained(bert_model, num_labels=32)

    再次初始化一个模型,用于演示如何从头加载一个模型。

  12. model.load_weights('../data/model.h5')

    加载先前保存的模型权重。

  13. result = model.predict(val_train_list[:12]) # 预测值

    使用验证数据集中的前12个样本进行预测,得到模型的输出。

  14. print(result)

    打印出预测结果。

  15. result = tf.nn.sigmoid(result)

    将模型的原始输出通过sigmoid函数转换,得到一个在0到1之间的值,表示属于每个类别的概率。

  16. print(result)

    再次打印经过sigmoid激活函数处理后的预测结果。

  17. result = tf.cast(tf.greater_equal(result, 0.5), tf.float32)

    将sigmoid输出的概率转换为二分类结果。对于每个标签,如果概率大于或等于0.5,则认为该样本属于该标签(转换为1),否则不属于(转换为0)。

  18. `print

(result)`

最后,打印出转换后的分类结果。

整体而言,这段代码展示了使用预训练的BERT模型在一个多标签文本分类任务上的训练、保存、加载和预测的完整过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/133536.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

优维低代码实践:打包发布

导语 优维低代码技术专栏,是一个全新的、技术为主的专栏,由优维技术委员会成员执笔,基于优维7年低代码技术研发及运维成果,主要介绍低代码相关的技术原理及架构逻辑,目的是给广大运维人提供一个技术交流与学习的平台。…

【C++】开源:rapidjson数据解析库配置与使用

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍rapidjson数据解析库配置与使用。 无专精则不能成,无涉猎则不能通。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下&…

【亲测推荐】魔方财务和魔方云系统开源全解密

简介 资源入口点击进入 众所周知,魔方财务现在官方售价299,那么接下来就是带来开心版,详细手写实测安装教程已经放在付费资源中 展示 > 本文由博客一文多发平台 [OpenWrite](https://openwrite.cn?fromarticle_bottom) 发布!

MySQL索引事务存储引擎

索引:是一个排序的列表 列表中存储的是索引的值和包含这个值数据所在行的物理地址 索引的作用 利用索引数据库可以快速定位 大大加快查询速度表的数据很大 或查询需要关联多个表 使用索引也可以查询速度加快表与表之间的连接速度使用分组和排序时可以大大减少时间提…

数据结构 - 全貌总结

目录 一. 前言 二. 分类 三. 常见的线性和非线性结构 一. 前言 数据结构是计算机存储、组织数据的方式。一种好的数据结构可以带来更高的运行或者存储效率。数据在内存中是呈线性排列的,但是我们可以使用指针等道具,构造出类似“树形”等复杂结构。 数…

java项目之宠物管理系统(ssm框架)

项目简介 宠物管理系统实现了以下功能: 管理员:首页、个人中心、宠物分类管理、商品分类管理、宠物用品管理、宠物商店管理、宠物领养管理、用户管理、宠物寄存管理、用户领养管理、宠物挂失管理、论坛管理、管理员管理、系统管理、订单管理。前台首页…

使用 Rust 进行程序

首先,我们需要安装必要的库。在终端中运行以下命令来安装 scraper 和 reqwest 库: rust cargo install scraper reqwest 然后,我们可以开始编写程序。以下是一个基本的爬虫程序,用于爬取 上的图片: rust use reqwe…

Vue3.0中父组件与子组件的通信传值props与emit :VCA模式

简介 什么是props Props 是 Vue 组件之间通信的一种方式,通过 Props,父组件可以向子组件传递数据,即:父组件可以通过组件标签上的属性值把数据传递到子组件中。子组件可以根据自己的属性和方法去渲染展示数据或执行某些操作。由…

Pinia 是什么?Redux、Vuex、Pinia 的区别?

结论先行: Pinia 是 Vue 官方团队开发的一个全新状态管理库。与 Redux、Vuex 相同,核心都是解决组件间的通信和数据的共享问题。 Pinia 和 Vuex 类似,但使用起来更加简单和直观。因为 Pinia 基于 Vue3 的 Composition 组合式 API 风格&…

金豺算法优化VMD参数,六种适应度函数任意切换,最小包络熵、样本熵、信息熵、排列熵、排列熵/互信息熵、包络谱峰值因子...

声明:对于作者的原创代码,禁止转售倒卖,违者必究! 本期采用金豺优化算法(Golden Jackal optimization, GJO)优化VMD参数。选取六种适应度函数进行优化,以此确定VMD的最佳k和α参数。6种适应度函数分别是:最…

大厂真题:【模拟】阿里蚂蚁2023秋招-奇偶操作

题目描述与示例 题目描述 小红有一个长度为n的数组a,她将对数组进行m次操作,每次操作有两种类型: 将数组中所有值为奇数的元素加上x将数组中所有值为偶数的元素加上x 请你输出m次操作后的数组 输入描述 第一行两个整数n和m,表示…

初识JVM

1. JVM内存区域划分 jvm在启动的时候,会申请到一整个很大的内存区域。整个一大块区域,不太好用。为了更方便使用,把整个区域隔成了很多区域,每个区域都有不同的作用。 本地方法栈 此处提到的栈和数据结构中的栈不是一个东西&…

如何在Linux机器上使用ssh远程连接Windows Server服务器

如何在Linux机器上使用ssh远程连接Windows Server服务器 一、源起二、使用ssh远程连接Windows1.先决条件(1)至少运行 Windows Server 2019 或 Windows 10(内部版本 1809)的设备。(2)PowerShell 5.1 或更高版…

【广州华锐互动】影视制作VR在线学习:身临其境,提高学习效率

随着科技的不断发展,影视后期制作技术也在日新月异。然而,传统的教学方式往往难以满足学员的学习需求,无法充分展现影视后期制作的魅力和潜力。近年来,虚拟现实(VR)技术的崛起为教学领域带来了新的机遇。通过VR教学课件&#xff0…

超详细Linux搭建Hadoop集群

一、给计算机集群起别名——互通 总纲: 1、准备3台客户机(关闭防火墙、静态IP、主机名称都设置好) 2、安装JDK(可点击) 3、配置环境变量 4、安装Hadoop 5、配置hadoop的环境变量 6、配置集群 7、群起测试 1.1、环境准备…

蓝鹏测控平台软件 智能制造生产线的大脑

测控软件平台,是由包括底层驱动程序、通讯协议等,集数据采集、自动反馈控制、信息分析以及多种工程应用于一体的一种电子信息处理平台。 蓝鹏测控软件平台目前支持各种文本标签 、数字标签;支持趋势图、波动图、缺陷图及统计图表。多端口实现…

MCU常见通信总线串讲(一)—— UART和USART

🙌秋名山码民的主页 😂oi退役选手,Java、大数据、单片机、IoT均有所涉猎,热爱技术,技术无罪 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 获取源码,添加WX 目录 前言一…

使用 curator 连接 zookeeper 集群 Invalid config event received

dubbo整合zookeeper 如图,错误日志 2023-11-04 21:16:18.699 ERROR 7459 [main-EventThread] org.apache.curator.framework.imps.EnsembleTracker Caller0 at org.apache.curator.framework.imps.EnsembleTracker.processConfigData(EnsembleTracker.java…

Lyapunov function 李雅普诺夫函数

文章目录 正文定义对定义中出现的术语的进一步讨论 Basic Lyapunov theorems for autonomous systems 自治系统的基本李雅普诺夫定理Locally asymptotically stable equilibrium 局部渐近稳定平衡Stable equilibrium 稳定平衡Globally asymptotically stable equilibrium 全局渐…

计算机毕业设计java+vue+springboot的论坛信息网站

项目介绍 本论文系统地描绘了整个网上论坛管理系统的设计与实现,主要实现的功能有以下几点:管理员;首页、个人中心、用户管理、公告管理、公告类型管理、热门帖子管理、帖子分类管理、留言板管理、论坛新天地、我的收藏管理、系统管理&#…