使用注意力机制建模 - 标准化日期格式

文章目录

    • 1. 概述
    • 2. 数据
    • 3. 模型
    • 4. 训练
    • 5. 测试

参考 基于深度学习的自然语言处理

本文使用attention机制的模型,将各种格式的日期转化成标准格式的日期

1. 概述

  • LSTM、GRU 减少了梯度消失的问题,但是对于复杂依赖结构的长句子,梯度消失仍然存在
  • 注意力机制能同时看见句子中的每个位置,并赋予每个位置不同的权重(注意力),且可以并行计算

2. 数据

  • 生成日期数据
from faker import Faker
from babel.dates import format_date
import random
fake = Faker()
fake.seed(123)
random.seed(321)# 各种日期格式
FORMATS = ['short','medium','long','full','full','full','full','full','full','full','full','full','full','d MMM YYY','d MMMM YYY','dd MMM YYY','d MMM, YYY','d MMMM, YYY','dd, MMM YYY','d MM YY','d MMMM YYY','MMMM d YYY','MMMM d, YYY','dd.MM.YY']
  • 生成日期数据:随机格式(X),标准格式(Y)
def load_date():# 加载一些日期数据dt = fake.date_object() # 随机一个日期human_readable = format_date(dt, format=random.choice(FORMATS),locale='en_US')# 使用随机选取的格式,生成日期human_readable = human_readable.lower().replace(',','')machine_readable = dt.isoformat() # 标准格式return human_readable, machine_readable, dttest_date = load_date()

输出:

  • 建立字典,以及映射关系(字符 :idx)
from tqdm import tqdm # 显示进度条
def load_dateset(num_of_data):human_vocab = set()machine_vocab = set()dataset = []Tx = 30 # 日期最大长度for i in tqdm(range(num_of_data)):h, m, _ = load_date()if h is not None:dataset.append((h, m))human_vocab.update(tuple(h))machine_vocab.update(tuple(m))human = dict(zip(sorted(human_vocab)+['<unk>', '<pad>'],list(range(len(human_vocab)+2))))# x 字符:idx 的映射inv_machine = dict(enumerate(sorted(machine_vocab)))# idx : y 字符machine = {v : k for k, v in inv_machine.items()}# y 字符 : idxreturn dataset, human, machine, inv_machinem = 10000 # 样本个数
dataset, human_vocab, machine_vocab, inv_machine_vocab = load_dateset(m)
  • 日期(char序列)转 ids 序列,并且 pad / 截断
import numpy as np
from keras.utils import to_categoricaldef string_to_int(string, length, vocab):string = string.lower().replace(',','')if len(string) > length: # 长了,截断string = string[:length]rep = list(map(lambda x : vocab.get(x, '<unk>'), string))# 对string里每个char 使用 匿名函数 获取映射的id,没有的话,使用unk的id,map返回迭代器,转成listif len(string) < length:rep += [vocab['<pad>']]*(length-len(string))# 长度不够,加上 pad 的 idreturn rep # 返回 [ids,...]
  • 根据 ids 序列生成 one_hot 矩阵
def process_data(dataset, human_vocab, machine_vocab, Tx, Ty):X,Y = zip(*dataset)print("处理前 X:{}".format(X))print("处理前 Y:{}".format(Y))X = np.array([string_to_int(date, Tx, human_vocab) for date in X])Y = [string_to_int(date, Ty, machine_vocab) for date in Y]print("处理后 X的shape:{}".format(X.shape))print("处理后 Y: {}".format(Y))Xoh = np.array(list(map(lambda x : to_categorical(x, num_classes=len(human_vocab)), X)))Yoh = np.array(list(map(lambda x : to_categorical(x, num_classes=len(machine_vocab)), Y)))return X, np.array(Y), Xoh, Yoh
Tx = 30 # 输入长度
Ty = 10 # 输出长度
X, Y, Xoh, Yoh = process_data(dataset, human_vocab, machine_vocab, Tx, Ty)


检查生成的 one_hot 编码矩阵维度

print(X.shape)
print(Y.shape)
print(Xoh.shape)
print(Yoh.shape)

输出:

(10000, 30)
(10000, 10)
(10000, 30, 37)
(10000, 10, 11)

3. 模型

  • softmax 激活函数,求注意力权重
from keras import backend as K
def softmax(x, axis=1):ndim = K.ndim(x)if ndim == 2:return K.softmax(x)elif ndim > 2:e = K.exp(x - K.max(x, axis=axis, keepdims=True))s = K.sum(e, axis=axis, keepdims=True)return e/selse:raise ValueError('维度不对,不能是1维')
  • 模型组件
from keras.layers import RepeatVector, LSTM, Concatenate, \Dense, Activation, Dot, Input, Bidirectionalrepeator = RepeatVector(Tx) # 重复 Tx 次
# 重复器
# Input shape:
#     2D tensor of shape `(num_samples, features)`.
#
# Output shape:
#     3D tensor of shape `(num_samples, n, features)`.
concator = Concatenate(axis=-1) # 拼接器
densor1 = Dense(10, activation='tanh') # FC
densor2 = Dense(1, activation='relu') # FC
activator = Activation(softmax, name='attention_weights') # 计算注意力权重
dotor = Dot(axes=1) # 加权
  • 模型
def one_step_attention(h, s_prev):s_prev = repeator(s_prev) # 将前一个输出状态重复 Tx 次concat = concator([h, s_prev]) # 与 全部句子状态 拼接e = densor1(concat) # 经过 FCenergies = densor2(e) # 经过FCalphas = activator(energies) # 得到注意力权重context = dotor([alphas, h]) # 跟原句子状态做attentionreturn context # 得到上下文向量,后序输入到解码器# 解码器,是一个单向LSTM
n_h = 32
n_s = 64
post_activation_LSTM_cell = LSTM(n_s, return_state=True) # 单向LSTM
output_layer = Dense(len(machine_vocab), activation=softmax) # FC 输出预测值from keras.models import Model
def model(Tx, Ty, n_h, n_s, human_vocab_size, machine_vocab_size):X = Input(shape=(Tx,human_vocab_size), name='input_first')s0 = Input(shape=(n_s,),name='s0')c0 = Input(shape=(n_s,),name='c0')s = s0c = c0outputs = []h = Bidirectional(LSTM(n_h, return_sequences=True))(X) # 编码器得到整个序列的状态for t in range(Ty): # 解码器 推理context = one_step_attention(h, s) # attention 得到上下文向量s, _, c = post_activation_LSTM_cell(context, initial_state=[s,c])out = output_layer(s) # FC 输出预测outputs.append(out)model = Model(inputs=[X,s0,c0], outputs=outputs)return modelmodel = model(Tx,Ty,n_h,n_s,len(human_vocab), len(machine_vocab))
model.summary()from keras.utils import plot_model
plot_model(model, to_file='model.png',show_shapes=True,rankdir='TB')

输出:

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_first (InputLayer)        [(None, 30, 37)]     0                                            
__________________________________________________________________________________________________
s0 (InputLayer)                 [(None, 64)]         0                                            
__________________________________________________________________________________________________
bidirectional (Bidirectional)   (None, 30, 64)       17920       input_first[0][0]                
__________________________________________________________________________________________________
repeat_vector (RepeatVector)    (None, 30, 64)       0           s0[0][0]                         lstm[0][0]                       lstm[1][0]                       lstm[2][0]                       lstm[3][0]                       lstm[4][0]                       lstm[5][0]                       lstm[6][0]                       lstm[7][0]                       lstm[8][0]                       
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 30, 128)      0           bidirectional[0][0]              repeat_vector[0][0]              bidirectional[0][0]              repeat_vector[1][0]              bidirectional[0][0]              repeat_vector[2][0]              bidirectional[0][0]              repeat_vector[3][0]              bidirectional[0][0]              repeat_vector[4][0]              bidirectional[0][0]              repeat_vector[5][0]              bidirectional[0][0]              repeat_vector[6][0]              bidirectional[0][0]              repeat_vector[7][0]              bidirectional[0][0]              repeat_vector[8][0]              bidirectional[0][0]              repeat_vector[9][0]              
__________________________________________________________________________________________________
dense (Dense)                   (None, 30, 10)       1290        concatenate[0][0]                concatenate[1][0]                concatenate[2][0]                concatenate[3][0]                concatenate[4][0]                concatenate[5][0]                concatenate[6][0]                concatenate[7][0]                concatenate[8][0]                concatenate[9][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 30, 1)        11          dense[0][0]                      dense[1][0]                      dense[2][0]                      dense[3][0]                      dense[4][0]                      dense[5][0]                      dense[6][0]                      dense[7][0]                      dense[8][0]                      dense[9][0]                      
__________________________________________________________________________________________________
attention_weights (Activation)  (None, 30, 1)        0           dense_1[0][0]                    dense_1[1][0]                    dense_1[2][0]                    dense_1[3][0]                    dense_1[4][0]                    dense_1[5][0]                    dense_1[6][0]                    dense_1[7][0]                    dense_1[8][0]                    dense_1[9][0]                    
__________________________________________________________________________________________________
dot (Dot)                       (None, 1, 64)        0           attention_weights[0][0]          bidirectional[0][0]              attention_weights[1][0]          bidirectional[0][0]              attention_weights[2][0]          bidirectional[0][0]              attention_weights[3][0]          bidirectional[0][0]              attention_weights[4][0]          bidirectional[0][0]              attention_weights[5][0]          bidirectional[0][0]              attention_weights[6][0]          bidirectional[0][0]              attention_weights[7][0]          bidirectional[0][0]              attention_weights[8][0]          bidirectional[0][0]              attention_weights[9][0]          bidirectional[0][0]              
__________________________________________________________________________________________________
c0 (InputLayer)                 [(None, 64)]         0                                            
__________________________________________________________________________________________________
lstm (LSTM)                     [(None, 64), (None,  33024       dot[0][0]                        s0[0][0]                         c0[0][0]                         dot[1][0]                        lstm[0][0]                       lstm[0][2]                       dot[2][0]                        lstm[1][0]                       lstm[1][2]                       dot[3][0]                        lstm[2][0]                       lstm[2][2]                       dot[4][0]                        lstm[3][0]                       lstm[3][2]                       dot[5][0]                        lstm[4][0]                       lstm[4][2]                       dot[6][0]                        lstm[5][0]                       lstm[5][2]                       dot[7][0]                        lstm[6][0]                       lstm[6][2]                       dot[8][0]                        lstm[7][0]                       lstm[7][2]                       dot[9][0]                        lstm[8][0]                       lstm[8][2]                       
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 11)           715         lstm[0][0]                       lstm[1][0]                       lstm[2][0]                       lstm[3][0]                       lstm[4][0]                       lstm[5][0]                       lstm[6][0]                       lstm[7][0]                       lstm[8][0]                       lstm[9][0]                       
==================================================================================================
Total params: 52,960
Trainable params: 52,960
Non-trainable params: 0
________________________________________________________________________________________________

4. 训练

from keras.optimizers import Adam
# 优化器
opt = Adam(learning_rate=0.005, decay=0.01)
# 配置模型
model.compile(optimizer=opt, loss='categorical_crossentropy',metrics=['accuracy'])# 初始化 解码器状态
s0 = np.zeros((m, n_s))
c0 = np.zeros((m, n_s))
outputs = list(Yoh.swapaxes(0, 1))
# Yoh shape 10000*10*11,调换0,1轴,为10*10000*11
# outputs list,长度 10, 每个里面是array 10000*11history = model.fit([Xoh, s0, c0], outputs,epochs=10, batch_size=128,validation_split=0.1)
  • 绘制 loss 和 各位置的准确率
from matplotlib import pyplot as plt
import pandas as pd
his = pd.DataFrame(history.history)
print(his.columns)
loss = history.history['loss']
val_loss = history.history['val_loss']plt.plot(loss, label='train Loss')
plt.plot(val_loss, label='valid Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid()
plt.show()# 列 具体的名字根据运行次数,会有变化
col_train_acc = ('dense_7_accuracy', 'dense_7_1_accuracy', 'dense_7_2_accuracy','dense_7_3_accuracy', 'dense_7_4_accuracy', 'dense_7_5_accuracy','dense_7_6_accuracy', 'dense_7_7_accuracy', 'dense_7_8_accuracy','dense_7_9_accuracy')
col_test_acc = ('val_dense_7_accuracy', 'val_dense_7_1_accuracy','val_dense_7_2_accuracy', 'val_dense_7_3_accuracy','val_dense_7_4_accuracy', 'val_dense_7_5_accuracy','val_dense_7_6_accuracy', 'val_dense_7_7_accuracy','val_dense_7_8_accuracy', 'val_dense_7_9_accuracy')
train_acc = pd.DataFrame(history.history[c] for c in col_train_acc)
test_acc = pd.DataFrame(history.history[c] for c in col_test_acc)train_acc.plot()
plt.title('Training Accuracy on pos')
plt.legend()
plt.grid()
plt.show()test_acc.plot()
plt.title('Validation Accuracy on pos')
plt.legend()
plt.grid()
plt.show()

5. 测试

s0 = np.zeros((1, n_s))
c0 = np.zeros((1, n_s))
test_data,_,_,_ = load_dateset(10)
for x,y in test_data:print(x + " ==> " +y)
for x,_ in test_data:source = string_to_int(x, Tx, human_vocab)source = np.array(list(map(lambda a : to_categorical(a, num_classes=len(human_vocab)), source)))source = source[np.newaxis, :]pred = model.predict([source, s0, c0])pred = np.argmax(pred, axis=-1)output = [inv_machine_vocab[int(i)] for i in pred]print('source:',x)print('output:',''.join(output))

输出:

18 april 2014 ==> 2014-04-18
saturday august 22 1998 ==> 1998-08-22
october 22 1995 ==> 1995-10-22
thursday february 29 1996 ==> 1996-02-29
wednesday october 17 1979 ==> 1979-10-17
7 12 73 ==> 1973-12-07
9/30/01 ==> 2001-09-30
22 may 2001 ==> 2001-05-22
7 march 1979 ==> 1979-03-07
19 feb 2013 ==> 2013-02-19

预测10个,错误了4个,日期字符不完全正确

source: 18 april 2014
output: 2014-04-18
source: saturday august 22 1998
output: 1998-08-22
source: october 22 1995
output: 1995-12-22 # 错误 10 月
source: thursday february 29 1996
output: 1996-02-29
source: wednesday october 17 1979
output: 1979-10-17
source: 7 12 73
output: 1973-02-07 # 错误 12月
source: 9/30/01
output: 2001-05-00 # 错误 09-30
source: 22 may 2001
output: 2011-05-22 # 错误 2001
source: 7 march 1979
output: 1979-03-07
source: 19 feb 2013
output: 2013-02-19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473360.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

牛客 数学实验(模拟)

文章目录1. 题目2. 解题1. 题目 链接&#xff1a;https://ac.nowcoder.com/acm/contest/10166/A 来源&#xff1a;牛客网 牛牛在做数学实验。 老师给了牛牛一个数字n&#xff0c;牛牛需要不断地将所有数位上的值做乘法运算&#xff0c;直至最后数字不发生变化为止。 请你帮牛…

css动画之波纹

样式定义&#xff1a; #ContactAbout { height: auto; position: relative; overflow: hidden; } #sectioncontact { display: block; width: 100%; position: relative; height: 700px; z-index: 10; }#sectioncontact .map { width: 370px; height: 280px; position: absolut…

牛客 奇怪的排序问题(单调栈/遍历)

文章目录1. 题目2. 解题1. 题目 链接&#xff1a;https://ac.nowcoder.com/acm/contest/10166/B 来源&#xff1a;牛客网 操场上有n个人排成一队&#xff0c;这n个人身高互不相同&#xff0c;可将他们的身高视为一个1到n的排列。 这时需要把队伍变成升序&#xff0c;也就是从…

Python 中,matplotlib绘图无法显示中文的问题

在python中&#xff0c;默认情况下是无法显示中文的&#xff0c;如下代码&#xff1a; import matplotlib.pyplot as plt# 定义文本框和箭头格式 decisionNode dict(boxstyle "sawtooth", fc "0.8") leafNode dict(boxstyle "round4", fc …

牛客 XOR和(找规律)

文章目录1. 题目2. 解题1. 题目 链接&#xff1a;https://ac.nowcoder.com/acm/contest/10166/C 来源&#xff1a;牛客网 牛牛最近学会了异或操作&#xff0c;于是他发现了一个函数 f(x)x⊕(x−1)f(x)x\oplus (x-1)f(x)x⊕(x−1)&#xff0c;现在牛牛给你一个数 n&#xff0c…

天池 在线编程 数组划分III(计数)

文章目录1. 题目2. 解题1. 题目 https://tianchi.aliyun.com/oj/231188302809557697/235445278655844965 给你一个整数数组和一个整数K&#xff0c;请你判断数组是否可以划分为若干大小为k序列&#xff0c;并满足以下条件&#xff1a; 数组中的每一个数恰恰出现在一个序列中…

详解nohup和 区别

一、nohup nohup 命令运行由 Command参数和任何相关的 Arg参数指定的命令&#xff0c;忽略所有挂断&#xff08;SIGHUP&#xff09;信号。在注销后使用 nohup 命令运行后台中的程序。要运行后台中的 nohup 命令&#xff0c;添加 & &#xff08; 表示“and”的符号&#xf…

天池 在线编程 最小振幅(排序)

文章目录1. 题目2. 解题1. 题目 https://tianchi.aliyun.com/oj/231188302809557697/235445278655844966 给定一个由N个整数组成的数组A&#xff0c;一次移动&#xff0c;我们可以选择此数组中的任何元素并将其替换为任何值。 数组的振幅是数组A中的最大值和最小值之间的差。…

Python把两张图片拼接为一张图片并保存

这里主要用Python扩展库pillow中Image对象的paste()方法把两张图片拼接起来 from os import listdir from PIL import Imagedef pinjie():# 获取当前文件夹中所有JPG图像im_list [Image.open(fn) for fn in listdir() if fn.endswith(.jpg)]# 图片转化为相同的尺寸ims []for…

天池 在线编程 高效作业处理服务(01背包DP)

文章目录1. 题目2. 解题1. 题目 https://tianchi.aliyun.com/oj/231188302809557697/235445278655844967 Twitter正在测试一种名为Pigeon的新工作处理服务。 Pigeon可以用任务实际持续时间的两倍处理任务&#xff0c;并且每个任务都有一个权重。 此外&#xff0c;Pigeon在一…

LeetCode 1694. 重新格式化电话号码(模拟)

文章目录1. 题目2. 解题1. 题目 给你一个字符串形式的电话号码 number 。 number 由数字、空格 、和破折号 - 组成。 请你按下述方式重新格式化电话号码。 首先&#xff0c;删除 所有的空格和破折号。 其次&#xff0c;将数组从左到右 每 3 个一组 分块&#xff0c;直到 剩…

LeetCode 1695. 删除子数组的最大得分(前缀和+哈希+双指针)

文章目录1. 题目2. 解题1. 题目 给你一个正整数数组 nums &#xff0c;请你从中删除一个含有 若干不同元素 的子数组。 删除子数组的 得分 就是子数组各元素之 和 。 返回 只删除一个 子数组可获得的 最大得分 。 如果数组 b 是数组 a 的一个连续子序列&#xff0c;即如果它…

大型网站系统与Java中间件实践

大型网站系统与Java中间件实践&#xff08;贯通分布式高并发高数据高访问量网站架构与实现之权威著作&#xff0c;九大一线互联网公司CTO联合推荐&#xff09; 曾宪杰 著 ISBN 978-7-121-22761-5 2014年4月出版 定价&#xff1a;65.00元 340页 16开 编辑推荐 到底是本什么书…

LeetCode 1696. 跳跃游戏 VI(优先队列 / 单调队列)

文章目录1. 题目2. 解题2.1 贪心错误解2.2 优先队列/单调队列1. 题目 给你一个下标从 0 开始的整数数组 nums 和一个整数 k 。 一开始你在下标 0 处。每一步&#xff0c;你最多可以往前跳 k 步&#xff0c;但你不能跳出数组的边界。 也就是说&#xff0c;你可以从下标 i 跳到…

LeetCode 1697. 检查边长度限制的路径是否存在(排序+并查集)

文章目录1. 题目2. 解题1. 题目 给你一个 n 个点组成的无向图边集 edgeList &#xff0c;其中 edgeList[i] [ui, vi, disi] 表示点 ui 和点 vi 之间有一条长度为 disi 的边。请注意&#xff0c;两个点之间可能有 超过一条边 。 给你一个查询数组queries &#xff0c;其中 qu…

限制RICHTEXTBOX的输入的范围

附件&#xff1a;http://files.cnblogs.com/xe2011/WindowsFormsApplication_LimitRichTextBoxInput.rarusing System;using System.Collections.Generic;using System.ComponentModel;using System.Data;using System.Drawing;using System.Linq;using System.Text;using Syst…

NLP项目工作流程

文章目录1. 谷歌Colab设置2. 编写代码3. flask 微服务4. 打包到容器5. 容器托管参考 基于深度学习的自然语言处理使用这篇文章的数据(情感分类)进行学习。 1. 谷歌Colab设置 Colab 地址 新建笔记本 设置 选择 GPU/TPU 加速计算 测试 GPU 是否分配 import tensorflow…

牛客 牛牛浇树(差分)

文章目录1. 题目2. 解题1. 题目 链接&#xff1a;https://ac.nowcoder.com/acm/contest/10323/A 来源&#xff1a;牛客网 牛牛现在在花园养了n棵树&#xff0c;按顺序从第1棵到第n棵排列着。 牛牛每天会按照心情给其中某一个区间的树浇水。 例如如果某一天浇水的区间为[2,4]&…

再议 语法高亮插件的选择

之前一篇《为博客园选择一个小巧霸气的语法高亮插件》介绍了语法高亮插件的选择&#xff0c;当时只注重速度了。这些天在做深度定制的时候发现一个严重的问题&#xff0c;匹配精度不够。 什么是匹配精度呢&#xff1f;简单说就是没有把代码分块&#xff0c;是否分的足够细&…

Python自定义时间间隔访问网页

方法一&#xff1a;利用webbrowser import time import webbrowserwhile True: # 死循环time.sleep(60 * 1) # 程序等待时间&#xff0c;这里等待1min&#xff0c;参数的基本单位是秒print("正在访问&#xff1a;请稍等。。。")webbrowser.open("https://blo…