【深度学习】快速入门KerasNLP:微调BERT模型完成电影评论情感分类任务

简介:本文将介绍 KerasNLP 的安装及使用,以及如何使用它在情感分析任务中微调 BERT 的预训练模型。

1. KerasNLP库

KerasNLP 是一个自然语言处理库,兼容 TensorFlow、JAX 和 PyTorch 等多种深度学习框架。基于 Keras 3 构建,这些模型、层、指标和分词器可以在任何框架中训练和序列化,并且可重复应用于其他框架中,无需其他复杂开发步骤。

安装代码

pip install --upgrade keras-nlp
pip install --upgrade keras

2. BERT模型介绍

BERT,全称为Bidirectional Encoder Representations from Transformers,是由谷歌AI团队提出的一种预训练语言模型。

它基于Transformer架构,通过双向的编码器对文本进行建模,即同时考虑上下文信息,从而捕捉词汇间的深层语义关系。

BERT在预训练阶段使用无监督的Masked Language Model(掩码语言模型)和Next Sentence Prediction(下一句预测)任务进行训练,随后可以通过微调在各种自然语言处理任务中取得显著的效果。BERT的出现极大地提升了NLP领域的表现,广泛应用于问答系统、文本分类、命名实体识别等任务。

更多的NLP模型参考:KerasNLP Models


3. 代码示例

项目介绍

项目的主要目标是通过微调预训练的BERT模型,准确地将电影评论分类为正面或负面。
在这里插入图片描述

数据集介绍
本文使用的是tensorflow内置的IMDB影评数据集。该数据集包含来自互联网电影数据库(IMDB)的 50,000 条影评,用于二分类任务(正面和负面)。IMDB数据集是情感分析的经典数据集,广泛用于评估和比较不同模型的性能。

# 配置环境
import os
os.environ["KERAS_BACKEND"] = "tensorflow"  # Or "jax" or "torch"!# 导入库
import keras_nlp
import tensorflow_datasets as tfds# 导入数据
imdb_train, imdb_test = tfds.load("imdb_reviews",split=["train[:10%]", "test[:10%]"], #原代码为split=["train", "test"],这里只取10%的样本量以减少训练耗时as_supervised=True,batch_size=16,
)# 加载BERT模型
classifier = keras_nlp.models.BertClassifier.from_preset("bert_base_en_uncased", num_classes=2,  # 结果只需要两种分类:正面OR负面
)# 模型训练
classifier.fit(imdb_train, validation_data=imdb_test)# 预测结果
classifier.predict(["What an amazing movie!", "A total waste of my time."])

结果输出:
![[超快速入门 KerasNLP & KerasCV-20240625170242587.webp|524]]在这里插入图片描述

解释: 每行对应一个输入样本(电影评论),每个样本的预测分数有两个值。这些分数是未经过处理的原始logits,分别对应两个分类(正面和负面)。

  • 第一行[-2.00009, 1.8325567]:对应(What an amazing movie!)。 由于正面评论(1.83)的分数高于负面评论(-2.00),模型预测为正面评论。

  • 第二行[1.9168645, -1.5912567]:对应样本二(A total waste of my time.)。由于负面评论(1.91)的分数高于正面评论(-1.59),模型预测为负面评论。

在此例中,使用的是 KerasNLP 的 BertClassifier,默认情况下,它会按照标签顺序输出预测分数。假设数据集中正面评论标签为1,负面评论标签为0,那么模型输出的第一个分数对应标签0(负面),第二个分数对应标签1(正面)。


查看数据集的类别标签及顺序:

info = tfds.builder('imdb_reviews').info
print(info.features['label'].names)

输出:![[超快速入门 KerasNLP & KerasCV-20240625171738863.webp]]


结果转换
我们可以使用Softmax函数将原始分数logits转换成对应的类别标签:

import numpy as np
import tensorflow as tflogits = np.array([[-2.00009, 1.8325567], [1.9168645, -1.5912567]])# 1. 定义 softmax 函数:
def softmax(x):return tf.nn.softmax(x)# 2. 计算 softmax 概率:
probabilities = softmax(logits)# 3. 获取预测类别索引
predicted_classes = np.argmax(probabilities, axis=1)# 定义类别标签映射
class_labels = ['neg', 'pos']# 将预测类别索引转换为对应的标签
predicted_labels = [class_labels[idx] for idx in predicted_classes]# 打印每条评论的预测结果
test_reviews = ["What an amazing movie!", "A total waste of my time."]
for review, label in zip(test_reviews, predicted_labels):print(f"Review: \"{review}\" -> Sentiment: {label}")# 输出 logits、softmax 概率和预测类别(可选)
print("Logits:\n", logits)
print("Probabilities:\n", probabilities)
print("Predicted Classes:\n", predicted_classes)

![[超快速入门 KerasNLP & KerasCV-20240625173420064.webp]]

参考链接:KerasNLP

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/34375.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

核密度估计kde的本质

核密度估计的本质就是插值,不是拟合,只是不要求必须过已知点。 核为box窗函数 核为高斯函数

python利用cartopy绘制带有经纬度的地图

参考: https://makersportal.com/blog/2020/4/24/geographic-visualizations-in-python-with-cartopy https://scitools.org.uk/cartopy/docs/latest/ https://stackoverflow.com/questions/69465435/cartopy-show-tick-marks-of-axes 具体实现方式: …

201.回溯算法:全排列(力扣)

class Solution { public:vector<int> res; // 用于存储当前排列组合vector<vector<int>> result; // 用于存储所有的排列组合void backtracing(vector<int>& nums, vector<bool>& used) {// 如果当前排列组合的长度等于 nums 的长度&am…

Mybatis 到 MyBatisPlus

Mybatis 到 MyBatisPlus Mybatis MyBatis&#xff08;官网&#xff1a;https://mybatis.org/mybatis-3/zh/index.html &#xff09;是一款优秀的 持久层 &#xff08;ORM&#xff09;框架&#xff0c;用于简化JDBC的开发。是 Apache的一个开源项目iBatis&#xff0c;2010年这…

【图像处理实战】去除光照不均(Python)

这篇文章主要是对参考文章里面实现一种小拓展&#xff1a; 可处理彩色图片&#xff08;通过对 HSV 的 V 通道进行处理&#xff09;本来想将嵌套循环改成矩阵运算的&#xff0c;但是太麻烦了&#xff0c;而且代码也不好理解&#xff0c;所以放弃了。 代码 import cv2 import …

虚拟化 之八 详解构造带有 jailhouse 的 openEuler 发行版(ARM 飞腾派)

基本环境 嵌入式平台下,由于资源的限制,通常不具备通用性的 Linux 发行版,各大主流厂商都会提供自己的 Linux 发行版。这个发行版通常是基于某个 Linux 发行版构建系统来构建的,而不是全部手动构建,目前主流的 Linux 发行版构建系统是 Linux 基金会开发的 Yocto 构建系统。…

用一个暑假|用AlGC-stable diffusion 辅助服装设计及展示,让你在同龄人中脱颖而出!

大家好&#xff0c;我是设计师阿威 Stable Diffusion是一款开源AI绘画工具&#xff0c; 用户输入语言指令&#xff0c;即可自动生成各种风格的绘画图片 Stable Diffusion功能强大&#xff0c;生态完整、使用方便。支持大部分视觉模型上传&#xff0c;且可自己定制模型&#x…

什么是大模型?一文读懂大模型的基本概念

大模型是指具有大规模参数和复杂计算结构的机器学习模型。本文从大模型的基本概念出发&#xff0c;对大模型领域容易混淆的相关概念进行区分&#xff0c;并就大模型的发展历程、特点和分类、泛化与微调进行了详细解读&#xff0c;供大家在了解大模型基本知识的过程中起到一定参…

win7 的 vmware tools 安装失败

没有安装vmware tools的系统屏幕显示异常。桌面是比较小的图像&#xff0c;四周是黑边在 vmware 软件里 方法1&#xff0c;下补丁 https://www.catalog.update.microsoft.com/Search.aspx?qkb4474419 方法2&#xff0c;使用老版vm tools http://softwareupdate.vmware.com/c…

【ARM】MDK工程切换高版本的编译器后出现error A1137E报错

【更多软件使用问题请点击亿道电子官方网站】 1、 文档目标 解决工程从Compiler 5切换到Compiler 6进行编译时出现一些非语法问题上的报错。 2、 问题场景 对于一些使用Compiler 5进行编译的工程&#xff0c;要切换到Compiler 6进行编译的时候&#xff0c;原本无任何报错警告…

各大广告商竞相厮杀下,诞生了一个偏门的副业方式

前段时间&#xff0c;想买摩托车&#xff0c;但是媳妇不让买&#xff0c;所以我打算偷偷买&#xff0c;然后萌生了去摆摊赚钱的想法&#xff0c;但是还没有实施就在网上接触到了“某赚”APP&#xff0c;于是一发不可收拾&#xff0c;用我的话来说&#xff0c;我做的不是副业&am…

佑驾创新A股夭折再冲港股:三年亏损超5亿,商业化盈利难题何解

《港湾商业观察》廖紫雯 日前&#xff0c;深圳佑驾创新科技股份有限公司&#xff08;以下简称&#xff1a;佑驾创新&#xff09;递表港交所&#xff0c;保荐机构为中信证券、中金公司。佑驾创新曾于2023年8月启动A股上市辅导&#xff0c;但2024年5月公司终止了与辅导机构的上市…

【ai】trition:tritonclient yolov4:部署ubuntu18.04

X:\05_trition_yolov4_clients\01-python server代码在115上,client本想在windows上, 【ai】trition:tritonclient.utils.shared_memory 仅支持linux 看起来要分离。 client代码远程部署在ubuntu18.04上 ubuntu18.04 创建yolov4-trition python=3.7 环境 (base) zhangbin@ub…

基于matlab的图像灰度化与图像反白

1原理 2.1 图像灰度化原理 图像灰度化是将彩色图像转换为灰度图像的过程&#xff0c;使得每个像素点仅包含一个灰度值&#xff0c;从而简化了图像的复杂度。灰度化原理主要可以分为以下几种方法&#xff1a; 亮度平均法 原理&#xff1a;将图像中每个像素的RGB值的平均值作为…

[深度学习] 生成对抗网络GAN

生成对抗网络&#xff08;Generative Adversarial Networks&#xff0c;GANs&#xff09;是一种由 Ian Goodfellow 等人在2014年提出的深度学习模型Generative Adversarial Networks。GANs的基本思想是通过两个神经网络&#xff08;生成器和判别器&#xff09;的对抗过程&#…

VMware vCenter Server 8.0U3 发布下载 - 集中式管理 vSphere 环境

VMware vCenter Server 8.0U3 发布下载 - 集中式管理 vSphere 环境 Server Management Software | vCenter 请访问原文链接&#xff1a;https://sysin.org/blog/vmware-vcenter-8-u3/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sys…

如何解决ssh远程连接自动断开的问题

文章目录 1. 问题描述2. 配置SSH设置2.1 调整服务器端的设置2.2 调整客户端的设置 3. 调整用户断开时长 1. 问题描述 SSH 远程连接断开是一个常见的问题&#xff0c;尤其是在网络不稳定或长时间没有活动时。文本介绍一些常见的方法和技巧来保持 SSH 连接稳定和避免断开。 2. …

基于Python/MNE处理fnirs数据

功能性近红外光谱技术在脑科学领域被广泛应用&#xff0c;市面上也已经有了许多基于MATLAB的优秀工具包及相关教程&#xff0c;如&#xff1a;homer、nirs_spm等。而本次教程将基于Python的MNE库对fNIRS数据进行处理。 本次教程基于&#xff1a;https://mne.tools/stable/auto_…

自动驾驶系统功能安全解决方案解析

电信、公用事业、运输和国防等关键基础设施服务需要定位、导航和授时&#xff08;PNT&#xff09;技术来运行。但是&#xff0c;广泛采用定位系统&#xff08;GPS&#xff09;作为PNT信息的主要会引入漏洞。 在为关键基础设施制定PNT解决方案时&#xff0c;运营商必须做出两个…

运维入门技术——监控的三个维度(非常详细)零基础收藏这一篇就够了_监控维度怎么区分

一个好的监控系统最后要做到的形态:实现Metrics、Tracing、Logging的融合。监控的三个维度也就是Metrics、Tracing、Logging。 Metrics Metrics也就是我们常说的指标。 首先它的典型特征就是可聚合(aggregatable).什么是可聚合的呢,简单讲可聚合就是一种基本单位可以在一种维…