使用RNN预测文档归属作者

文章目录

    • 1. 文本处理
    • 2. 文本序列化
    • 3. 数据集拆分
    • 4. 建立RNN模型
    • 5. 训练
    • 6. 测试

参考 基于深度学习的自然语言处理

1. 文本处理

数据预览

# 有两个作者的文章(A, B),定义为0, 1
A = 0 # hamilton
B = 1 # madison
UNKNOWN = -1
# 把同一作者的文章全部合并到一个文件
textA, textB = '', ''import os
for file in os.listdir('./papers/A'):textA += preprocessing('./papers/A/'+file)
for file in os.listdir('./papers/B'):textB += preprocessing('./papers/B/'+file)
  • 把同一作者的文档合并,去除\n, 多余空格,以及作者的名字(防止数据泄露)
def preprocessing(file_path):with open(file_path, 'r') as f:lines = f.readlines()text = ' '.join(lines[1:]).replace('\n',' ').replace('  ', ' ').lower().replace('hamilton','').replace('madison','')text = ' '.join(text.split())return text

print("文本A的长度:{}".format(len(textA)))
print("文本B的长度:{}".format(len(textB)))文本A的长度:216394
文本B的长度:230867

2. 文本序列化

  • 采用字符级别的 tokenizer char_level=True
from keras.preprocessing.text import Tokenizer
char_tokenizer = Tokenizer(char_level=True)char_tokenizer.fit_on_texts(textA + textB) # 训练tokenizerlong_seq_a = char_tokenizer.texts_to_sequences([textA])[0] # 文本转 ids 序列
long_seq_b = char_tokenizer.texts_to_sequences([textB])[0]Xa, ya = make_subsequence(long_seq_a, A) # 切分成多个等长的子串样本
Xb, yb = make_subsequence(long_seq_b, B)



  • ids 序列切分成等长的子串样本
SEQ_LEN = 30 # 切分序列的长度,超参数
import numpy as np
def make_subsequence(long_seq, label, seq_len=SEQ_LEN):numofsubseq = len(long_seq)-seq_len+1 # 滑窗,可以取出来这么多种X = np.zeros((numofsubseq, seq_len)) # 数据y = np.zeros((numofsubseq, 1)) # 标签for i in range(numofsubseq):X[i] = long_seq[i:i+seq_len] # seq_len 大小的滑窗y[i] = labelreturn X, y
print('字符的种类:{}'.format(len(char_tokenizer.word_index))) # 52
# {' ': 1, 'e': 2, 't': 3, 'o': 4, 'i': 5, 'n': 6, 'a': 7, 's': 8, 'r': 9, 'h': 10,
#  'l': 11, 'd': 12, 'c': 13, 'u': 14, 'f': 15, 'm': 16, 'p': 17, 'b': 18, 'y': 19, 'w': 20,
#  ',': 21, 'g': 22, 'v': 23, '.': 24, 'x': 25, 'k': 26, 'j': 27, ';': 28, 'q': 29, 'z': 30,
#  '-': 31, '?': 32, '"': 33, '1': 34, ':': 35, '8': 36, '7': 37, '(': 38, ')': 39, '2': 40,
#  '0': 41, '3': 42, '4': 43, '6': 44, "'": 45, '!': 46, ']': 47, '5': 48, '[': 49, '@': 50,
#  '9': 51, '%': 52}
print('A训练集大小:{}'.format(Xa.shape))
print('B训练集大小:{}'.format(Xb.shape))
A训练集大小:(216365, 30)
B训练集大小:(230838, 30)

3. 数据集拆分

  • A、B数据集混合
# 堆叠AB训练数据在一起
X = np.vstack((Xa, Xb))
y = np.vstack((ya, yb))
  • 训练集,测试集拆分
# 训练集测试集拆分
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

4. 建立RNN模型

from keras.models import Sequential
from keras.layers import SimpleRNN, Dense, EmbeddingEmbedding_dim = 128 # 输出的嵌入的维度
RNN_size = 256 #  RNN 单元个数model = Sequential()
model.add(Embedding(input_dim=len(char_tokenizer.word_index)+1,output_dim=Embedding_dim,input_length=SEQ_LEN))
model.add(SimpleRNN(units=RNN_size, return_sequences=False)) # 只输出最后一步 
# return the last output in the output sequence
model.add(Dense(1, activation='sigmoid')) # 二分类model.compile(optimizer='adam', loss='binary_crossentropy',metrics=['accuracy'])
model.summary()

模型结构:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 30, 128)           6784      
_________________________________________________________________
simple_rnn (SimpleRNN)       (None, 256)               98560     
_________________________________________________________________
dense (Dense)                (None, 1)                 257       
=================================================================
Total params: 105,601
Trainable params: 105,601
Non-trainable params: 0
_________________________________________________________________

如果return_sequences=True,后两个输出维度如下:(增加了序列长度维度)

simple_rnn_1 (SimpleRNN)     (None, 30, 256)           98560     
_________________________________________________________________
dense_1 (Dense)              (None, 30, 1)             257       

5. 训练

batch_size = 4096 # 一次梯度下降使用的样本数量
epochs = 20  # 训练轮数
history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs,validation_data=(X_test, y_test),verbose=1)
Epoch 1/20
88/88 [==============================] - 59s 669ms/step - loss: 0.6877 - accuracy: 0.5436 - val_loss: 0.6856 - val_accuracy: 0.5540
Epoch 2/20
88/88 [==============================] - 56s 634ms/step - loss: 0.6830 - accuracy: 0.5564 - val_loss: 0.6844 - val_accuracy: 0.5550
Epoch 3/20
88/88 [==============================] - 56s 633ms/step - loss: 0.6825 - accuracy: 0.5577 - val_loss: 0.6829 - val_accuracy: 0.5563
Epoch 4/20
88/88 [==============================] - 56s 634ms/step - loss: 0.6816 - accuracy: 0.5585 - val_loss: 0.6788 - val_accuracy: 0.5641
Epoch 5/20
88/88 [==============================] - 56s 637ms/step - loss: 0.6714 - accuracy: 0.5813 - val_loss: 0.6670 - val_accuracy: 0.5877
Epoch 6/20
88/88 [==============================] - 56s 637ms/step - loss: 0.6532 - accuracy: 0.6113 - val_loss: 0.6435 - val_accuracy: 0.6235
Epoch 7/20
88/88 [==============================] - 57s 648ms/step - loss: 0.6287 - accuracy: 0.6424 - val_loss: 0.6159 - val_accuracy: 0.6563
Epoch 8/20
88/88 [==============================] - 55s 620ms/step - loss: 0.5932 - accuracy: 0.6807 - val_loss: 0.5747 - val_accuracy: 0.6971
Epoch 9/20
88/88 [==============================] - 54s 615ms/step - loss: 0.5383 - accuracy: 0.7271 - val_loss: 0.5822 - val_accuracy: 0.7178
Epoch 10/20
88/88 [==============================] - 56s 632ms/step - loss: 0.4803 - accuracy: 0.7687 - val_loss: 0.4536 - val_accuracy: 0.7846
Epoch 11/20
88/88 [==============================] - 61s 690ms/step - loss: 0.3979 - accuracy: 0.8190 - val_loss: 0.3940 - val_accuracy: 0.8195
Epoch 12/20
88/88 [==============================] - 60s 687ms/step - loss: 0.3257 - accuracy: 0.8572 - val_loss: 0.3248 - val_accuracy: 0.8564
Epoch 13/20
88/88 [==============================] - 59s 668ms/step - loss: 0.2637 - accuracy: 0.8897 - val_loss: 0.2980 - val_accuracy: 0.8742
Epoch 14/20
88/88 [==============================] - 56s 638ms/step - loss: 0.2154 - accuracy: 0.9115 - val_loss: 0.2326 - val_accuracy: 0.9023
Epoch 15/20
88/88 [==============================] - 56s 639ms/step - loss: 0.1822 - accuracy: 0.9277 - val_loss: 0.2112 - val_accuracy: 0.9130
Epoch 16/20
88/88 [==============================] - 56s 640ms/step - loss: 0.1504 - accuracy: 0.9412 - val_loss: 0.1803 - val_accuracy: 0.9267
Epoch 17/20
88/88 [==============================] - 58s 660ms/step - loss: 0.1298 - accuracy: 0.9499 - val_loss: 0.1662 - val_accuracy: 0.9331
Epoch 18/20
88/88 [==============================] - 57s 643ms/step - loss: 0.1132 - accuracy: 0.9567 - val_loss: 0.1643 - val_accuracy: 0.9358
Epoch 19/20
88/88 [==============================] - 58s 659ms/step - loss: 0.1018 - accuracy: 0.9613 - val_loss: 0.1409 - val_accuracy: 0.9441
Epoch 20/20
88/88 [==============================] - 57s 642ms/step - loss: 0.0907 - accuracy: 0.9659 - val_loss: 0.1325 - val_accuracy: 0.9475
  • 绘制训练过程
import pandas as pd
import matplotlib.pyplot as plt
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1) # set the vertical range to [0-1]
plt.show()

6. 测试

# 测试for file in os.listdir('./papers/Unknown'):# 测试文本处理unk_file = preprocessing('./papers/Unknown/'+file)# 文本转ids序列unk_file_seq = char_tokenizer.texts_to_sequences([unk_file])[0]# 提取固定长度的子串,形成多个样本X_unk, _ = make_subsequence(unk_file_seq, UNKNOWN)# 预测y_pred = model.predict(X_unk)y_pred = y_pred > 0.5votesA = np.sum(y_pred==0)votesB = np.sum(y_pred==1)print("文章 {} 被预测为 {} 写的,投票数 {} : {}".format(file,"A:hamilton" if votesA > votesB else "B:madison",max(votesA, votesB),min(votesA, votesB)))

输出:5个文本的作者,都预测对了

文章 paper_1.txt 被预测为 B:madison 写的,投票数 122118563
文章 paper_2.txt 被预测为 B:madison 写的,投票数 108998747
文章 paper_3.txt 被预测为 A:hamilton 写的,投票数 70416343
文章 paper_4.txt 被预测为 A:hamilton 写的,投票数 50634710
文章 paper_5.txt 被预测为 A:hamilton 写的,投票数 68784876

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 1674. 使数组互补的最少操作次数(差分思想)

文章目录1. 题目2. 解题1. 题目 给你一个长度为 偶数 n 的整数数组 nums 和一个整数 limit 。 每一次操作,你可以将 nums 中的任何整数替换为 1 到 limit 之间的另一个整数。 如果对于所有下标 i(下标从 0 开始),nums[i] nums[…

Kaggle 房价预测竞赛优胜方案:用 Python 进行全面数据探索

[导读]Kaggle 的房价预测竞赛从 2016 年 8 月开始,到 2017 年 2 月结束。这段时间内,超过 2000 多人参与比赛,选手采用高级回归技术,基于我们给出的 79 个特征,对房屋的售价进行了准确的预测。今…

使用GRU单元的RNN模型生成唐诗

文章目录1. 读取数据2. 字符索引3. 创建文本序列4. 创建文本编码序列5. 使用GRU单元建立RNN模型6. 文本生成参考 基于深度学习的自然语言处理 本文使用 GRU 单元建立 RNN 网络,使用唐诗三百首进行训练,使用模型生成唐诗。 GRU RNN 网络能够克服简单RNN…

Python数据结构常见的八大排序算法(详细整理)

前言 八大排序,三大查找是《数据结构》当中非常基础的知识点,在这里为了复习顺带总结了一下常见的八种排序算法。 常见的八大排序算法,他们之间关系如下: 排序算法.png 他们的性能比较: 下面,利用Python分别…

牛客 牛牛选物(01背包)

文章目录1. 题目2. 解题1. 题目 链接:https://ac.nowcoder.com/acm/contest/9887/A 来源:牛客网 牛牛有现在有n个物品,每个物品有一个体积v[i]和重量g[i],他想选择其中总体积恰好为V的若干个物品,想使这若干个物品的总重量最大&…

微信小程序最常用的布局——Flex布局

最近在学习微信小程序,在设计首页布局的时候,新认识了一种布局方式display:flex 1 .container { 2 display: flex; 3 flex-direction: column; 4 align-items: center; 5 background-color: #b3d4db; 6 } 编译之后的效果很明显,界面…

LeetCode 649. Dota2 参议院(循环队列)

文章目录1. 题目2. 解题1. 题目 Dota2 的世界里有两个阵营:Radiant(天辉)和 Dire(夜魇) Dota2 参议院由来自两派的参议员组成。现在参议院希望对一个 Dota2 游戏里的改变作出决定。他们以一个基于轮为过程的投票进行。在每一轮中,每一位参议员都可以行…

'[linux下tomcat 配置

tomcat目录结构bin ——Tomcat执行脚本目录 conf ——Tomcat配置文件 lib ——Tomcat运行需要的库文件(JARS) logs ——Tomcat执行时的LOG文件 temp ——Tomcat临时文件存放目录 webapps ——Tomcat的主要Web发布目录(存放我们自己的JSP,SER…

微信小程序基础(一)

一.注册小程序账号,下载IDE 1.官网注册https://mp.weixin.qq.com/,并下载IDE。 2.官方文档一向都是最好的学习资料。 注意: (1)注册账号之后会有一个appid,新建项目的时候需要填上,不然很多…

[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)

文章目录1. 读入数据2. 文本处理3. 建模4. 训练5. 测试练习地址:https://www.kaggle.com/c/ds100fa19 相关博文 [Kaggle] Spam/Ham Email Classification 垃圾邮件分类(spacy) [Kaggle] Spam/Ham Email Classification 垃圾邮件分类&#xff…

微信小程序中实现瀑布流布局和无限加载

瀑布流布局是一种比较流行的页面布局方式,最典型的就是Pinterest.com,每个卡片的高度不都一样,形成一种参差不齐的美感。 在HTML5中,我们可以找到很多基于jQuery之类实现的瀑布流布局插件,轻松做出这样的布局形式。在…

LeetCode 1684. 统计一致字符串的数目(哈希)

文章目录1. 题目2. 解题1. 题目 给你一个由不同字符组成的字符串 allowed 和一个字符串数组 words 。 如果一个字符串的每一个字符都在 allowed 中,就称这个字符串是 一致 字符串。 请你返回 words 数组中 一致 字符串的数目。 示例 1: 输入&#xff…

微信小程序:页面跳转时传递数据到另一个页面

一、功能描述 页面跳转时,同时把当前页面的数据传递给跳转的目标页面,并在跳转后的目标页面进行展示 二、功能实现 1. 代码实现 test1页面 // pages/test1/test1.js Page({/*** 页面的初始数据*/data: {name:Tom,age:12},buttonListener:function(){…

LeetCode 1685. 有序数组中差绝对值之和(前缀和)

文章目录1. 题目2. 解题1. 题目 给你一个 非递减 有序整数数组 nums 。 请你建立并返回一个整数数组 result,它跟 nums 长度相同,且result[i] 等于 nums[i] 与数组中所有其他元素差的绝对值之和。 换句话说, result[i] 等于 sum(|nums[i]-…

对一个 复杂的json结果进行取值的例子

1 JSON结果集 1 [2 {3 "J_LP_OPERATE_MAIN": {4 "ID": "1900036295",5 "FILL_MAN": "周兴福",6 "FILL_DEPT": "运维一班",7 "STATE…

微信小程序正则判断姓名和手机号

一、页面效果 二、json文件 //获取应用实例 const app getApp() Page({/*** 页面的初始数据*/data: {array: [速美, 现代, 淮安],mode: scaleToFill,src: ../../images/1.png,userInfo: {},hasUserInfo: false,canIUse: wx.canIUse(button.open-type.getUserInfo),userName: …

LeetCode 1686. 石子游戏 VI(贪心)

文章目录1. 题目2. 解题283 / 1660,前17%681 / 6572,前10.4%1. 题目 Alice 和 Bob 轮流玩一个游戏,Alice 先手。 一堆石子里总共有 n 个石子,轮到某个玩家时,他可以 移出 一个石子并得到这个石子的价值。 Alice 和 B…

T4生成实体和简单的CRUD操作

主要跟大家交流下T4,我这里针对的是mysql,我本人比较喜欢用mysql,所以语法针对mysql,所以你要准备mysql的DLL了,同理sqlserver差不多,有兴趣可以自己写写,首先网上找了一个T4的帮助类,得到一些数据库属性,命名为 DbHelper.ttinclude <# template debug"false" hos…

LeetCode 1688. 比赛中的配对次数(模拟)

文章目录1. 题目2. 解题1. 题目 给你一个整数 n &#xff0c;表示比赛中的队伍数。比赛遵循一种独特的赛制&#xff1a; 如果当前队伍数是 偶数 &#xff0c;那么每支队伍都会与另一支队伍配对。总共进行 n / 2 场比赛&#xff0c;且产生 n / 2 支队伍进入下一轮。如果当前队…

微信小程序使用函数的方法

一、使用来自不同页面的函数 函数写在util.js页面 function formatTime(date) {var year date.getFullYear()var month date.getMonth() 1var day date.getDate()var hour date.getHours()var minute date.getMinutes()var second date.getSeconds()return [year, mon…