何使用BERT模型实现中文的文本分类

原文网址:https://blog.csdn.net/Real_Brilliant/article/details/84880528

如何使用BERT模型实现中文的文本分类

  • 前言
  • Pytorch
    • readme
    • 参数表
    • 算法流程
      • 1. 概述
      • 2. 读取数据
      • 3. 特征转换
      • 4. 模型训练
      • 5. 模型测试
      • 6. 测试结果
      • 7. 总结

前言

  1. Google官方BERT代码(Tensorflow)
  2. 本文章参考的BERT分类代码(Pytorch)
  3. 本文章改进的BERT中文文本分类代码(Pytorch)
  4. BERT模型介绍

Pytorch

readme

  • 请先安装pytorch的BERT代码,代码源见前言(2)
    pip install pytorch-pretrained-bert
    
    • 1

参数表

data_dirbert_modeltask_name
输入数据目录加载的bert模型,对于中文文本请输入’bert-base-chinese输入数据预处理模块,最好根据应用场景自定义
model_save_pthmax_seq_length*train_batch_size
模型参数保存地址最大文本长度batch大小
learning_ratenum_train_epochs
Adam初始学习步长最大epoch数

* max_seq_length = 所设定的文本长度 + 2 ,BERT会给每个输入文本开头和结尾分别加上[CLS][SEP]标识符,因此会占用2个字符空间,其作用会在后续进行详细说明。

算法流程

1. 概述

训练阶段
利用验证集调整参数
选取验证集上得分最高的模型
测试阶段
加载预训练模型
读取数据
特征转换
模型训练
保存最佳模型参数
加载训练阶段最佳模型
读取数据
特征转换
输入模型并进行测试

2. 读取数据

  • 对应于参数表中的task_name,是用于数据读取的模块
  • 可以根据自身需要自定义新的数据读取模块
  • 以输入数据为json文件时为例,数据读取模块包含两个部分:
    • 基类DataProcessor
      class DataProcessor(object):		def get_train_examples(self, data_dir):raise NotImplementedError()
      
      def get_dev_examples(self, data_dir):raise NotImplementedError()def get_test_examples(self, data_dir):raise NotImplementedError()def get_labels(self):raise NotImplementedError()@classmethod
      def _read_json(cls, input_file, quotechar=None):"""Reads a tab separated value file."""dicts = []with codecs.open(input_file, 'r', 'utf-8') as infs:for inf in infs:inf = inf.strip()dicts.append(json.loads(inf))return dicts
      
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 用于数据读取的模块MyPro
    class MyPro(DataProcessor):def get_train_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), 'train')
    
    def get_dev_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "val.json")), 'dev')def get_test_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), 'test')def get_labels(self):return [0, 1]def _create_examples(self, dicts, set_type):examples = []for (i, infor) in enumerate(dicts):guid = "%s-%s" % (set_type, i)text_a = infor['question']label = infor['label']examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examples
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
  • 需要注意的几点是:
    1. data_dir目录下应包含名为trainvaltest的三个文件,根据文件格式不同需要对读取方式进行修改
    2. get_labels()返回的是所有可能的类别label_list,比如['数学', '英语', '语文'][1, 2, 3]
    3. 模块最终返回一个名为examples的列表,每个列表元素中包含序号、中文文本、类别三个元素
  • 3. 特征转换

    • convert_examples_to_features是用于将examples转换为特征,也即features的函数。
    • features包含4个数据:
      • input_ids:分词后每个词语在vocabulary中的id,补全符号对应的id为0,[CLS][SEP]的id分别为101和102。应注意的是,在中文BERT模型中,中文分词是基于字而非词的分词
      • input_mask:真实字符/补全字符标识符,真实文本的每个字对应1,补全符号对应0,[CLS][SEP]也为1。
      • segment_ids:句子A和句子B分隔符,句子A对应的全为0,句子B对应的全为1。但是在多数文本分类情况下并不会用到句子B,所以基本不用管。
      • label_id :将label_list中的元素利用字典转换为index标识,即
        label_map = {}
        for (i, label) in enumerate(label_list):label_map[label] = i
        
        • 1
        • 2
        • 3
    • features中一个元素的例子是:
      在这里插入图片描述
    • 转换完成后的特征值就可以作为输入,用于模型的训练和测试

    4. 模型训练

    • 完成读取数据、特征转换之后,将特征送入模型进行训练
    • 训练算法为BERT专用的Adam算法
    • 训练集、测试集、验证集比例为6:2:2
    • 每一个epoch后会在验证集上进行验证,并给出相应的f1值,如果f1值大于此前最高分则保存模型参数,否则flags加1。如果flags大于6,也即连续6个epoch模型的性能都没有继续优化,停止训练过程。
      f1 = val(model, processor, args, label_list, tokenizer, device)
      if f1 > best_score:best_score = f1print('*f1 score = {}'.format(f1))flags = 0checkpoint = {'state_dict': model.state_dict()}torch.save(checkpoint, args.model_save_pth)
      else:print('f1 score = {}'.format(f1))flags += 1if flags >=6:break
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
    • 如果epoch数超过先前设定的num_train_epochs,同样会停止迭代。

    5. 模型测试

    • 先加载模型
    • 送数据,取得分,完事
    • 暂时还没加打印测试结果到文件的功能,后续会加上

    6. 测试结果

    val_F1test_F1
    Fast text0.72180.7094
    Text rnn + bigru0.73830.7194
    Text cnn0.72920.7088
    bigru + attention0.73350.7146
    RCNN0.73550.7213
    BERT0.79380.787
    • 基于真实数据做的文本分类,用过不少模型,BERT的性能可以说是独一档
    • BERT确实牛逼,不过一部分原因也是模型量级就不一样

    7. 总结

    • 使用代码的时候按照参数表修改下参数,把数据按照命名规范放data_dir目录下一般就没啥问题了
    • 最多还要修改下读取数据的代码(如果数据不是.json格式的),就可以跑通了
    • 最后可以根据个人需要,对模型训练逻辑、epoch数、学习步长等地方做进一步修改
    • 代码地址已经放在前言(3)里了
                                    </div><div data-report-view="{&quot;mod&quot;:&quot;1585297308_001&quot;,&quot;dest&quot;:&quot;https://blog.csdn.net/Real_Brilliant/article/details/84880528&quot;,&quot;extend1&quot;:&quot;pc&quot;,&quot;ab&quot;:&quot;new&quot;}"><div></div></div><link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet"></div>
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479834.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 152. 乘积最大子序列(DP)

文章目录1. 题目信息2. 解题1. 题目信息 给定一个整数数组 nums &#xff0c;找出一个序列中乘积最大的连续子序列&#xff08;该序列至少包含一个数&#xff09;。 示例 1:输入: [2,3,-2,4] 输出: 6 解释: 子数组 [2,3] 有最大乘积 6。 示例 2:输入: [-2,0,-1] 输出: 0 解释…

YouTube深度学习推荐系统的十大工程问题

文 | 王喆源 | https://zhuanlan.zhihu.com/p/52504407这篇文章主要介绍了YouTube深度学习系统论文中的十个工程问题&#xff0c;为了方便进行问题定位&#xff0c;我们还是简单介绍一下背景知识&#xff0c;简单回顾一下YouTube的深度推荐系统论文Deep Neural Networks for Yo…

RPC框架的实现原理,及RPC架构组件详解

RPC的由来 随着互联网的发展&#xff0c;网站应用的规模不断扩大&#xff0c;常规的垂直应用架构已无法应对&#xff0c;分布式服务架构以及流动计算架构势在必行&#xff0c;亟需一个治理系统确保架构有条不紊的演进。 单一应用架构 当网站流量很小时&#xff0c;只需一个应…

论文浅尝 | 时序与因果关系联合推理

论文笔记整理&#xff1a;李昊轩&#xff0c;南京大学硕士&#xff0c;研究方向为知识图谱、自然语言处理。来源&#xff1a;ACL2018链接&#xff1a;https://www.aclweb.org/anthology/P18-1212动机理解事件之间的时间和因果关系是一项基本的自然语言理解任务。由于原因一定先…

LeetCode 64. 最小路径和(DP)

文章目录1. 题目信息2. 解题1. 题目信息 给定一个包含非负整数的 m x n 网格&#xff0c;请找出一条从左上角到右下角的路径&#xff0c;使得路径上的数字总和为最小。 说明&#xff1a;每次只能向下或者向右移动一步。 示例:输入: [[1,3,1],[1,5,1],[4,2,1] ] 输出: 7 解释…

加快读博失败的10种方法

文 | 德先生源 | 知乎博士难毕&#xff0c;全球皆如此。差不多每个学校都有1/3到一半的博士研究生拿不到学位。读博失败不仅本人难受&#xff0c;导师也不好过。为了帮助自己的博士生们顺利毕业&#xff0c;犹他大学的Matt Might教授&#xff08;计算机科学家&#xff0c;生物学…

要成为一个 Java 架构师得学习哪些知识以及方法?

“ 最近在架构师线下实战中&#xff0c;被问到最多的一个问题&#xff0c;就是要成为一个 Java 架构师得学习哪些知识&#xff0c;以及怎样才能做到架构师这个级别&#xff1f; 今天主要澄清几个关于架构师的几大误区。 架构师并不是人人都能做到的&#xff0c;如果你能走到…

论文浅尝 | 一种可解释的语义匹配复值网络

笔记整理&#xff1a;耿玉霞&#xff0c;浙江大学直博生。研究方向&#xff1a;知识图谱&#xff0c;零样本学习&#xff0c;自然语言处理等。论文链接&#xff1a;https://arxiv.org/pdf/1904.05298.pdf本文是发表在 NAACL 2019 上的最佳可解释性论文。受量子力学中数学模型的…

指针都没搞懂,还能算得上 C++ 老司机?

在工业界&#xff0c;有这样一个规律&#xff1a;“ 但凡能用其他语言的都不会用C&#xff0c;只能用C的必然用C。”但是&#xff0c;C的学习和项目开发都比较困难。一个有经验的老手也经常搞出野指针&#xff0c;内存泄露等bug&#xff0c;包括我自己在学C的时候也非常痛苦。所…

DSSM、CNN-DSSM、LSTM-DSSM等深度学习模型在计算语义相似度上的应用+距离运算

在NLP领域&#xff0c;语义相似度的计算一直是个难题&#xff1a;搜索场景下query和Doc的语义相似度、feeds场景下Doc和Doc的语义相似度、机器翻译场景下A句子和B句子的语义相似度等等。本文通过介绍DSSM、CNN-DSSM、LSTM-DSSM等深度学习模型在计算语义相似度上的应用&#xff…

如何才能真正的提高自己,真正成为一名出色的架构师?

“ 有读者朋友给我留言&#xff0c;如何才能真正的提高自己&#xff0c;成为一名架构师&#xff0c;有学习各种语言的小伙伴。 这里我结合我的学习方法论&#xff0c;再结合我自己的经验&#xff0c;分享部分心得&#xff0c;希望对你有所帮助。 欢迎小伙伴留言给到你现在遇…

论文浅尝 | 从知识图谱流中学习时序规则

论文笔记整理&#xff1a;汪寒&#xff0c;浙江大学硕士&#xff0c;研究方向为知识图谱、自然语言处理。链接&#xff1a;http://ceur-ws.org/Vol-2350/paper15.pdf动机知识图谱是现在十分流行的数据管理方式&#xff0c;在最近几年应用广泛。但目前的基于KG的规则挖掘主要都是…

如何选择数据结构和算法(转)

文章目录1. 时间、空间复杂度 ! 性能2. 抛开数据规模谈数据结构和算法都是“耍流氓”3. 结合数据特征和访问方式来选择数据结构4. 区别对待IO密集、内存密集和计算密集5. 善用语言提供的类&#xff0c;避免重复造轮子6. 千万不要漫无目的地过度优化熟知每种数据结构和算法的功能…

Linux服务器安装cuda,cudnn,显卡驱动和pytorch超详细流程

原文链接&#xff1a;https://blog.csdn.net/kingfoulin/article/details/98872965 基本的环境 首先了解自己服务器的操作系统内核版本等信息&#xff1a; 查看自己操作系统的版本信息&#xff1a;cat /etc/issue或者是 cat /etc/lsb-release等命令 查看服务器显卡信息&…

自训练:超越预训练,展现强大互补特性的上分新范式!

文 | 香侬科技编 | 兔子酱背景预训练&#xff08;Pre-training&#xff09;模型自BERT问世以来就一发不可收拾&#xff0c;目前已经在自然语言理解和生成两个方面取得了突破性成就。但是&#xff0c;作为它的一个“兄弟”&#xff0c;自训练&#xff08;Self-training&#xff…

论文浅尝 | 通过文本到文本神经问题生成的机器理解

论文笔记整理&#xff1a;程茜雅&#xff0c;东南大学硕士&#xff0c;研究方向&#xff1a;自然语言处理&#xff0c;知识图谱。Citation: Yuan X, WangT, Gulcehre C, et al. Machine comprehension by text-to-text neural question generation[J]. arXiv preprint arXiv:17…

安装paddlepaddle-GPU 报libcudnn.so和libcublas.so找不到的解决方案

第一步&#xff0c;查找两个的文件位置 第二步&#xff1a; 由于cudcun实在cuda10.0的基础上安装的&#xff0c;解压cudcnn的tar包之后会出现一个cuda-10.0文件夹&#xff0c;而不是cuda. 第三步&#xff1a; 在一步出现的位置找到了和libcublas.so.10对其进行了重命名就好了…

LeetCode 221. 最大正方形(DP)

文章目录1. 题目信息2. 解题1. 题目信息 在一个由 0 和 1 组成的二维矩阵内&#xff0c;找到只包含 1 的最大正方形&#xff0c;并返回其面积。 示例: 来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 链接&#xff1a;https://leetcode-cn.com/problems/maximal-squ…

anaconda配置虚拟环境

一般是在服务器上&#xff0c;创建一个自己的虚拟环境&#xff0c;自己来用&#xff0c;不影响别人的环境&#xff0c;也不用被别人安装环境影响。 打开终端 1.查看当前存在哪些虚拟环境 conda env list 或 conda info -e 2.创建名字为 lly_env 的虚拟环境(名字自己取一个&am…

论文浅尝 | 知识图谱推理中表示学习和规则挖掘的迭代学习方法

作者&#xff1a;张文&#xff0c;浙江大学在读博士&#xff0c;研究方向为知识图谱的表示学习&#xff0c;推理和可解释。本文是我们与苏黎世大学以及阿里巴巴合作的工作&#xff0c;发表于WWW2019&#xff0c;这篇工作将知识图谱推理的两种典型方法&#xff0c;即表示学习和规…