动手做个DialoGPT:生成式多轮对话模型

文 | 苏剑林

编 | 兔子酱


前段时间刷Arixv的时候,发现清华大学开源了一个大规模的中文闲聊语料库LCCC,从开源的文件上来看,这可能是目前开源的数量最大、质量最好的闲聊语料库了,而且还包含了部分多轮对话聊天,总的来说可玩性还是蛮强的。笔者也被它吸引到了,尝试着用它来训练了一个闲聊对话模型,结果看上去还是不错的,在此分享一下自己的经验。

论文名称
《A Large-Scale Chinese Short-Text Conversation Dataset》

论文链接
https://arxiv.org/abs/2008.03946

项目地址
https://github.com/thu-coai/CDial-GPT

Arxiv访问慢的小伙伴也可以在 【夕小瑶的卖萌屋】订阅号后台回复关键词 【0917】 下载论文PDF~

语料简介

这里简单介绍一下LCCC这个数据集(Large-scale Cleaned Chinese Conversation),具体细节大家可以去Github上看,下载链接也在上面。LCCC分base和large两个版本,base主要是来源于微博对话,large则是在base的基础上融合了其他开源对话语料,按照作者的说法,LCCC经过了严格的清洗过程,所以整体质量看上去还是很不错的。

为了简化任务,所有样本都被处理成双人对话。下面是一些样本示例:

A: 等过年咱们回去买点兔头好好吃顿火锅
B: 太原就没看见有好吃的兔头
A: 我从虹桥给你带个回去那天瞅到一正宗的
B: 最爱你了
A: 那是必须

A: 嗯嗯,我再等等!你现在在上海吧?上海风好像比南京还大呢,少出门吧
B: 对啊,我在家,没事儿。一定要小心啊!

A: 我去年也去转了一圈,还碰见以前的体育老师了,合了个影
B: 哈哈我还去找高一时侯的英语老师没找到她刚好有事情没在学校~
A: 你也是真心找回忆了哦
B: 哈哈毕业了没去过想去看看啊

模型设计

知道了数据长什么样之后,我们接下来就要去设计模型了。显然,我们需要做的就是训练一个模型,预测下一个该回复什么。既然语料里包含了多轮对话,那么我们还要求这个模型支持多轮对话。考虑对话历史的最简单的方式,就是把直到当前句的所有历史对话都拼接成单句文本,来作为模型的输入信息。

给定一些输入,预测一个输出,从形式上来看我们应该用Seq2Seq模型。直接用Seq2Seq其实问题也不大,但标准的Seq2Seq一般用于形式比较固定的输入输出,比如输入的文本长度应该是集中在某个范围内,不宜变化太大,但考虑多轮对话的话,理论上我们也不知道前面有多少轮对话,因此原则上输入文本长度是无限制的。用Seq2Seq的话,还有训练效率低的问题,就是我们每轮对话每次我们只能训练一句回复,如果一个多轮对话有n句回复,那么就要拆分为n个样本来训练了。

因此,我们需要一个长度能相当自由地变化的、同时能预测整个多轮对话的模型,实现这个需求的比较适当的选择就是单向语言模型(LM、GPT),做法如下图:

如图所示,我们选择当前主流的Transformer模型,按照BERT的常规输入格式,将每句对话用[SEP]拼接起来,然后就训练一个从左往右的单向语言模型。为了区分不同的说话角色,我们对不同的说话者用不同的Segment Id区分。此外,考虑到BERT和GPT都是用了绝对位置编码,可处理的文本长度存在一个上限,而对话轮数理论上是无限的,所以这里我们采用了相对位置编码的NEZHA作为基本结构,并使用NEZHA的预训练权重作为模型的初始化权重。

说白了,就是往NEZHA里边加入了下三角形式的Attention Mask,使其变为一个语言模型,相关介绍请参考《从语言模型到Seq2Seq:Transformer如戏,全靠Mask》[1]

训练细节

首先,给出笔者的参考实现以及训练好的模型:

https://github.com/bojone/nezha_gpt_dialog

代码基于bert4keras实现,基本结构是NEZHA Base。为了使得训练更加“平民化”,笔者只用了一张卡(22G的RTX)配合梯度累积来训练,显存不够大的读者可以缩小batch_size。开源的模型是用的学习率、0.01的权重衰减率的Adam优化器进行训练的,训练了1000万步(20多天),单步batch_size为16,累积梯度步数为16,所以相当于以256的batch_size训练了62.5万步,最终的loss大概是2.83(平均每个token的交叉熵)。

我们来测试一下我们的训练结果。

>>> print(chatbot.response([u'我最近天天吃火锅']))
火锅好吃啊
>>> print(chatbot.response([u'我最近天天吃火锅']))
我也是
>>> print(chatbot.response([u'我最近天天吃火锅']))
哈哈哈哈哈哈我最近一周吃了两顿火锅
>>> print(chatbot.response([u'我最近天天吃火锅']))
我也是,每次都吃的撑
>>> print(chatbot.response([u'我最近天天吃火锅']))
吃火锅不长胖
>>> print(chatbot.response([u'我最近天天吃火锅']))
火锅好啊
>>> print(chatbot.response([u'我最近天天吃火锅']))
你们那边有这种火锅?
>>> print(chatbot.response([u'我最近天天吃火锅']))
火锅好吃吗?我也想吃
>>> print(chatbot.response([u'我最近天天吃火锅']))
吃吧,吃了就开心了
>>> print(chatbot.response([u'我最近天天吃火锅']))
我也想天天吃火锅!
>>> print(chatbot.response([u'我最近天天吃火锅']))
火锅吃不了多久的。

>>> print(chatbot.response([u'你那天气好吗']))
还好,不热,有点冷
>>> print(chatbot.response([u'你那天气好吗', u'还好,不热,有点冷', u'那你得穿多厚的衣服']))
我穿了两件

对比分析

CDial-GPT也开源了自己训练的预训练模型,笔者也将它转换为bert4keras能加载的格式了,CDial-GPT-tf[2],读者也可以测试比对一下。从训练上来看,CDial-GPT使用pytorch实现的模型,基本结构是GPT Base,使用了4张2080Ti,总batch_size为32,累积梯度64步,论文说训练了30个epoch,总步数约2100万步(笔者的两倍),因此大概相当于batch_size为2048训练了33万步。

在输入设计上,CDial-GPT也有所不同,如下图:

如图所示,CDial-GPT跟我们前述设计的主要不同是多轮对话之间的拼接方式,我们之前是直接用[SEP]连接,它是用[speaker1]、[speaker2](图中简记为S1、S2)这样的角色标记来连接,最后才用一个[SEP]表示回复结束。这样一来,由于预测部分的格式跟历史的格式不一样,因此每次只能训练一句回复,多轮对话要拆分为多个样本来训练,理论上是增加了训练复杂性的(要训练多步才能把一个多轮对话样本训练完)。

至于效果上,个人测试的感觉是两者没什么明显差别。有兴趣的读者也可以自行比较测试。

文章总结

本文主要分享了一次对话模型实践,基于开源的LCCC闲聊语料库,利用语言模型(GPT)对多轮对话进行生成式建模,得到了一个相对通用的闲聊对话模型,最后将本文的思路与CDial-GPT本身开源的模型进行了比较。


文末福利
后台回复关键词【入群
加入卖萌屋NLP/IR/Rec与求职讨论群
有顶会审稿人、大厂研究员、知乎大V和妹纸
等你来撩哦~

参考文献

[1] 《从语言模型到Seq2Seq:Transformer如戏,全靠Mask》:
https://kexue.fm/archives/6933
[2] CDial-GPT-tf:
https://github.com/bojone/CDial-GPT-tf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479963.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

搜索引擎背后的数据结构和算法

文章目录1. 整体系统介绍2. 搜集2.1 待爬取网页链接文件:links.bin2.2 网页判重文件:bloom_filter.bin2.3 原始网页存储文件:doc_raw.bin2.4 网页链接及其编号的对应文件:doc_id.bin3. 分析3.1 抽取网页文本信息3.2 分词并创建临时…

论文浅尝 | DKN: 面向新闻推荐的深度知识感知网络

笔记整理:仲亮靓,东南大学硕士研究生,研究方向是基于知识图谱的推荐系统动机新闻文本的语言非常凝练,其中包含了很多实体和常识知识。但目前的新闻个性化推荐方法都没有利用这些外部知识,也没有使用新闻之间潜在的知识…

聊聊工业界做机器学习的里程碑

文 | 吴海波编 | YY阅读说明,本文的机器学习领域限制于互联网搜索、推荐、广告场景,仅限于个人观点。2017年,我和团队的几个核心去了趟北京,找了各大互联网公司一线实战的同学,交流各自在机器学习上的经验。这次交流让…

直通BAT JVM必考题:Minor GC、Major GC、Full GC的区别

Java面试过程,JVM属于必考题系列: 直通BAT必考题系列:深入详解JVM内存模型与JVM参数详细配置 直通BAT必考题系列:JVM的4种垃圾回收算法、垃圾回收机制与总结 直通BAT必考题系列:7种JVM垃圾收集器特点,优…

matplotlib绘制多张图、多子图、多例图

绘制多图 关键: fig plt.figure(1) 表示新建第几个图 import matplotlib.pyplot as pltfig plt.figure(1) plt_rec_loss [1,2,3,4,5,6] plt_rec_recall [4,3,6,5,8,9] plt.xlabel("epoch") plt.ylabel("loss") plt.plot(range(len(plt_re…

jieba分词并做分析

Github:结巴分词地址 https://github.com/fxsjy/jieba 几种分词方法的简单使用:一 . jieba安装、示例 pip install jieba,jieba分词的语料是基于人民日报。分词示例1 import jieba 2 3 str1 江州市长江大桥 4 word_object jieba.cut(s…

研讨会 | CCF TF 第 17 期:认知计算产业化落地

CCF TF 技术前线只为技术专家CCFTF第17期主题 认知计算产业化落地2019年05月11日上海斯波特酒店五楼(上海市南丹路15号,徐汇区政府对面)人类迈入人工智能时代,技术的发展使得机器可以从大数据中提取信息,串联成知识&a…

短网址系统

文章目录1. 短网址服务整体介绍2. 如何通过哈希算法生成短网址?2.1 如何让短网址更短2.2 如何解决哈希冲突?2.3 如何优化哈希算法生成短网址的性能?3. 如何通过ID生成器生成短网址?3.1 相同的原始网址可能会对应不同的短网址3.2 如…

一个神经元的价值和一个神经病的坚持

作者 | 周博磊来源 | 机器之心一个神经元能够催生多少故事?香港中文大学信息工程系助理教授周博磊近日撰文介绍了他自 2015 年开始至今对神经元的研究经历。最近,他与 David Bau、朱俊彦等人合作的神经元研究论文发表在了 PNAS 杂志上。以下是周博磊的原…

直通BAT必考题系列:深入剖析JVM之G1收集器、及回收流程、与推荐用例

金三银四马上到了,即将进入面试的高峰期。在BAT面试中,JVM基本都是必考的系列。你至少需要掌握JVM内存模型与JVM参数详细配置,JVM的4种垃圾回收算法、垃圾回收机制与总结,以及今天重点谈到的JVM垃圾回收算法的实现:JVM…

多任务学习方法

最近一直在做多任务,但是效果好象没什么提升,因为都是凭自己的想法和感觉在做。于是上网查找了一些这方面的资料,寻求一些理论上的支撑和前人经验上的帮助。 多任务学习: 故名思意,就是多个任务一起学习。为什么要进行…

曹羽 | 从知识工程到知识图谱全面回顾

本文转载自公众号:集智俱乐部。文本挖掘和图形数据库 | ©ontotext导语知识工程是符号主义人工智能的典型代表,近年来越来越火的知识图谱,就是新一代的知识工程技术。知识工程将如何影响未来人工智能领域的发展,甚至让计算机拥…

4大JVM性能分析工具详解,及内存泄漏分析方案

谈到性能优化分析一般会涉及到: Java代码层面的,典型的循环嵌套等 还会涉及到Java JVM:内存泄漏溢出等 MySQL数据库优化:分库分表、慢查询、长事务的优化等 阿里P8架构师谈:MySQL慢查询优化、索引优化、以及表等优化…

从 0 搭建一个工业级推荐系统

推荐系统从来没像现在这样,影响着我们的生活。当你上网购物时,天猫、京东会为你推荐商品;想了解资讯,头条、知乎会为你准备感兴趣的新闻和知识;想消遣放松,抖音、快手会为你奉上让你欲罢不能的短视频。而驱…

论文浅尝 | 虚拟知识图谱:软件系统和应用案例综述

本文转载自公众号:DI数据智能。Virtual Knowledge Graphs: An Overview of Systems and Use Cases作者:Guohui Xiao, Linfang Ding, Benjamin Cogrel & Diego Calvanese供稿:Guohui Xiao编者按:Data Intelligence 发表意大利博…

LeetCode 169. 求众数(摩尔投票)

文章目录1. 题目信息2. 解题思路3. 代码3.1 排序3.2 map计数3.3 摩尔投票1. 题目信息 给定一个大小为 n 的数组,找到其中的众数。众数是指在数组中出现次数大于 ⌊ n/2 ⌋ 的元素。 你可以假设数组是非空的,并且给定的数组总是存在众数。 示例 1:输入…

阿里P8架构师谈:JVM的内存分配、运行原理、回收算法机制

不管是BAT面试,还是工作实践中的JVM调优以及参数设置,或者内存溢出检测等,都需要涉及到Java虚拟机的内存模型、内存分配,以及回收算法机制等,这些都是必考、必会技能。 JVM内存模型 JVM内存模型可以分为两个部分&…

我的BERT!改改字典,让BERT安全提速不掉分(已开源)

文 | 苏剑林编 | 小轶背景当前,大部分中文预训练模型都是以字为基本单位的,也就是说中文语句会被拆分为一个个字。中文也有一些多粒度的语言模型,比如创新工场的ZEN和字节跳动的AMBERT,但这类模型的基本单位还是字,只不…

2020年考证时间表汇总!这些证书值得拥有!

原文地址: https://zhuanlan.zhihu.com/p/100824416 2020年考证时间表汇总!这些证书值得拥有!已认证的官方帐号154 人赞同了该文章昨日之日不可留,2019年已然过去,2020年的我们不能再一成不变!快根据自身情…

征稿 | 2019年全国知识图谱与语义计算大会(CCKS2019)第二轮征稿启事

2019年全国知识图谱与语义计算大会China Conference on Knowledge Graph and Semantic Computing (CCKS 2019)2019年8月24日-27日,杭州征稿截止: 2019年5月18日全国知识图谱与语义计算大会(CCKS: China Conference on Knowledge Graph and Semantic Comp…