xlnet源码解读(简易pytorch实现版本)
xlnet这个模型还是相当复杂的,我看了很长一段时间也还是有很多地方没有搞明白,最后又在网上搜了很多大佬写的相关博客,才算是大致弄明白了,想了解xlnet的原理,请参考原论文,这里推荐一位大佬写的博客,写得非常清楚明白,也解决了我的很多困惑。
原理讲解博客:https://blog.csdn.net/qq_37236745/article/details/108846515
在这里,我重点讲解一下xlnet的代码实现,我这个代码是
Simple XLNet implementation with Pytorch Wrapper!
代码链接:https://github.com/graykode/xlnet-Pytorch
注:这代码里面没有说实验配置,但是我自己的实验环境是cuda11.3 ,经过实验可以正常运行,我的实验配置供参考,如下:
python:3.6
torch 1.10.0+cu113
torchvision 0.11.1+cu113
代码的实现。
首先是代码文件目录:
其中,dengdan_data_utils.py是我自己调试的时候复制的data_utils.py文件,方便我调试。跑代码记录.txt是我自己用来在跑代码的时候记录一些东西的文件。
从目录结构来看,结构较为清晰,基本上从名字就知道那个文件是用来干嘛的了。main.py是主文件,模型入口、xlnet.py是模型文件、data_utils.py是数据加载及处理文件。
接下来,按照main.py文件中代码运行的顺序来讲解此代码。
首先是导入一些包,这里不多说了。
然后就是程序入口,在命令行设置参数的代码,也不多说。
之后加载分词器以及模型初始化:
这里我们从main.py文件转向xlnet.py文件去看一下模型的初始化,重点讲一下如下几个参数:
个人理解,他们分别表示计算内容流注意力、查询流注意力和片段嵌入时的偏置参数(对模型原理不明白的话需要去看论文或博客)。
看完模型初始化,回到main.py,在设置好优化器和损失函数后,就开始迭代循环epoch,进行数据创建、数据排列和模型训练了。我个人还没明白为什么每个epoch都要进行一次全新的数据创建。
接下来,分别讲数据创建、数据排列和模型训练这三个部分。
-
数据创建:
这个部分需要注意的是,创建的两个句子不一定是相邻的,用标签label来进行区分,1表示两个句子相邻,0表示不是相邻的。
经过数据创建处理后,返回的是一个列表features,其中每一个feature都包含几个部分,如下:
-
数据排列:
这里的输入是数据创建部分的每一个feature,reuse_len、perm_size为256、seq_len为512、num_predict为85。数据排列这里我看了很久都没有看明白:
疑难点记录
这个是什么意思??有没有知道的佬可以给我科普一下~
到现在数据排列的部分我也没有看明白,555~
不写了,欢迎给我留言~
最后,附上我个人复现代码时的注释和个人笔记,欢迎参考我的仓库:
https://gitee.com/deng-dan-neu/my-keyan-jilu