3 文本分类入门finetune:bert-base-chinese

项目实战:

数据准备工作

        `bert-base-chinese` 是一种预训练的语言模型,基于 BERT(Bidirectional Encoder Representations from Transformers)架构,专门用于中文自然语言处理任务。BERT 是由 Google 在 2018 年提出的一种革命性的预训练模型,通过大规模的无监督训练,能够学习到丰富的语言表示。

        `bert-base-chinese` 是 BERT 在中文语料上进行预训练的版本,它包含了 12 层 Transformer 编码器和 110 万个参数。这个模型在中文文本上进行了大规模的预训练,可以用于各种中文自然语言处理任务,如文本分类、命名实体识别、情感分析等。

使用 `bert-base-chinese` 模型时,可以将其作为一个特征提取器,将输入的文本转换为固定长度的向量表示,然后将这些向量输入到其他机器学习模型中进行训练或推断。也可以对 `bert-base-chinese` 进行微调,将其用于特定任务的训练。

        预训练的 `bert-base-chinese` 模型可以通过 Hugging Face 的 Transformers 库进行加载和使用。在加载模型后,可以使用它的 `encode` 方法将文本转换为向量表示,或者使用 `forward` 方法对文本进行特定任务的预测。

        需要注意的是,`bert-base-chinese` 是一个通用的中文语言模型,但它可能在特定的任务上表现不佳。在某些情况下,可能需要使用更大的模型或进行微调来获得更好的性能。

进行微调时,可以按照以下步骤进行操作:

1. 准备数据集:首先,你需要准备一个与你的任务相关的标注数据集。这个数据集应该包含输入文本以及相应的标签或注释,用于训练和评估模型。

2. 加载预训练模型:使用 Hugging Face 的 Transformers 库加载预训练的 `bert-base-chinese` 模型。你可以选择加载整个模型或只加载其中的一部分,具体取决于你的任务需求。

3. 创建模型架构:根据你的任务需求,创建一个适当的模型架构。这通常包括在 `bert-base-chinese` 模型之上添加一些额外的层,用于适应特定的任务。

4. 数据预处理:将你的数据集转换为适合模型输入的格式。这可能包括将文本转换为输入的编码表示,进行分词、填充和截断等操作。

5. 定义损失函数和优化器:选择适当的损失函数来衡量模型预测与真实标签之间的差异,并选择合适的优化器来更新模型的参数。

6. 微调模型:使用训练集对模型进行训练。在每个训练步骤中,将输入文本提供给模型,计算损失并进行反向传播,然后使用优化器更新模型的参数。

7. 评估模型:使用验证集或测试集评估模型的性能。可以计算准确率、精确率、召回率等指标来评估模型在任务上的表现。

8. 调整和优化:根据评估结果,对模型进行调整和优化。你可以尝试不同的超参数设置、模型架构或训练策略,以获得更好的性能。

9. 推断和应用:在微调完成后,你可以使用微调后的模型进行推断和应用。将新的输入文本提供给模型,获取预测结果,并根据任务需求进行后续处理。

需要注意的是,微调的过程可能需要大量的计算资源和时间,并且需要对模型和数据进行仔细的调整和优化。此外,合适的数据集规模和质量对于获得良好的微调结果也非常重要。

准备模型

https://huggingface.co/bert-base-chinese/tree/main

数据集

与之前的训练数据一样使用;

代码部分

# 导入transformers
import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup# 导入torch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F# 常用包
import re
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap
from tqdm import tqdm%matplotlib inline
%config InlineBackend.figure_format='retina' # 主题device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
devicePRE_TRAINED_MODEL_NAME = '../bert-base-chinese/' # 英文bert预训练模型
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
tokenizerdf = pd.read_csv("../data/online_shopping_10_cats.csv")
#myDataset[0]
RANDOM_SEED = 1012
df_train, df_test = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)
df_train.shape, df_val.shape, df_test.shapeclass MyDataSet(Dataset):def __init__(self,texts,labels,tokenizer,max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self,item):text = str(self.texts[item])label = self.labels[item]encoding = self.tokenizer(text=text,max_length=self.max_len,pad_to_max_length=True,add_special_tokens=True,return_attention_mask=True,return_token_type_ids=True,return_tensors='pt')return {"text":text,"input_ids":encoding['token_type_ids'].flatten(),"attention_mask":encoding['attention_mask'].flatten(),"labels":torch.tensor(label,dtype=torch.long)}
def create_data_loader(df,tokenizer,max_len,batch_size=4):ds = MyDataSet(texts=df['review'].values,labels=df['label'].values,tokenizer = tokenizer,max_len=max_len)return DataLoader(ds,batch_size=batch_size)
MAX_LEN = 512
BATCH_SIZE = 4
train_data_loader = create_data_loader(df_train,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)
val_data_loader = create_data_loader(df_val,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)
test_data_loader = create_data_loader(df_test,tokenizer,max_len=MAX_LEN, batch_size=BATCH_SIZE)class BaseBertModel(nn.Module):def __init__(self,n_class=2):super(BaseBertModel,self).__init__()self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)self.drop = nn.Dropout(0.2)self.out = nn.Linear(self.bert.config.hidden_size, n_class)passdef forward(self,input_ids,attention_mask):_,pooled_output = self.bert(input_ids=input_ids,attention_mask=attention_mask,return_dict = False)out = self.drop(pooled_output)return self.out(out)#pooled_output
model = BaseBertModel()
model = model.to(device)EPOCHS = 5 # 训练轮数
optimizer = AdamW(model.parameters(),lr=2e-5,correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=0,num_training_steps=total_steps)
loss_fn = nn.CrossEntropyLoss().to(device)def train_epoch(model,data_loader,loss_fn,optimizer,device,schedule,n_exmaples):model = model.train()losses = []correct_predcitions = 0for d in tqdm(data_loader):input_ids = d['input_ids'].to(device)attention_mask = d['attention_mask'].to(device)targets = d['labels'].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask)_,preds = torch.max(outputs, dim=1)loss = loss_fn(outputs,targets)losses.append(loss.item())correct_predcitions += torch.sum(preds==targets)loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()scheduler.step()optimizer.zero_grad()return correct_predictions.double() / n_examples, np.mean(losses)def eval_model(model, data_loader, loss_fn, device, n_examples):model = model.eval() # 验证预测模式losses = []correct_predictions = 0with torch.no_grad():for d in data_loader:input_ids = d["input_ids"].to(device)attention_mask = d["attention_mask"].to(device)targets = d["labels"].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask)_, preds = torch.max(outputs, dim=1)loss = loss_fn(outputs, targets)correct_predictions += torch.sum(preds == targets)losses.append(loss.item())return correct_predictions.double() / n_examples, np.mean(losses)# train model
history = defaultdict(list) # 记录10轮loss和acc
best_accuracy = 0for epoch in range(EPOCHS):print(f'Epoch {epoch + 1}/{EPOCHS}')print('-' * 10)#(model,data_loader,loss_fn,device,schedule,n_exmaples)train_acc, train_loss = train_epoch(model,train_data_loader,loss_fn,optimizer,device,scheduler,len(df_train))print(f'Train loss {train_loss} accuracy {train_acc}')val_acc, val_loss = eval_model(model,val_data_loader,loss_fn,device,len(df_val))print(f'Val   loss {val_loss} accuracy {val_acc}')print()history['train_acc'].append(train_acc)history['train_loss'].append(train_loss)history['val_acc'].append(val_acc)history['val_loss'].append(val_loss)if val_acc > best_accuracy:torch.save(model.state_dict(), 'best_model_state.bin')best_accuracy = val_acc# 模型评估
test_acc, _ = eval_model(model,test_data_loader,loss_fn,device,len(df_test)
)
test_acc.item()def get_predictions(model, data_loader):model = model.eval()texts = []predictions = []prediction_probs = []real_values = []with torch.no_grad():for d in data_loader:texts = d["texts"]input_ids = d["input_ids"].to(device)attention_mask = d["attention_mask"].to(device)targets = d["labels"].to(device)outputs = model(input_ids=input_ids,attention_mask=attention_mask)_, preds = torch.max(outputs, dim=1)probs = F.softmax(outputs, dim=1)texts.extend(texts)predictions.extend(preds)prediction_probs.extend(probs)real_values.extend(targets)predictions = torch.stack(predictions).cpu()prediction_probs = torch.stack(prediction_probs).cpu()real_values = torch.stack(real_values).cpu()return texts, predictions, prediction_probs, real_valuesy_texts, y_pred, y_pred_probs, y_test = get_predictions(model,test_data_loader
)
print(classification_report(y_test, y_pred, target_names=[str(label) for label in class_names]))# 模型预测sample_text='Hard but Robust, Easy but Sensitive: How Encod.'
encoded_text = tokenizer.encode_plus(sample_text,max_length=MAX_LEN,add_special_tokens=True,return_token_type_ids=False,pad_to_max_length=True,return_attention_mask=True,return_tensors='pt',
)input_ids = encoded_text['input_ids'].to(device)
attention_mask = encoded_text['attention_mask'].to(device)output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)print(f'Sample text: {sample_text}')
print(f'Danger label  : {label_id2cate[prediction.cpu().numpy()[0]]}')

runing....

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/203891.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gpt1与bert区别

区别1:网络结构(主要是Masked Multi-Head-Attention和Multi-Head-Attention) gpt1使用transformer的decoder,单向编码,是一种基于语言模型的生成式模型,更适合生成下一个单词或句子 bert使用transformer的…

Domino多Web站点托管

大家好,才是真的好。 看到一篇文档,大概讲述的是他在家里架了一台Domino服务器,上面跑了好几个Internet的Web网站(使用Internet站点)。再租了一台云服务器,上面安装Nginx做了反向代理,代理访问…

轨迹分析:Palantir评估细胞分化潜能 类似于monocle2

轨迹分析是单细胞测序分析中重要的组成部分,它基于细胞谱系之间“具有中间态细胞”的理论基础,通过结合先验知识(细胞注释、markers)、细胞基因表达改变等,为在单细胞测序数据赋予了“假时间”(pseudotime&…

图的深度优先搜索(数据结构实训)

题目: 图的深度优先搜索 描述: 图的深度优先搜索类似于树的先根遍历,是树的先根遍历的推广。即从某个结点开始,先访问该结点,然后深度访问该结点的第一棵子树,依次为第二顶子树。如此进行下去,直…

每天五分钟计算机视觉:通过残差块搭建卷积残差神经网络Resnet

本文重点 随着深度神经网络的层数的增加,神经网络会变得越来越难以训练,之所以这样就是因为存在梯度消失和梯度爆炸问题。本节课程我们将学习跳跃连接方式,它可以从某一网络层获取激活a,然后迅速反馈给另外一层,甚至是神经网络的更深层,从而解决梯度消失的问题。 传统的…

关于命令行方式的MySQL服务无法启动问题原因之一解决

这里无法启动服务的原因为系统某些进行占用了3306端口问题 当你遇到无法启动的问题时,可以尝试通过netstat -ano命令查看系统进行信息,验证是否3306端口被占用 在本地地址列如果发现3306端口被占用,则通过 taskkill /f /pid 进程id命令关闭进…

matlab 点云放缩变换

目录 一、算法原理二、代码实现三、结果展示四、相关链接本文由CSDN点云侠原创,原文链接。爬虫网站自重。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫与GPT。 一、算法原理 缩放可以独立应用于三个坐标轴,如将点 ( x , y , z ) ( x

dtm分布式事务框架之SAGA 实战

一.dtm分布式事务框架之SAGA 1.1DTM介绍 DTM是一款开源的分布式事务管理器,解决跨数据库、跨服务、跨语言栈更新数据的一致性问题。 通俗一点说,DTM提供跨服务事务能力,一组服务要么全部成功,要么全部回滚,避免只更…

【天线了解】1.004天线的了解以及使用

一。004天线使用步骤 1.打开天线 (1)天线的各种版本 注意: 《1》天线包括单通道天线程序,双通道天线程序等。 《2》在没有连接天线时,有的天线程序打不开。 (2)打开软件前的配置工作 注意&…

接鸡冠^^

欢迎来到程序小院 接鸡冠 玩法&#xff1a;左右移动棒棒君(小海豹)接住鸡冠&#xff0c;避开炸弹&#xff0c;若不小心接住炸弹则游戏结束&#xff0c; 赶紧接鸡冠吧&#xff0c;看看你能够接住多少鸡冠哦^^。。开始游戏https://www.ormcc.com/play/gameStart/211 html <di…

【精选】设计模式——策略设计模式-两种举例说明,具体代码实现

Java策略设计模式 简介 策略设计模式是一种行为型设计模式&#xff0c;它允许在运行时选择算法的行为。 在软件开发中&#xff0c;我们常常需要根据不同情况采取不同的行为。通常的做法是使用大量的条件语句来实现这种灵活性&#xff0c;但这会导致代码变得复杂、难以维护和扩…

Unity打包EXE自定义(拖拽)窗口大小

代码 using System.Collections; using System.Collections.Generic; using UnityEngine; using System; using System.Runtime.InteropServices; public class MyWindow : MonoBehaviour {[DllImport("user32.dll")]private static extern IntPtr GetActiveWindow(…

CSS-自适应导航栏(flex | grid)

目标&#xff1a;实现左右各有按钮&#xff0c;中间是内容&#xff0c;自适应显示中间的内容导航栏&#xff0c;即 根据中间的宽度大小显示内容。 自适应导航栏 总结&#xff1a;推荐 flex布局 / grid布局 flex布局&#xff1a; 两侧 flex:1; ----->中间自适应 grid布局&…

uniapp(微信小程序)聊天实例,支持图片,语音,表情(附源码)

效果预览 安装教程 配置 请参考Dome 会话配置 {info:{// 用户关键字userKey:2666,// 用户手机userPhone:15252156614,// 用户昵称userName: 健健,// 头像headImg: http://d.hiphotos.baidu.com/image/h%3D300/sign0defb42225381f3081198ba999004c67/6159252dd42a2834a75bb01…

CRM客户关系管理系统的主要功能有哪些?

我们都知道&#xff0c;CRM系统可以帮助企业加快业务增长。如果一个企业能提高业务效率、跨团队协作、有效管理客户、轻松共享和同步数据&#xff0c;那么企业竞争力将极大地提高。基于此&#xff0c;我们说说CRM客户关系管理系统的主要功能分析。 完整的CRM是什么样的&#x…

红队专题-开源资产扫描系统-ARL资产灯塔系统

ARL资产灯塔系统 安装说明问题 &#xff1a; 安装说明 源码地址 https://github.com/TophantTechnology/ARL https://github.com/TophantTechnology/ARL/wiki/Docker-%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85-ARL 安装环境 uname -a Linux VM-24-12-centos 3.10.0-1160.49.1.e…

亚马逊云科技re:Invent,生成式AI正在彻底改变开发者的工作方式

去年此时&#xff0c;ChatGPT横空出世席卷全球&#xff0c;许多人称其意味着AI的iPhone时刻到来。CSDN创始人蒋涛对此曾预测&#xff1a;「下一步就是应用时刻&#xff0c;新应用时代将来临……大模型将推动更多的AI应用程序员诞生」。 在2023亚马逊云科技re:Invent全球大会第三…

Linux--环境变量

一.基本概念 * 环境变量 (environment variables) 一般是指在操作系统中用来指定操作系统运行环境的一些参数 * 如&#xff1a;我们在编写 C/C 代码的时候&#xff0c;在链接的时候&#xff0c;从来不知道我们的所链接的动态静态库在哪里&#xff0c;但 是照样可以链接成功&am…

使用jenkins插件Allure生成自动化测试报告

前言 以前做自动化测试的时候一直用的HTMLTestRunner来生成测试报告&#xff0c;后来也尝试过用Python的PyH模块自己构建测试报告&#xff0c;在后来看到了RobotFramework的测试报告&#xff0c;感觉之前用的测试报告都太简陋&#xff0c;它才是测试报告应该有的样子。也就是在…

微信小程序 -- ios 底部小黑条样式问题

问题&#xff1a; 如图&#xff0c;ios有的机型底部伪home键会显示在按钮之上&#xff0c;导致点击按钮的时候误触 解决&#xff1a; App.vue <script>export default {wx.getSystemInfo({success: res > {let bottomHeight res.screenHeight - res.safeArea.bott…