完全小白如何用Windows电脑运行人生中的第一个Bert文本分类代码(更简单版)

诸神缄默不语-个人CSDN博文目录

我准备过几天录一个讲解视频。先等一下,现在只有图文版。

本文代码最早写于2024年3月27日,不保证未来以下代码及操作过程仍然可以使用。
本文主要关注中文仇恨检测短文本分类,数据集来源于datasets官网,本教程不介绍datasets包的使用技能。

完整代码见:https://github.com/PolarisRisingWar/all-notes-in-one/blob/main/hate_text_classification.ipynb

文章目录

  • 1. 环境配置
  • 2. 设置超参
  • 3. 导入包
  • 4. 构建tokenizer和模型
  • 5. 导入数据集
  • 6. 数据集预处理
  • 7. 设置训练和评估超参数
  • 8. 训练
  • 9. 测试

1. 环境配置

Anaconda安装教程:Anaconda教程(持续更新ing…)
或视频版:https://www.bilibili.com/video/BV1K34y1G7av/

PyTorch安装教程:PyTorch安装教程

transformers安装教程:huggingface.transformers安装教程
此外还需要下载datasets, evaluate, accelerate包,直接用pip下载就可以

pip教程:pip详解(持续更新ing…)

BERT中文版预训练模型权重:下载这里所有的文件:bert-base-chinese at main
然后放到某个文件夹里,这个文件夹路径在下面的代码里放到pretrained_path位置处

2. 设置超参

# 超参设置
pretrained_path = r"D:\allApplications\forPython\llm\bert-base-chinese"  #预训练模型权重路径
max_epoch_num = 1  #运行epoch数
output_dim = 2  #分类标签数(本文做仇恨检测,所以是二分类)

3. 导入包

import numpy as npfrom transformers import (AutoTokenizer,AutoModelForSequenceClassification,TrainingArguments,Trainer,
)import datasets, evaluate

4. 构建tokenizer和模型

tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
model = AutoModelForSequenceClassification.from_pretrained(pretrained_path, num_labels=output_dim
)

5. 导入数据集

我直接用的是datasets官网上的Paul/hatecheck-mandarin数据集,可以参考以下代码将数据文件先下载到本地:

import datasets
dataset=datasets.load_dataset("Paul/hatecheck-mandarin")dataset.save_to_disk("hatecheck-mandarin")

以后就可以直接通过如下代码导入数据集:

all_dataset = datasets.load_from_disk("hatecheck-mandarin")

6. 数据集预处理

这个数据集只有测试集,我尝试性地选择了30个样本分别作为训练集、验证集和测试集。

label_map = {"hateful": 1, "non-hateful": 0}  #将标签映射为数字def tokenize_function(examples):"""批量预处理样本的代码"""return_dict = tokenizer(examples["test_case"], padding="max_length", truncation=True, max_length=512)  #tokenize输入文本return_dict["label"] = [label_map[x] for x in examples["label_gold"]]  #将标签映射为数字return return_dicttokenized_datasets = all_dataset.map(tokenize_function, batched=True)  #对数据集进行批量预处理#将数据集划分为训练集、验证集、测试集
example_train_dataset = tokenized_datasets["test"].select(range(10))
example_valid_dataset = tokenized_datasets["test"].select(range(10, 20))
example_test_dataset = tokenized_datasets["test"].select(range(20, 30))

7. 设置训练和评估超参数

#训练阶段超参数
training_args = TrainingArguments(output_dir="test_checkpoint",  #训练权重存储路径evaluation_strategy="epoch",  #每个epoch评估一次report_to="none",  #设置不用wandb记录训练日志push_to_hub=False,  #不将训练权重上传到huggingface hubnum_train_epochs=max_epoch_num,  #训练轮数per_device_train_batch_size=1,  #每个设备(CPU/GPU)上的batch sizeno_cuda=True,  #不使用GPU(我是因为电脑上GPU内存只有2G根本带不动,反正只是个示例代码干脆不用了)do_train=True,  #训练save_steps=10,  #每10个step保存一次checkpointsave_total_limit=1,  #保存最后1个checkpoint
)#调用evaluate包的准确率评估指标
metric = evaluate.load("accuracy")def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)#设置训练器
trainer = Trainer(model=model,args=training_args,train_dataset=example_train_dataset,eval_dataset=example_valid_dataset,compute_metrics=compute_metrics,
)

8. 训练

trainer.train()

一边训练一边就会保存训练权重,打印进度条和评估指标:
在这里插入图片描述

9. 测试

直接得到测试集上的预测结果。
predictions是每个样本对应的logits(两个元素分别是正负标签归一化前的预测概率)。
label_ids是标签。
metrics是评估指标(损失函数、准确率、运行时间)

result = trainer.predict(example_test_dataset)
print(result)

输出:

PredictionOutput(predictions=array([[-2.42961  ,  3.2694852],[-2.601208 ,  3.2841954],[-2.4568973,  3.0550063],[-2.48325  ,  3.239044 ],[-2.5793636,  3.283456 ],[-2.356497 ,  2.8706086],[-2.5480254,  3.2449925],[-2.4836488,  3.2778778],[-2.6639569,  3.2867384],[-2.5087523,  3.067698 ]], dtype=float32), label_ids=array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int64), metrics={'test_loss': 0.003417523577809334, 'test_accuracy': 1.0, 'test_runtime': 17.7476, 'test_samples_per_second': 0.563, 'test_steps_per_second': 0.113})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/1613.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

绿联搭建rustdesk服务器

绿联搭建rustdesk服务器,不再使用向日葵 注意:本服务器需要有动态公网IP以及自己的域名,ipv6未测试。 1. 拉取镜像 rustdesk/rustdesk-server-s6:latest 注意是这个-s6的镜像。 2. 部署镜像 2.1 内存配置 本服务器比较省内存&#xff0…

关于agi中的Function Calling深入解析

接口(Interface) 两种常见接口: 1、人机交互接口,User Interface,简称UI 2、应用程序编程接口,Application Programming Interface,简称API 接口能【通】的关键,是两边都要遵守约定。 人要按照UI的设计来操作。UI的设计要符合…

Android Studio实现内容丰富的安卓养老平台

获取源码请点击文章末尾QQ名片联系,源码不免费,尊重创作,尊重劳动 158安卓养老 1.开发环境 后端用springboot框架,安卓的用android studio开发android stuido3.6 jak1.8 idea mysql tomcat 2.功能介绍 安卓端: 1.注册登…

【数据结构(八)上】二叉树经典习题

❣博主主页: 33的博客❣ ▶文章专栏分类: Java从入门到精通◀ 🚚我的代码仓库: 33的代码仓库🚚 🫵🫵🫵关注我带你学更多数据结构的知识 目录 1.前言2.经典习题2.1相同的树2.2另一棵子树2.3翻转二叉树2.4平衡二叉树2.5对…

直播美颜工具与视频美颜SDK:技术深入探索

直播美颜工具和视频美颜SDK的出现,为直播平台和应用开发者提供了丰富的选择。本文将深入探讨这些技术的原理、应用和发展趋势。 一、美颜算法 直播美颜工具的核心在于其先进的美颜算法。这些算法通过对图像进行分析和处理,实时地修饰主播的面部特征&am…

vsstudio 如何远程调试

你可能需要调试一个在本地生成的 Windows 桌面项目,然后在远程计算机上运行可执行文件。本主题阐释如何更改本地项目设置以在远程计算机上运行应用程序。C++ 项目会自动部署到远程计算机。您将需要手动部署 .NET Framework 可执行文件。 设置 Visual C++ 项目 此处显示的过程…

项目开发流程

项目开发流程 👩‍🦳项目立项 估计项目的花费,确定大致的所需开发人员数,确定项目是否可行; 👩‍🦰需求分析 整体过程: 项目背景和目标,即项目的目的是什么 用户需求&…

SQLAIchemy 异步DBManager封装-01入门理解

前言 SQLAlchemy 是一个强大的 Python SQL 工具包和对象关系映射(ORM)系统,是业内比较流行的ORM,设计非常优雅。随着其2.0版本的发布,SQLAlchemy 引入了原生的异步支持,这极大地增强了其在处理高并发和异步…

Windows 的常用命令(不分大小写)

Net user (查看当前系统所有的账户) net user yourname password /add 添加新用户 net localgroup administrators yourname /add 添加管理员权限 net user yourname /delete 删除用户 net user 命令 [colorred]说明:以下命令仅限持管理员…

opencv人脸打马赛克

import cv2def FaceFind(imgPath: str) -> list:image cv2.imread(imgPath)gray cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)face_cascade cv2.CascadeClassifier(haarcascade_frontalface_default.xml)# 返回人脸坐标列表faces face_cascade.detectMultiScale(gray, scal…

数据结构11:二叉树的链式结构

文章目录 快速创建链式二叉树二叉树的遍历前序、中序、后序层序 二叉树的基本操作二叉树的节点个数二叉树叶节点的个数二叉树第k层结点个数二叉树查找值为x的结点 二叉树基础oj练习单值二叉树检查两颗树是否相同对称二叉树二叉树的前序遍历另一颗树的子树 二叉树的创建和销毁二…

谷雨时节,雨水渐多湿气旺盛,吃什么养生?听听张婉如医生怎么说

谷雨春光晓,山川黛色青。 叶间鸣戴胜,泽水长浮萍。 4月19日21时59分迎来谷雨,雨生百谷,这是谷雨节气的意思,它是春季的最后一个节气,这个时节早晚温差大,空气湿气重,如何养生呢&am…

java在线问卷调查系统的设计与实现(springboot+mysql源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的在线问卷调查系统。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 基于java的在线问卷调查…

电脑便签怎么固定位置 能固定在桌面的电脑便签

在繁忙的工作中,电脑便签是我离不开的小助手。每当灵感闪现,或是需要记录待办事项时,我总会打开便签,快速地记录下来。它就像我电脑屏幕上的一块“记事板”,随时提醒我未完成的工作和即将到来的任务。 但有一段时间&a…

i管家空间不足提醒怎么关闭

i管家的空间不足提醒是为了提醒用户手机存储空间不足,可能会影响手机的正常运行。目前,这个提醒功能是无法直接关闭的。如果您希望减少这类提醒的出现,可以尝试以下几种方法: 清理手机存储:检查手机中是否有不需要的文…

restful请求风格的增删改查-----查询and添加

一、restful风格的介绍 restful也称之为REST ( Representational State Transfer ),可以将它理解为一种软件架构风格或设计风格,而不是一个标准。简单来说,restful风格就是把请求参数变成请求路径的一种风格。例如,传统的URL请求…

Darknet,看过很多篇,这个最清晰了

Darknet深度学习框架:YOLO背后的强大支持 Darknet,一个由Joseph Redmon开发的轻量级神经网络框架,以其在计算机视觉任务,特别是目标检测中的卓越表现而闻名。本文将详细介绍Darknet的基本概念、结构以及它在深度学习领域的应用。…

UE4_动画基础_根运动Root Motion

学习笔记,仅供参考! 在游戏动画中,角色的碰撞胶囊体(或其他形状)通常由控制器驱动通过场景。然后来自该胶囊体的数据用于驱动动画。例如,如果胶囊体在向前移动,系统就会知道在角色上播放一个跑步…

Kivy Pyinstaller Windows 打包

各种报错 ImportErrorWhenRunningHook: Failed to import module __PyInstaller_hooks_0_kivy required by hook for module 三天美好时光啥也没干,就研究这个了。 打包成功,运行应用程序exe闪退的。终于打包成功了。 这所有的原因都是因为我爱你。如果…

小型架构实验模拟

一 实验需求 二 实验环境 22 机器: 做nginx 反向代理 做静态资源服务器 装 nginx keepalived filebeat 44机器: 做22 机器的备胎 装nginx keepalived 99机器:做mysql的主 装mysqld 装node 装filebeat 77机器:做mysq…