深圳企业建站设计公司/百度提交网站入口

深圳企业建站设计公司,百度提交网站入口,合肥建设网络赌博网站,合肥做网站的公最近阅读了有关文本分类的文章,其中有一篇名为《Adversarail Training for Semi-supervised Text Classification》, 其主要思路实在文本训练时增加了一个扰动因子,即在embedding层加入一个小的扰动,发现训练的结果比不加要好很多。 模型的网…

最近阅读了有关文本分类的文章,其中有一篇名为《Adversarail Training for Semi-supervised Text Classification》, 其主要思路实在文本训练时增加了一个扰动因子,即在embedding层加入一个小的扰动,发现训练的结果比不加要好很多。

模型的网络结构如下图:

 

下面就介绍一下这个对抗因子r的生成过程:

在进入lstm网络前先进行从w到v的计算,即将wordembedding 归一化:

然后定义模型的损失函数,令输入为x,参数为θ,Radv为对抗训练因子,损失函数为:

其中一个细节,虽然θˆ 是θ的复制,但是它是计算扰动的过程,不会参与到计算梯度的反向传播算法中。

然后就是求扰动:

 

 

先对表达式求导得到倒数g,然后对倒数g进行l2正则化的线性变换。

至此扰动则计算完成然后加入之前的wordembedding中参与模型训练。

下面则是模型的代码部分:

#构建adversarailLSTM模型class AdversarailLSTM(object):def __init__(self, config, wordEmbedding, indexFreqs):#定义输入self.inputX = tf.placeholder(tf.int32, [None, config.sequenceLength], name="inputX")self.inputY = tf.placeholder(tf.float32, [None, 1], name="inputY")self.dropoutKeepProb = tf.placeholder(tf.float32, name="dropoutKeepProb")#根据词频计算权重indexFreqs[0], indexFreqs[1] = 20000, 10000weights = tf.cast(tf.reshape(indexFreqs / tf.reduce_sum(indexFreqs), [1, len(indexFreqs)]), dtype=tf.float32)#词嵌入层with tf.name_scope("wordEmbedding"):#利用预训练的词向量初始化词嵌入矩阵normWordEmbedding = self._normalize(tf.cast(wordEmbedding, dtype=tf.float32, name="word2vec"), weights)#self.W = tf.Variable(tf.cast(wordEmbedding, dtype=tf.float32, name="word2vec"), name="W")self.embeddedWords = tf.nn.embedding_lookup(normWordEmbedding, self.inputX)#计算二元交叉熵损失with tf.name_scope("loss"):with tf.variable_scope("Bi-LSTM", reuse=None):self.predictions = self._Bi_LSTMAttention(self.embeddedWords)self.binaryPreds = tf.cast(tf.greater_equal(self.predictions, 0.5), tf.float32, name="binaryPreds")losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.predictions, labels=self.inputY)loss = tf.reduce_mean(losses)with tf.name_scope("perturloss"):with tf.variable_scope("Bi-LSTM", reuse=True):perturWordEmbedding = self._addPerturbation(self.embeddedWords, loss)print("perturbSize:{}".format(perturWordEmbedding))perturPredictions = self._Bi_LSTMAttention(perturWordEmbedding)perturLosses = tf.nn.sigmoid_cross_entropy_with_logits(logits=perturPredictions, labels=self.inputY)perturLoss = tf.reduce_mean(perturLosses)self.loss = loss + perturLossdef _Bi_LSTMAttention(self, embeddedWords):#定义两层双向LSTM的模型结构with tf.name_scope("Bi-LSTM"):fwHiddenLayers = []bwHiddenLayers = []for idx, hiddenSize in enumerate(config.model.hiddenSizes):with tf.name_scope("Bi-LSTM" + str(idx)):#定义前向网络结构lstmFwCell = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_units=hiddenSize, state_is_tuple=True),output_keep_prob=self.dropoutKeepProb)#定义反向网络结构lstmBwCell = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_units=hiddenSize, state_is_tuple=True),output_keep_prob=self.dropoutKeepProb)fwHiddenLayers.append(lstmFwCell)bwHiddenLayers.append(lstmBwCell)# 实现多层的LSTM结构, state_is_tuple=True,则状态会以元祖的形式组合(h, c),否则列向拼接fwMultiLstm = tf.nn.rnn_cell.MultiRNNCell(cells=fwHiddenLayers, state_is_tuple=True)bwMultiLstm = tf.nn.rnn_cell.MultiRNNCell(cells=bwHiddenLayers, state_is_tuple=True)#采用动态rnn,可以动态地输入序列的长度,若没有输入,则取序列的全长#outputs是一个元组(output_fw, output_bw), 其中两个元素的维度都是[batch_size, max_time, hidden_size], fw和bw的hiddensize一样#self.current_state是最终的状态,二元组(state_fw, state_bw), state_fw=[batch_size, s], s是一个元组(h, c)outputs, self.current_state = tf.nn.bidirectional_dynamic_rnn(fwMultiLstm, bwMultiLstm,self.embeddedWords, dtype=tf.float32,scope="bi-lstm" + str(idx))#在bi-lstm+attention论文中,将前向和后向的输出相加with tf.name_scope("Attention"):H = outputs[0] + outputs[1]#得到attention的输出output = self.attention(H)outputSize = config.model.hiddenSizes[-1]print("outputSize:{}".format(outputSize))#全连接层的输出with tf.name_scope("output"):outputW = tf.get_variable("outputW",shape=[outputSize, 1],initializer=tf.contrib.layers.xavier_initializer())outputB = tf.Variable(tf.constant(0.1, shape=[1]), name="outputB")predictions = tf.nn.xw_plus_b(output, outputW, outputB, name="predictions")return predictionsdef attention(self, H):"""利用Attention机制得到句子的向量表示"""#获得最后一层lstm神经元的数量hiddenSize = config.model.hiddenSizes[-1]#初始化一个权重向量,是可训练的参数W = tf.Variable(tf.random_normal([hiddenSize], stddev=0.1))#对bi-lstm的输出用激活函数做非线性转换M = tf.tanh(H)#对W和M做矩阵运算,W=[batch_size, time_step, hidden_size], 计算前做维度转换成[batch_size * time_step, hidden_size]#newM = [batch_size, time_step, 1], 每一个时间步的输出由向量转换成一个数字newM = tf.matmul(tf.reshape(M, [-1, hiddenSize]), tf.reshape(W, [-1, 1]))#对newM做维度转换成[batch_size, time_step]restoreM = tf.reshape(newM, [-1, config.sequenceLength])#用softmax做归一化处理[batch_size, time_step]self.alpha = tf.nn.softmax(restoreM)#利用求得的alpha的值对H进行加权求和,用矩阵运算直接操作r = tf.matmul(tf.transpose(H, [0, 2, 1]), tf.reshape(self.alpha, [-1, config.sequenceLength, 1]))#将三维压缩成二维sequeezeR = [batch_size, hissen_size]sequeezeR = tf.squeeze(r)sentenceRepren = tf.tanh(sequeezeR)#对attention的输出可以做dropout处理output = tf.nn.dropout(sentenceRepren, self.dropoutKeepProb)return outputdef _normalize(self, wordEmbedding, weights):"""对word embedding 结合权重做标准化处理"""mean = tf.matmul(weights, wordEmbedding)powWordEmbedding = tf.pow(wordEmbedding -mean, 2.)var = tf.matmul(weights, powWordEmbedding)stddev = tf.sqrt(1e-6 + var)return (wordEmbedding - mean) / stddevdef _addPerturbation(self, embedded, loss):"""添加波动到word embedding"""grad, =tf.gradients(loss,embedded,aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)grad = tf.stop_gradient(grad)perturb = self._scaleL2(grad, config.model.epsilon)#print("perturbSize:{}".format(embedded+perturb))return embedded + perturbdef _scaleL2(self, x, norm_length):#shape(x) = [batch, num_step, d]#divide x by max(abs(x)) for a numerically stable L2 norm#2norm(x) = a * 2norm(x/a)#scale over the full sequence, dim(1, 2)alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12l2_norm = alpha * tf.sqrt(tf.reduce_sum(tf.pow(x/alpha, 2), (1, 2), keep_dims=True) + 1e-6)x_unit = x / l2_normreturn norm_length * x_unit

  

 代码是在双向lstm+attention的基础上增加adversarial training,训练数据为imdb电影评论数据,最后的结果发现确实很快就能达到最优值,但是训练所占的空间比较大(电脑跑了几十步就停止了),每一步的时间也稍微长一点。

 

转载于:https://www.cnblogs.com/danny92/p/10636890.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/462975.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在leveldb中,为什么要有immutable memtable?

目的是:为了防止写入kv时被阻塞。 设想,如果没有immutable memtable,当memtable满了之后后台线程需要将memtable 立即flush到新建的sst中,在flush的过程中,新的KV记录是无法写入的,只能等待,就…

密码强弱提示(27)

密码的强弱提示是对用户填写登陆密码的复杂程度来给出提示,使用密码的强弱提示可以增强用户对密码的保护意识,对如今的网络是非常有必要的,本程序中当用户输入完密码后,网页会自动的对用户输入的密码给出强弱判断。 使用JavaScrip…

《In Search of an Understandable Consensus Algorithm》翻译

Abstract Raft是一种用于管理replicated log的consensus algorithm。它能和Paxos产生同样的结果,有着和Paxos同样的性能,但是结构却不同于Paxos;它让Raft比Paxos更易于理解,并且也为用它构建实际的系统提供了更好的基础。为了增强…

深度点评五种常见WiFi搭建方案

总结十年无线搭建经验,针对企业常见的五种办公室无线网络方案做个简要分析,各种方案有何优劣,又适用于那种类型的企业。方案一:仅路由器或AP覆盖简述:使用路由器或AP覆盖多个无线盲区,多个AP的部署实现整体…

深入理解Presto

深入理解Presto 简介 Presto是一个facebook开源的分布式SQL查询引擎,适用于交互式分析查询,数据量支持GB到PB字节。presto的架构由关系型数据库的架构演化而来。presto之所以能在各个内存计算型数据库中脱颖而出,在于以下几点: …

设置VS2017背景图片

设置方法很简单:安装扩展ClaudiaIDE 1、在这里下载扩展,https://visualstudiogallery.msdn.microsoft.com/9ba50f8d-f30c-4e33-ab19-bfd9f56eb817 2、然后双击即可完成安装。 之后重启VS,就可以看到编程背景上多了一个萌妹子,据说…

证书的应用之一 —— TCPSSL通信实例及协议分析(上)

SSL(Security Socket Layer)是TLS(Transport Layer Security)的前身,是一种加解密协议,它提供了再网络上的安全传输,它介于网络通信协议的传输层与应用层之间。 为实现TCP层之上的ssl通信,需要用到数字证书。本文通过具体例子来说…

ClickedOnce部署方法

1.ClickedOnce部署时有些DLL和配置文件无法自动部署到系统当中,只能用Manifest Manager Tool 修改manifest 文件 /Files/Tonyyang/Software/ManifestManagerUtility.rar 2.部署文件结构 3.部署方法 首先用VS自带的ClickedOnce发布应用程序(博客园有&…

树莓派安装MySQL数据库与卸载

出处: 1、http://www.cnblogs.com/liyangLife/p/4500115.html 2、https://blog.csdn.net/huayucong/article/details/49736427 3、https://www.imooc.com/article/23132?block_idtuijian_wz 4、http://www.runoob.com/mysql/mysql-install.html(Debian系…

Visual Studio.net 2010 Windows Service 开发,安装与调试

本示例完成一个每隔一分钟向C:\log.txt文件写入一条记录为例,讲述一个Windows Service 程序的开发,安装与调试 原程序,加文档示例下载 /Files/zycblog/SourceCode.rar 目录索引 1 开发工具 2 开发过程 3 安装 4 开发调试 5 注意事项 6 参考资料…

ArcGis dbf读写——挂接Excel到属性表 C#

ArcMap提供了挂接Excel表格信息到属性表的功能,但是当数据量较大到以万计甚至十万计的时候这个功能就歇菜了,当然,你可以考虑分段挂接。这个挂接功能只是做了一个表关联,属性记录每个字段的信息需要通过“字段计算器”计算过来。 …

VisualStudioAddIn2017.vsix的下载安装和使用

本加载项是用于Visual Studio的,下载以后按照如下步骤进行安装: 完全退出Visual Studio把下载了的文件解压缩,会产生一个VisualStudioAddIn2017.vsix文件双击该文件,按照提示安装重启Visual Studio安装完成后的使用方法&#xff0…

Presto基本概念

Presto基本概念 Presto是Facebook开源的MPP SQL引擎,旨在填补Hive在速度和灵活性(对接多种数据源)上的不足。相似的SQL on Hadoop竞品还有Impala和Spark SQL等。这里我们介绍下Presto的基本概念,为后续的笔记做基础。 Operator …

2019春第六周编程总结

这个作业属于哪个课程C语言程序设计Ⅱ这个作业要求在哪里https://edu.cnblogs.com/campus/zswxy/MS/homework/2829我在这个课程的目标是利用指针知识解决相关实际问题在具体哪方面帮我实现目标设计密码开锁、交换变量解决问题以及电码加密参考文献C语言基础、http://www.w3scho…

Exchange企业实战技巧(26)在Outlook中打开多个邮箱

工作中,有时要需要让某个用户在outlook中同时打开多个exchange邮箱,对于outlook2010来说,是支持多个Exchange邮箱用户账户的并存,而outlook2007则不支持。那该功能有没其他实现方法呢?答案是有的。 如果你的Exchange是…

Emulator Error: Could not load OpenGLES emulati...

为什么80%的码农都做不了架构师?>>> 模拟器提示警告:Emulator Error: Could not load OpenGLES emulation library: Could not load DLL! 亲测可用: 从SDK\tools\lib目录下将一下四个dll文件复制到SDK\tools,重启模…

关于.net的垃圾回收和大对象处理_标记

CLR垃圾回收器根据所占空间大小划分对象。大对象和小对象的处理方式有很大区别。比如内存碎片整理 —— 在内存中移动大对象的成本是昂贵的,让我们研究一下垃圾回收器是如何处理大对象的,大对象对程序性能有哪些潜在的影响。 大对象堆和垃圾回收 在.Net …

三篇文章了解 TiDB 技术内幕——说存储

此文转载了 申砾 PingCAP 的文章:https://mp.weixin.qq.com/s?__bizMzI3NDIxNTQyOQ&mid2247484822&idx1&sn5434362800d8dcc0ca69d2f3f3260173&chksmeb1622fcdc61abea428f74b26a24bc589d524dd3b666d9b124809300f488d00b33a315a87792&scene21#we…

复制数组方法总结

为什么80%的码农都做不了架构师&#xff1f;>>> 在java中&#xff0c;对数组复制有多种 1.通过循环来复制 比如用for循环 int a[]{1,2,3}; int b[]new int[a.length]; for(int i0;i<a.length;i){ b[i]a[i]; } 2.直接复制 int a[]{1,2,3}; int b[]a…

JS/Cs相互调用

js调用cs中函数的方法 在前台js代码里写上<%method();%>举例:cs文件中写的有public void method(){....执行某些操作.}这个函数,然后在前台页面的js里面调用.<script type"text/javascript"><%method();%></script>在cs中调用js函数法一:C…