完全小白如何用Windows电脑运行人生中的第一个Bert文本分类代码(更简单版)

诸神缄默不语-个人CSDN博文目录

我准备过几天录一个讲解视频。先等一下,现在只有图文版。

本文代码最早写于2024年3月27日,不保证未来以下代码及操作过程仍然可以使用。
本文主要关注中文仇恨检测短文本分类,数据集来源于datasets官网,本教程不介绍datasets包的使用技能。

完整代码见:https://github.com/PolarisRisingWar/all-notes-in-one/blob/main/hate_text_classification.ipynb

文章目录

  • 1. 环境配置
  • 2. 设置超参
  • 3. 导入包
  • 4. 构建tokenizer和模型
  • 5. 导入数据集
  • 6. 数据集预处理
  • 7. 设置训练和评估超参数
  • 8. 训练
  • 9. 测试

1. 环境配置

Anaconda安装教程:Anaconda教程(持续更新ing…)
或视频版:https://www.bilibili.com/video/BV1K34y1G7av/

PyTorch安装教程:PyTorch安装教程

transformers安装教程:huggingface.transformers安装教程
此外还需要下载datasets, evaluate, accelerate包,直接用pip下载就可以

pip教程:pip详解(持续更新ing…)

BERT中文版预训练模型权重:下载这里所有的文件:bert-base-chinese at main
然后放到某个文件夹里,这个文件夹路径在下面的代码里放到pretrained_path位置处

2. 设置超参

# 超参设置
pretrained_path = r"D:\allApplications\forPython\llm\bert-base-chinese"  #预训练模型权重路径
max_epoch_num = 1  #运行epoch数
output_dim = 2  #分类标签数(本文做仇恨检测,所以是二分类)

3. 导入包

import numpy as npfrom transformers import (AutoTokenizer,AutoModelForSequenceClassification,TrainingArguments,Trainer,
)import datasets, evaluate

4. 构建tokenizer和模型

tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
model = AutoModelForSequenceClassification.from_pretrained(pretrained_path, num_labels=output_dim
)

5. 导入数据集

我直接用的是datasets官网上的Paul/hatecheck-mandarin数据集,可以参考以下代码将数据文件先下载到本地:

import datasets
dataset=datasets.load_dataset("Paul/hatecheck-mandarin")dataset.save_to_disk("hatecheck-mandarin")

以后就可以直接通过如下代码导入数据集:

all_dataset = datasets.load_from_disk("hatecheck-mandarin")

6. 数据集预处理

这个数据集只有测试集,我尝试性地选择了30个样本分别作为训练集、验证集和测试集。

label_map = {"hateful": 1, "non-hateful": 0}  #将标签映射为数字def tokenize_function(examples):"""批量预处理样本的代码"""return_dict = tokenizer(examples["test_case"], padding="max_length", truncation=True, max_length=512)  #tokenize输入文本return_dict["label"] = [label_map[x] for x in examples["label_gold"]]  #将标签映射为数字return return_dicttokenized_datasets = all_dataset.map(tokenize_function, batched=True)  #对数据集进行批量预处理#将数据集划分为训练集、验证集、测试集
example_train_dataset = tokenized_datasets["test"].select(range(10))
example_valid_dataset = tokenized_datasets["test"].select(range(10, 20))
example_test_dataset = tokenized_datasets["test"].select(range(20, 30))

7. 设置训练和评估超参数

#训练阶段超参数
training_args = TrainingArguments(output_dir="test_checkpoint",  #训练权重存储路径evaluation_strategy="epoch",  #每个epoch评估一次report_to="none",  #设置不用wandb记录训练日志push_to_hub=False,  #不将训练权重上传到huggingface hubnum_train_epochs=max_epoch_num,  #训练轮数per_device_train_batch_size=1,  #每个设备(CPU/GPU)上的batch sizeno_cuda=True,  #不使用GPU(我是因为电脑上GPU内存只有2G根本带不动,反正只是个示例代码干脆不用了)do_train=True,  #训练save_steps=10,  #每10个step保存一次checkpointsave_total_limit=1,  #保存最后1个checkpoint
)#调用evaluate包的准确率评估指标
metric = evaluate.load("accuracy")def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)#设置训练器
trainer = Trainer(model=model,args=training_args,train_dataset=example_train_dataset,eval_dataset=example_valid_dataset,compute_metrics=compute_metrics,
)

8. 训练

trainer.train()

一边训练一边就会保存训练权重,打印进度条和评估指标:
在这里插入图片描述

9. 测试

直接得到测试集上的预测结果。
predictions是每个样本对应的logits(两个元素分别是正负标签归一化前的预测概率)。
label_ids是标签。
metrics是评估指标(损失函数、准确率、运行时间)

result = trainer.predict(example_test_dataset)
print(result)

输出:

PredictionOutput(predictions=array([[-2.42961  ,  3.2694852],[-2.601208 ,  3.2841954],[-2.4568973,  3.0550063],[-2.48325  ,  3.239044 ],[-2.5793636,  3.283456 ],[-2.356497 ,  2.8706086],[-2.5480254,  3.2449925],[-2.4836488,  3.2778778],[-2.6639569,  3.2867384],[-2.5087523,  3.067698 ]], dtype=float32), label_ids=array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int64), metrics={'test_loss': 0.003417523577809334, 'test_accuracy': 1.0, 'test_runtime': 17.7476, 'test_samples_per_second': 0.563, 'test_steps_per_second': 0.113})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/1613.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习实战(13)】训练之加载预训练权重

一、预训练权重加载流程 预训练权值加载流程:pretrained_dict -> yolo_dict(backbone) -> temp_dict(与model_dict匹配上的) -> model_dict(load_state_dict加载) 二、代码 # ------------------------------------------------------# # 预训练权值加…

手写java设计模式之单例模式,附源码解读

在Java应用中,单例对象能保证在一个JVM中,该对象只有一个实例存在。这样的模式有几个好处: 1、减少类的频繁创建,减少使用频繁使用new创建实例,减少GC压力。 2、某些应用场景下,使用单例模式,…

绿联搭建rustdesk服务器

绿联搭建rustdesk服务器,不再使用向日葵 注意:本服务器需要有动态公网IP以及自己的域名,ipv6未测试。 1. 拉取镜像 rustdesk/rustdesk-server-s6:latest 注意是这个-s6的镜像。 2. 部署镜像 2.1 内存配置 本服务器比较省内存&#xff0…

关于agi中的Function Calling深入解析

接口(Interface) 两种常见接口: 1、人机交互接口,User Interface,简称UI 2、应用程序编程接口,Application Programming Interface,简称API 接口能【通】的关键,是两边都要遵守约定。 人要按照UI的设计来操作。UI的设计要符合…

layer弹出层点击关闭按钮刷新父页面

在弹出层页面&#xff0c;&#xff0c;找到layer关闭按钮&#xff0c;写一个关闭事件&#xff0c;里面去执行js方法。 例&#xff1a;页面写个a标签方便调用&#xff1a;<a id“hidalayerclose” style“display: none;” οnclick“fureload()”> parent.$(".lay…

Android Studio实现内容丰富的安卓养老平台

获取源码请点击文章末尾QQ名片联系&#xff0c;源码不免费&#xff0c;尊重创作&#xff0c;尊重劳动 158安卓养老 1.开发环境 后端用springboot框架&#xff0c;安卓的用android studio开发android stuido3.6 jak1.8 idea mysql tomcat 2.功能介绍 安卓端&#xff1a; 1.注册登…

mac: nvm is already installed in /Users/**/.nvm, trying to update using git

如图吐了&#xff0c;安装了nvm后出现了如下问题&#xff1a; nvm is already installed in /Users/**/.nvm, trying to update using git 原因分析&#xff1a; 这种情况可能出现在安装脚本检测到 nvm 已经存在于系统中&#xff0c;但是由于某些原因&#xff0c;终端无法识…

SOCKS5代理IP指什麼?

SOCKS5代理IP是一種網路協議&#xff0c;它可以在客戶端和目標伺服器之間建立一個隧道&#xff0c;以進行數據交換&#xff0c;並隱藏用戶的真實IP地址。它是SOCKS協議的最新版本&#xff0c;不僅可以支持TCP和UDP協議&#xff0c;還支持各種類型的網路請求&#xff0c;包括HTT…

【数据结构(八)上】二叉树经典习题

❣博主主页: 33的博客❣ ▶文章专栏分类: Java从入门到精通◀ &#x1f69a;我的代码仓库: 33的代码仓库&#x1f69a; &#x1faf5;&#x1faf5;&#x1faf5;关注我带你学更多数据结构的知识 目录 1.前言2.经典习题2.1相同的树2.2另一棵子树2.3翻转二叉树2.4平衡二叉树2.5对…

直播美颜工具与视频美颜SDK:技术深入探索

直播美颜工具和视频美颜SDK的出现&#xff0c;为直播平台和应用开发者提供了丰富的选择。本文将深入探讨这些技术的原理、应用和发展趋势。 一、美颜算法 直播美颜工具的核心在于其先进的美颜算法。这些算法通过对图像进行分析和处理&#xff0c;实时地修饰主播的面部特征&am…

vsstudio 如何远程调试

你可能需要调试一个在本地生成的 Windows 桌面项目,然后在远程计算机上运行可执行文件。本主题阐释如何更改本地项目设置以在远程计算机上运行应用程序。C++ 项目会自动部署到远程计算机。您将需要手动部署 .NET Framework 可执行文件。 设置 Visual C++ 项目 此处显示的过程…

项目开发流程

项目开发流程 &#x1f469;‍&#x1f9b3;项目立项 估计项目的花费&#xff0c;确定大致的所需开发人员数&#xff0c;确定项目是否可行&#xff1b; &#x1f469;‍&#x1f9b0;需求分析 整体过程&#xff1a; 项目背景和目标&#xff0c;即项目的目的是什么 用户需求&…

Springboot 操作Mongodb(一)

MongoDB概念 MongoDB 基本概念指的是学习 MongoDB 最先应该了解的词汇&#xff0c;比如 MongoDB 中的"数据库"、"集合"、"文档"这三个名词&#xff1a; 文档&#xff08;Document&#xff09;&#xff1a; 文档是 MongoDB 中最基本的数据单元&…

SQLAIchemy 异步DBManager封装-01入门理解

前言 SQLAlchemy 是一个强大的 Python SQL 工具包和对象关系映射&#xff08;ORM&#xff09;系统&#xff0c;是业内比较流行的ORM&#xff0c;设计非常优雅。随着其2.0版本的发布&#xff0c;SQLAlchemy 引入了原生的异步支持&#xff0c;这极大地增强了其在处理高并发和异步…

Windows 的常用命令(不分大小写)

Net user &#xff08;查看当前系统所有的账户&#xff09; net user yourname password /add 添加新用户 net localgroup administrators yourname /add 添加管理员权限 net user yourname /delete 删除用户 net user 命令 [colorred]说明&#xff1a;以下命令仅限持管理员…

opencv人脸打马赛克

import cv2def FaceFind(imgPath: str) -> list:image cv2.imread(imgPath)gray cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)face_cascade cv2.CascadeClassifier(haarcascade_frontalface_default.xml)# 返回人脸坐标列表faces face_cascade.detectMultiScale(gray, scal…

数据结构11:二叉树的链式结构

文章目录 快速创建链式二叉树二叉树的遍历前序、中序、后序层序 二叉树的基本操作二叉树的节点个数二叉树叶节点的个数二叉树第k层结点个数二叉树查找值为x的结点 二叉树基础oj练习单值二叉树检查两颗树是否相同对称二叉树二叉树的前序遍历另一颗树的子树 二叉树的创建和销毁二…

谷雨时节,雨水渐多湿气旺盛,吃什么养生?听听张婉如医生怎么说

谷雨春光晓&#xff0c;山川黛色青。 叶间鸣戴胜&#xff0c;泽水长浮萍。 4月19日21时59分迎来谷雨&#xff0c;雨生百谷&#xff0c;这是谷雨节气的意思&#xff0c;它是春季的最后一个节气&#xff0c;这个时节早晚温差大&#xff0c;空气湿气重&#xff0c;如何养生呢&am…

java在线问卷调查系统的设计与实现(springboot+mysql源码+文档)

风定落花生&#xff0c;歌声逐流水&#xff0c;大家好我是风歌&#xff0c;混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的在线问卷调查系统。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 基于java的在线问卷调查…

linux内核源码分析--通用函数指针

除了稍早讨论的net_device 结构的列表管理字段外&#xff0c;还有一些字段用于管理一些结构&#xff0c;确保这些结构在不需要时能予以删除。 atomic_t refcnt 引用计数&#xff0c;此计数器变为零之前&#xff0c;设备无法除名&#xff0c;参见第八章。 int watchdog_timeo st…