NLP - 基于bert预训练模型的文本多分类示例

项目说明

项目名称

基于DistilBERT的标题多分类任务

项目概述

本项目旨在使用DistilBERT模型对给定的标题文本进行多分类任务。项目包括从数据处理、模型训练、模型评估到最终的API部署。该项目采用模块化设计,以便于理解和维护。

项目结构

.
├── bert_data
│   ├── train.txt
│   ├── dev.txt
│   └── test.txt
├── saved_model
├── results
├── logs
├── data_processing.py
├── dataset.py
├── training.py
├── app.py
└── main.py

文件说明

  1. bert_data/:存放训练集、验证集和测试集的数据文件。

    • train.txt
    • dev.txt
    • test.txt
  2. saved_model/:存放训练好的模型和tokenizer。

  3. results/:存放训练结果。

  4. logs/:存放训练日志。

  5. data_processing.py:数据处理模块,负责读取和预处理数据。

  6. dataset.py:数据集类模块,定义了用于训练和评估的数据集类。

  7. training.py:模型训练模块,定义了训练和评估模型的过程。

  8. app.py:模型部署模块,使用FastAPI创建API服务。

  9. main.py:主脚本,运行整个流程,包括数据处理、模型训练和部署。

数据集数据规范

为了确保数据处理和模型训练的顺利进行,请按照以下规范准备数据集文件。每个文件包含的标题和标签分别使用制表符(\t)分隔。以下是一个示例数据集的格式。

数据文件格式

数据文件应为纯文本文件,扩展名为.txt,文件内容的每一行应包含一个文本标题和一个对应的分类标签,用制表符分隔。数据文件不应包含表头。

数据示例
探索神秘的海底世界    7
如何在家中制作美味披萨    2
全球气候变化的原因和影响    1
最新的智能手机评测    8
健康饮食:如何搭配均衡的膳食    5
最受欢迎的电影和电视剧推荐    3
了解宇宙的奥秘:天文学入门    0
如何种植和照顾多肉植物    9
时尚潮流:今年夏天的必备单品    6
如何有效管理个人财务    4

注意事项

  • 标签规范:确保每个标题文本的标签是一个整数,表示类别。
  • 文本编码:确保数据文件使用UTF-8编码,避免中文字符乱码。
  • 数据一致性:确保训练、验证和测试数据格式一致,便于数据加载和处理。

通过以上规范和示例数据文件创建方法,可以确保数据文件符合项目需求,并顺利进行数据处理和模型训练。

模块说明

1. 数据处理模块 (data_processing.py)

功能:读取数据文件并进行预处理。

  • load_data(file_path): 读取指定路径的数据文件,并返回一个包含文本和标签的数据框。
  • tokenize_data(data, tokenizer, max_length=128): 使用BERT的tokenizer对数据进行tokenize处理。
  • main(): 加载数据、tokenize数据并返回处理后的数据。
2. 数据集类模块 (dataset.py)

功能:定义数据集类,便于模型训练。

  • TextDataset: 将tokenized数据和标签封装成PyTorch的数据集格式,便于Trainer进行训练和评估。
3. 模型训练模块 (training.py)

功能:定义训练和评估模型的过程。

  • train_model(): 加载数据和tokenizer,创建数据集,加载模型,设置训练参数,定义Trainer,训练和评估模型,保存训练好的模型和tokenizer。
4. 模型部署模块 (app.py)

功能:使用FastAPI进行模型部署。

  • predict(item: Item): 接收POST请求的文本输入,使用训练好的模型进行预测并返回分类结果。
  • FastAPI应用启动配置。
5. 主脚本 (main.py)

功能:运行整个流程,包括数据处理、模型训练和部署。

  • main(): 运行模型训练流程,并输出训练完成的提示。

运行步骤

  1. 安装依赖
pip install pandas torch transformers fastapi uvicorn scikit-learn
  1. 数据处理

确保bert_data文件夹下包含train.txtdev.txttest.txt文件,每个文件包含文本和标签,使用制表符分隔。

  1. 训练模型

运行main.py脚本,进行数据处理和模型训练:

python main.py

训练完成后,模型和tokenizer将保存在saved_model文件夹中。

  1. 部署模型

运行app.py脚本,启动API服务:

uvicorn app:app --reload

服务启动后,可以通过POST请求访问预测接口,进行文本分类预测。

示例请求

curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{"text": "你的文本"}'

返回示例:

{"prediction": 3
}

注意事项

  • 确保数据文件格式正确,每行包含一个文本和对应的标签,使用制表符分隔。
  • 调整训练参数(如batch size和训练轮数)以适应不同的GPU配置。
  • 使用nvidia-smi监控显存使用,避免显存溢出。

项目代码

1. 数据处理模块

功能:读取数据文件并进行预处理。

# data_processing.py
import pandas as pd
from transformers import DistilBertTokenizerdef load_data(file_path):data = pd.read_csv(file_path, delimiter='\t', header=None)data.columns = ['text', 'label']return datadef tokenize_data(data, tokenizer, max_length=128):encodings = tokenizer(list(data['text']), truncation=True, padding=True, max_length=max_length)return encodingsdef main():# 加载Tokenizertokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-chinese')# 加载数据train_data = load_data('./bert_data/train.txt')dev_data = load_data('./bert_data/dev.txt')test_data = load_data('./bert_data/test.txt')# Tokenize数据train_encodings = tokenize_data(train_data, tokenizer)dev_encodings = tokenize_data(dev_data, tokenizer)test_encodings = tokenize_data(test_data, tokenizer)return train_encodings, dev_encodings, test_encodings, train_data['label'], dev_data['label'], test_data['label']if __name__ == "__main__":main()

2. 数据集类模块

功能:定义数据集类,便于模型训练。

# dataset.py
import torchclass TextDataset(torch.utils.data.Dataset):def __init__(self, encodings, labels):self.encodings = encodingsself.labels = labelsdef __getitem__(self, idx):item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}item['labels'] = torch.tensor(self.labels[idx])return itemdef __len__(self):return len(self.labels)

3. 模型训练模块

功能:定义训练和评估模型的过程。

# training.py
import torch
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from dataset import TextDataset
import data_processingdef train_model():# 加载数据和tokenizertrain_encodings, dev_encodings, test_encodings, train_labels, dev_labels, test_labels = data_processing.main()# 创建数据集train_dataset = TextDataset(train_encodings, train_labels)dev_dataset = TextDataset(dev_encodings, dev_labels)test_dataset = TextDataset(test_encodings, test_labels)# 加载DistilBERT模型model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-chinese', num_labels=10)model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))# 设置训练参数training_args = TrainingArguments(output_dir='./results',          # 输出结果目录num_train_epochs=3,              # 训练轮数per_device_train_batch_size=16,  # 训练时每个设备的批量大小per_device_eval_batch_size=64,   # 验证时每个设备的批量大小warmup_steps=500,                # 训练步数weight_decay=0.01,               # 权重衰减logging_dir='./logs',            # 日志目录fp16=True,                       # 启用混合精度训练)# 定义Trainertrainer = Trainer(model=model,                         # 预训练模型args=training_args,                  # 训练参数train_dataset=train_dataset,         # 训练数据集eval_dataset=dev_dataset             # 验证数据集)# 训练模型trainer.train()# 评估模型eval_results = trainer.evaluate()print(eval_results)# 保存模型model.save_pretrained('./saved_model')tokenizer.save_pretrained('./saved_model')if __name__ == "__main__":train_model()

4. 模型部署模块

功能:使用FastAPI进行模型部署。

# app.py
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torchapp = FastAPI()# 加载模型和tokenizer
model = DistilBertForSequenceClassification.from_pretrained('./saved_model')
tokenizer = DistilBertTokenizer.from_pretrained('./saved_model')class Item(BaseModel):text: str@app.post("/predict")
def predict(item: Item):inputs = tokenizer(item.text, return_tensors="pt", max_length=128, padding='max_length', truncation=True)outputs = model(**inputs)prediction = torch.argmax(outputs.logits, dim=1)return {"prediction": prediction.item()}if __name__ == "__main__":import uvicornuvicorn.run(app, host="0.0.0.0", port=8000)

5. 主脚本

功能:运行整个流程,包括数据处理、模型训练和部署。

# main.py
import trainingdef main():# 训练模型training.train_model()print("模型训练完成并保存。")if __name__ == "__main__":main()

详细说明

  1. 数据处理模块

    • 读取训练集、验证集和测试集的数据文件。
    • 使用BERT的Tokenizer对数据进行tokenize处理,生成模型可接受的输入格式。
    • 提供主要的数据处理函数,包括加载数据和tokenize数据。
  2. 数据集类模块

    • 定义一个TextDataset类,用于将tokenized数据和标签封装成PyTorch的数据集格式,便于Trainer进行训练和评估。
  3. 模型训练模块

    • 使用数据处理模块加载和tokenize数据。
    • 创建训练和验证数据集。
    • 加载DistilBERT模型,并设置训练参数(包括启用混合精度训练)。
    • 使用Trainer进行模型训练和评估,并保存训练好的模型。
  4. 模型部署模块

    • 使用FastAPI创建一个简单的API服务。
    • 加载保存的模型和tokenizer。
    • 定义一个预测接口,通过POST请求接收文本输入并返回分类预测结果。
  5. 主脚本

    • 运行模型训练流程,并输出训练完成的提示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/40200.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

苹果AI的国产大模型之争,没有悬念

文 | 智能相对论 作者 | 陈泊丞 苹果终于公布了最新的AI进程。 一个月前,正如此前预期的那样,人工智能是今年 WWDC 发布会的焦点。全程105分钟的主题演讲,就有40多分钟用于介绍苹果的AI成果。 苹果似乎还有意玩了一把“谐音梗”&#xff…

用机器改变人类方向

1800 世纪初,美国迎来了工业革命,这是一个由技术进步推动的变革时代。新机器和制造技术的引入重塑了经济格局,提高了生产效率,同时减少了某些领域对手工劳动的需求。因此,这种转变导致了失业。 如今,我们看…

实现点击按钮导出页面pdf

在Vue 3 Vite项目中,你可以使用html2canvas和jspdf库来实现将页面某部分导出为PDF文档的功能。以下是一个简单的实现方式: 1.安装html2canvas和jspdf: pnpm install html2canvas jspdf 2.在Vue组件中使用这些库来实现导出功能:…

统计信号处理基础 习题解答11-11

题目 考虑矢量MAP估计量 证明这个估计量对于代价函数 使贝叶斯风险最小。其中:, ,且. 解答 贝叶斯风险函数: 基于概率密度的非负特性,上述对积分要求最小,那就需要内层积分达到最小。令内层积分为: 上述积…

苹果Mac电脑能玩什么游戏 Mac怎么运行Windows游戏

相对于Windows平台来说,Mac电脑可玩的游戏较少。虽然苹果设备的性能足以支持各种大型游戏,但由于系统以及苹果配套服务的限制,很多游戏无法在Mac系统中运行。不过,借助虚拟机软件,Mac电脑可以突破系统限制玩更多的游戏…

react中jsx的语法规则

1.react核心库react.development.js 2.react_dom库,用于支持react操作dom(react-dom.development.js) 3.引入bable,解析jsx语法的库,用于将jsx转换为js(babel.min.js) 上述三个库是写基础react的基本库 下面我将用…

光照老化试验箱在化工产品暴晒测试中的应用

概述 光照老化试验箱是一种模拟自然光照条件下材料老化情况的实验设备,广泛应用于化工、建材、电子、汽车等行业中对材料的耐候性、耐光性能等进行测试。通过模拟日光中的紫外线和温度等环境因素,加速材料老化过程,以此评估材料在长期使用中…

2024阿里云大模型自定义插件(如何调用自定义接口)

1,自定义插件入口 2,插件定义:描述插件的参数 2.1,注意事项: 2.1.1,只支持json格式的参数;只支持application/JSON;如下图: 2.1.2,需要把接口描述进行修改&a…

03:Spring MVC

文章目录 一:Spring MVC简介1:说说自己对于Spring MVC的了解?1.1:流程说明: 一:Spring MVC简介 Spring MVC就是一个MVC框架,Spring MVC annotation式的开发比Struts2方便,可以直接代…

LeetCode 算法:二叉搜索树中第K小的元素 c++

原题链接🔗:二叉搜索树中第K小的元素 难度:中等⭐️⭐️ 题目 给定一个二叉搜索树的根节点 root ,和一个整数 k ,请你设计一个算法查找其中第 k 小的元素(从1开始计数)。 示例 1:…

网络爬虫之什么是代码混淆?初步理解代码混淆

爬虫逆向之什么是代码混淆?初步理解代码混淆 在网络爬虫和逆向工程的过程中,代码混淆是一项常见的技术,旨在保护代码不被轻易理解和逆向。对于爬虫工程师来说,理解并破解代码混淆是一个重要的技能。本文将详细介绍代码混淆的基本概…

GUI开发

Question One Java 实现动作监听,网格布局添加四个按钮,实现四个不同的文本显示 import java.awt.*; import java.awt.event.*; import javax.swing.*;class myGUI extends JFrame implements ActionListener{private Button b1, b2, b3, b4;private Tex…

0627,0628,0629,排序,文件

01:请实现选择排序,并分析它的时间复杂度,空间复杂度和稳定性 void selection_sort(int arr[], int n); 解答: 稳定性:稳定, 不稳定的,会发生长距离的交换 4 9 9 4 1 &#xf…

ubuntu,linux下屏蔽坏块方法-240625-240702封存

在windows下的屏蔽坏道的方法 机械硬盘坏道的文件系统级别的屏蔽方法_硬盘如何屏蔽坏扇区-CSDN博客 https://blog.csdn.net/cyuyan112233/article/details/139408503?spm1001.2014.3001.5502 【免费】磁盘坏道屏蔽工具磁盘坏道屏蔽工具_机械硬盘屏蔽坏扇区资源-CSDN文库 https…

第一周题目总结

1.车尔尼有一个数组 nums ,它只包含 正 整数,所有正整数的数位长度都 相同 。 两个整数的 数位不同 指的是两个整数 相同 位置上不同数字的数目。 请车尔尼返回 nums 中 所有 整数对里,数位不同之和。 示例 1: 输入&#xff1a…

【嵌入式DIY实例-ESP8266篇】-LCD ST7735显示网络时间

LCD ST7735显示网络时间 文章目录 LCD ST7735显示网络时间1、硬件准备2、代码实现本文将介绍如何使用 ESP8266 NodeMCU Wi-Fi 板实现互联网时钟,其中时间和日期显示在 ST7735 TFT 显示屏上。 ST7735 TFT是一款分辨率为128160像素的彩色显示屏,采用SPI协议与主控设备通信。 1…

Python中的变量和数据类型:Python中有哪些基本数据类型以及变量是如何声明的

在Python中,变量是用来存储数据的容器,而数据类型则定义了这些数据的种类。Python是一种动态类型语言,这意味着你不需要在声明变量时指定其类型;Python解释器会在运行时自动确定变量的类型。 Python中的基本数据类型 Python中有…

SQL语句(DML)

DML英文全称是Data Manipulation Language(数据操作语言),用来对数据库中表的数据记录进行增删改等操作 DML-添加数据 insert into employee(id, workno, name, gender, age, idcard) values (1,1,Itcast,男,10,123456789012345678);select *…

AI 与数据的智能融合丨大模型时代下的存储系统

WOT 全球技术创新大会2024北京站于 6 月 22 日圆满落幕。本届大会以“智启新纪,慧创万物”为主题,邀请到 60 位不同行业的专家,聚焦 AIGC、领导力、研发效能、架构演进、大数据等热门技术话题进行分享。 近年来,数据和人工智能已…

记录搭建一台可域名访问的HTTPS服务器

一、背景 近期公司业务涉及到微信小程序,即将开发完成需要按照微信小程序平台的要求提供带证书的域名请求服务器。 资源背景介绍如下: 1、域名 公司已有一个二级域名,再次申请新的二级域名并且实现ICP备案不仅需要花重金重新购买,…