文本分类的一种对抗训练方法

最近阅读了有关文本分类的文章,其中有一篇名为《Adversarail Training for Semi-supervised Text Classification》, 其主要思路实在文本训练时增加了一个扰动因子,即在embedding层加入一个小的扰动,发现训练的结果比不加要好很多。

模型的网络结构如下图:

 

下面就介绍一下这个对抗因子r的生成过程:

在进入lstm网络前先进行从w到v的计算,即将wordembedding 归一化:

然后定义模型的损失函数,令输入为x,参数为θ,Radv为对抗训练因子,损失函数为:

其中一个细节,虽然θˆ 是θ的复制,但是它是计算扰动的过程,不会参与到计算梯度的反向传播算法中。

然后就是求扰动:

 

 

先对表达式求导得到倒数g,然后对倒数g进行l2正则化的线性变换。

至此扰动则计算完成然后加入之前的wordembedding中参与模型训练。

下面则是模型的代码部分:

#构建adversarailLSTM模型class AdversarailLSTM(object):def __init__(self, config, wordEmbedding, indexFreqs):#定义输入self.inputX = tf.placeholder(tf.int32, [None, config.sequenceLength], name="inputX")self.inputY = tf.placeholder(tf.float32, [None, 1], name="inputY")self.dropoutKeepProb = tf.placeholder(tf.float32, name="dropoutKeepProb")#根据词频计算权重indexFreqs[0], indexFreqs[1] = 20000, 10000weights = tf.cast(tf.reshape(indexFreqs / tf.reduce_sum(indexFreqs), [1, len(indexFreqs)]), dtype=tf.float32)#词嵌入层with tf.name_scope("wordEmbedding"):#利用预训练的词向量初始化词嵌入矩阵normWordEmbedding = self._normalize(tf.cast(wordEmbedding, dtype=tf.float32, name="word2vec"), weights)#self.W = tf.Variable(tf.cast(wordEmbedding, dtype=tf.float32, name="word2vec"), name="W")self.embeddedWords = tf.nn.embedding_lookup(normWordEmbedding, self.inputX)#计算二元交叉熵损失with tf.name_scope("loss"):with tf.variable_scope("Bi-LSTM", reuse=None):self.predictions = self._Bi_LSTMAttention(self.embeddedWords)self.binaryPreds = tf.cast(tf.greater_equal(self.predictions, 0.5), tf.float32, name="binaryPreds")losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.predictions, labels=self.inputY)loss = tf.reduce_mean(losses)with tf.name_scope("perturloss"):with tf.variable_scope("Bi-LSTM", reuse=True):perturWordEmbedding = self._addPerturbation(self.embeddedWords, loss)print("perturbSize:{}".format(perturWordEmbedding))perturPredictions = self._Bi_LSTMAttention(perturWordEmbedding)perturLosses = tf.nn.sigmoid_cross_entropy_with_logits(logits=perturPredictions, labels=self.inputY)perturLoss = tf.reduce_mean(perturLosses)self.loss = loss + perturLossdef _Bi_LSTMAttention(self, embeddedWords):#定义两层双向LSTM的模型结构with tf.name_scope("Bi-LSTM"):fwHiddenLayers = []bwHiddenLayers = []for idx, hiddenSize in enumerate(config.model.hiddenSizes):with tf.name_scope("Bi-LSTM" + str(idx)):#定义前向网络结构lstmFwCell = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_units=hiddenSize, state_is_tuple=True),output_keep_prob=self.dropoutKeepProb)#定义反向网络结构lstmBwCell = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_units=hiddenSize, state_is_tuple=True),output_keep_prob=self.dropoutKeepProb)fwHiddenLayers.append(lstmFwCell)bwHiddenLayers.append(lstmBwCell)# 实现多层的LSTM结构, state_is_tuple=True,则状态会以元祖的形式组合(h, c),否则列向拼接fwMultiLstm = tf.nn.rnn_cell.MultiRNNCell(cells=fwHiddenLayers, state_is_tuple=True)bwMultiLstm = tf.nn.rnn_cell.MultiRNNCell(cells=bwHiddenLayers, state_is_tuple=True)#采用动态rnn,可以动态地输入序列的长度,若没有输入,则取序列的全长#outputs是一个元组(output_fw, output_bw), 其中两个元素的维度都是[batch_size, max_time, hidden_size], fw和bw的hiddensize一样#self.current_state是最终的状态,二元组(state_fw, state_bw), state_fw=[batch_size, s], s是一个元组(h, c)outputs, self.current_state = tf.nn.bidirectional_dynamic_rnn(fwMultiLstm, bwMultiLstm,self.embeddedWords, dtype=tf.float32,scope="bi-lstm" + str(idx))#在bi-lstm+attention论文中,将前向和后向的输出相加with tf.name_scope("Attention"):H = outputs[0] + outputs[1]#得到attention的输出output = self.attention(H)outputSize = config.model.hiddenSizes[-1]print("outputSize:{}".format(outputSize))#全连接层的输出with tf.name_scope("output"):outputW = tf.get_variable("outputW",shape=[outputSize, 1],initializer=tf.contrib.layers.xavier_initializer())outputB = tf.Variable(tf.constant(0.1, shape=[1]), name="outputB")predictions = tf.nn.xw_plus_b(output, outputW, outputB, name="predictions")return predictionsdef attention(self, H):"""利用Attention机制得到句子的向量表示"""#获得最后一层lstm神经元的数量hiddenSize = config.model.hiddenSizes[-1]#初始化一个权重向量,是可训练的参数W = tf.Variable(tf.random_normal([hiddenSize], stddev=0.1))#对bi-lstm的输出用激活函数做非线性转换M = tf.tanh(H)#对W和M做矩阵运算,W=[batch_size, time_step, hidden_size], 计算前做维度转换成[batch_size * time_step, hidden_size]#newM = [batch_size, time_step, 1], 每一个时间步的输出由向量转换成一个数字newM = tf.matmul(tf.reshape(M, [-1, hiddenSize]), tf.reshape(W, [-1, 1]))#对newM做维度转换成[batch_size, time_step]restoreM = tf.reshape(newM, [-1, config.sequenceLength])#用softmax做归一化处理[batch_size, time_step]self.alpha = tf.nn.softmax(restoreM)#利用求得的alpha的值对H进行加权求和,用矩阵运算直接操作r = tf.matmul(tf.transpose(H, [0, 2, 1]), tf.reshape(self.alpha, [-1, config.sequenceLength, 1]))#将三维压缩成二维sequeezeR = [batch_size, hissen_size]sequeezeR = tf.squeeze(r)sentenceRepren = tf.tanh(sequeezeR)#对attention的输出可以做dropout处理output = tf.nn.dropout(sentenceRepren, self.dropoutKeepProb)return outputdef _normalize(self, wordEmbedding, weights):"""对word embedding 结合权重做标准化处理"""mean = tf.matmul(weights, wordEmbedding)powWordEmbedding = tf.pow(wordEmbedding -mean, 2.)var = tf.matmul(weights, powWordEmbedding)stddev = tf.sqrt(1e-6 + var)return (wordEmbedding - mean) / stddevdef _addPerturbation(self, embedded, loss):"""添加波动到word embedding"""grad, =tf.gradients(loss,embedded,aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)grad = tf.stop_gradient(grad)perturb = self._scaleL2(grad, config.model.epsilon)#print("perturbSize:{}".format(embedded+perturb))return embedded + perturbdef _scaleL2(self, x, norm_length):#shape(x) = [batch, num_step, d]#divide x by max(abs(x)) for a numerically stable L2 norm#2norm(x) = a * 2norm(x/a)#scale over the full sequence, dim(1, 2)alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12l2_norm = alpha * tf.sqrt(tf.reduce_sum(tf.pow(x/alpha, 2), (1, 2), keep_dims=True) + 1e-6)x_unit = x / l2_normreturn norm_length * x_unit

  

 代码是在双向lstm+attention的基础上增加adversarial training,训练数据为imdb电影评论数据,最后的结果发现确实很快就能达到最优值,但是训练所占的空间比较大(电脑跑了几十步就停止了),每一步的时间也稍微长一点。

 

转载于:https://www.cnblogs.com/danny92/p/10636890.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/462975.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在leveldb中,为什么要有immutable memtable?

目的是:为了防止写入kv时被阻塞。 设想,如果没有immutable memtable,当memtable满了之后后台线程需要将memtable 立即flush到新建的sst中,在flush的过程中,新的KV记录是无法写入的,只能等待,就…

密码强弱提示(27)

密码的强弱提示是对用户填写登陆密码的复杂程度来给出提示,使用密码的强弱提示可以增强用户对密码的保护意识,对如今的网络是非常有必要的,本程序中当用户输入完密码后,网页会自动的对用户输入的密码给出强弱判断。 使用JavaScrip…

MySQL Cluster 用户权限共享 (各sql节点同步)

转自:http://blog.csdn.net/ylqmf/article/details/7866517 MySQL Cluster API节点 中mysql.user 表为MyISAM引擎,所以每个API都要配置权限系统,MySQL已经为我们提供了共享权限脚本。这个脚本主要作用就是将mysql.user 表MyISAM引擎更换为NDB…

vue项目创建,redis列表字典操作,django用redis的第二种方法

vue项目的创建(路飞前端) -安装node.js -安装vue的脚手架 -创建vue项目,vue create 项目名字在pycharm中开发vue -webstrom,pyacharm,goland,idea,androidStuidio,Php.... -Edit-conf----》点 选npm-----》在script对应的框中写&a…

心理学三大流派

心理学三大流派及其代表人物精神分析学派、行为主义学派、人本主义心理学影响最大,被称为心理学的三大主要势力[编辑]精神分析学派代表人物:西格蒙德弗洛伊德精神分析由弗洛伊德开创,其后被不断修正与发展,影响力远远超出心理学&a…

《In Search of an Understandable Consensus Algorithm》翻译

Abstract Raft是一种用于管理replicated log的consensus algorithm。它能和Paxos产生同样的结果,有着和Paxos同样的性能,但是结构却不同于Paxos;它让Raft比Paxos更易于理解,并且也为用它构建实际的系统提供了更好的基础。为了增强…

从软件交付看软件验收管理

软件项目交付验收是软件质量保障的最后一道防火墙,也是企业乃至每个项目成员都想要的结果,软件项目终于可以告一段落。一个软件项目的验收,一般是由一系列验收准备工作组成的,如果我们在最终验收前,已经将很多阶段的工…

十三、实现Comparable接口和new ComparatorT(){ }排序的实现过程

参考:https://www.cnblogs.com/igoodful/p/9517784.html Collections有两种比较规则方式,第一种是使用自身的比较规则: 该类必须实现Comparable接口并重写comparTo方法。 this可以想象为1,传入对象o想象为2,返回1-2即按…

Presto入门介绍

最近在调研presto查询引擎的模块,先了解了下大体的框架和基本知识。这篇文章适合入门的童鞋看,因此转载了,用于以后查询使用。 1, Presto基本认识 1.1 定义 Presto是一个分布式的查询引擎,本身并不存储数据&#xff…

职场有影帝出没,屌丝们请当心!

引子职场有影帝出没,请当心!广大屌丝请注意危险,谨慎前往。人生苦短,必须性感;职场如戏,要靠演技。不少公司正变成秀场,影帝层出不穷,屌丝们的辛苦努力一不小心就成了影帝的嫁衣。影…

深度点评五种常见WiFi搭建方案

总结十年无线搭建经验,针对企业常见的五种办公室无线网络方案做个简要分析,各种方案有何优劣,又适用于那种类型的企业。方案一:仅路由器或AP覆盖简述:使用路由器或AP覆盖多个无线盲区,多个AP的部署实现整体…

项目开发经验谈之:项目到底谁说了算

项目开发经验谈:项目的到底谁说了算 前言:项目到底是为谁而做,一个项目的成功到底是怎么样在评价:是领导阶层肯定,还是客户满意? 系列文章链接 项目开发经验谈:如何成为出色的开发人员盲目的项目…

深入理解Presto

深入理解Presto 简介 Presto是一个facebook开源的分布式SQL查询引擎,适用于交互式分析查询,数据量支持GB到PB字节。presto的架构由关系型数据库的架构演化而来。presto之所以能在各个内存计算型数据库中脱颖而出,在于以下几点: …

实战演示 bacula 软件备份功能

原文地址:http://www.linuxde.net/2012/04/9734.html 一、实例演示bacula的完全备份功能 1.创建卷组 执行如下命令,连接到bacula控制端,执行备份恢复操作: [rootbaculaserver opt]#/opt/bacula/sbin/bconsole Connec…

设置VS2017背景图片

设置方法很简单:安装扩展ClaudiaIDE 1、在这里下载扩展,https://visualstudiogallery.msdn.microsoft.com/9ba50f8d-f30c-4e33-ab19-bfd9f56eb817 2、然后双击即可完成安装。 之后重启VS,就可以看到编程背景上多了一个萌妹子,据说…

证书的应用之一 —— TCPSSL通信实例及协议分析(上)

SSL(Security Socket Layer)是TLS(Transport Layer Security)的前身,是一种加解密协议,它提供了再网络上的安全传输,它介于网络通信协议的传输层与应用层之间。 为实现TCP层之上的ssl通信,需要用到数字证书。本文通过具体例子来说…

自旋锁和互斥锁的区别

自旋锁和互斥锁的区别 POSIX threads(简称Pthreads)是在多核平台上进行并行编程的一套API。线程同步是并行编程中非常重要的通讯手段,其中最典型的应用就是用 Pthreads提供的锁机制(lock)来对多个线程之间的共享临界区(Critical Section)进行保护(另一种常用的同步…

校内模拟赛 Zbq's Music Challenge

Zbqs Music Challenge 题意: 一个长度为n的序列,每个位置可能是1或者0,1的概率是$p_i$。对于一个序列$S$,它的得分是 $$BasicScoreA\times \sum_{i1}^{n}{S_i} \tag{1}$$ $$ combo(i)\left\{ \begin{aligned} &S_i & &…

TSQL中实现ORACLE的多列IN 多列匹配。

期望效果:(我是拿到一对关系去另一表中的一对关系去对比)select * From Empoylee Where (Address1,Address2) in (Select Address1,Address2 From EmpoyleeAdresses Where Country Canada)以上无法实现用这种方案也可以实现 不过速度很慢的s…

ClickedOnce部署方法

1.ClickedOnce部署时有些DLL和配置文件无法自动部署到系统当中,只能用Manifest Manager Tool 修改manifest 文件 /Files/Tonyyang/Software/ManifestManagerUtility.rar 2.部署文件结构 3.部署方法 首先用VS自带的ClickedOnce发布应用程序(博客园有&…