【深度学习】NLP中的对抗训练

        在NLP中,对抗训练往往都是针对嵌入层(包括词嵌入,位置嵌入,segment嵌入等等)开展的,思想很简单,即针对嵌入层添加干扰,从而提高模型的鲁棒性和泛化能力,下面结合具体代码讲解一些NLP中常见对抗训练算法。

1.Fast Gradient Method(FGM)

        FGM的思想是针对词嵌入加入梯度方向的干扰,至于干扰的大小是我们可以调节的,增加干扰后的样本可以作为额外的对抗样本进行训练,以此提高模型的效果。由于我们在训练时会针对每个样本都进行一次额外的增加干扰后的训练,所以使用FGM后训练时间理论上也会大概增加一倍。

        FGM在原训练代码的基础上,主要增加了以下几个额外的操作:针对嵌入层添加干扰并备份参数,计算添加干扰后的损失,梯度回传从而累积添加干扰后的梯度,恢复原来的嵌入层参数。

1.1 算法流程

对于每个x:1.计算x的前向loss、反向传播得到梯度2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上4.将embedding恢复为(1)时的值5.根据(3)的梯度对参数进行更新

  1.2 具体代码

import torch
class FGM():def __init__(self, model):self.model = modelself.backup = {}def attack(self, epsilon=1., emb_name='word_embeddings'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:#print('增加扰动的对象是', name)#print(type(param.grad))self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = epsilon * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='word_embeddings'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.backupparam.data = self.backup[name]self.backup = {}

1.3 具体用法

fgm = FGM(model) # (#1)初始化
for batch_input, batch_label in data:loss = model(batch_input, batch_label) # 正常训练loss.backward() # 反向传播,得到正常的grad# 对抗训练fgm.attack() # (#2)在embedding上添加对抗扰动loss_adv = model(batch_input, batch_label) # (#3)计算含有扰动的对抗样本的lossloss_adv.backward() # (#4)反向传播,并在正常的grad基础上,累加对抗训练的梯度fgm.restore() # (#5)恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

2.Projected Gradient Descent (PGD

        Project Gradient Descent(PGD)是一种迭代攻击算法,相比于普通的FGM 仅做一次迭代,PGD是做多次迭代,每次走一小步,每次迭代都会将扰动投射到规定范围内。其中r为扰动约束空间(一个半径为r的球体),原始的输入样本对应的初识点为球心,避免扰动超过球面。迭代多次后,保证扰动在一定范围内,如下图所示:

 2.1 算法流程

对于每个x:1.计算x的前向loss、反向传播得到梯度并备份对于每步t:2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r(超出范围则投影回epsilon内)3.t不是最后一步: 将梯度归0,根据1的x+r计算前后向并得到梯度4.t是最后一步: 恢复(1)的梯度,计算最后的x+r并将梯度累加到(1)上5.将embedding恢复为(1)时的值6.根据(4)的梯度对参数进行更新

 2.2  具体代码

import torch
class PGD():def __init__(self, model):self.model = modelself.emb_backup = {}self.grad_backup = {}def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings', is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data, epsilon)def restore(self, emb_name='word_embeddings'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data, epsilon):r = param_data - self.emb_backup[param_name]if torch.norm(r) > epsilon:r = epsilon * r / torch.norm(r)return self.emb_backup[param_name] + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:self.grad_backup[name] = param.grad.clone()def restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:param.grad = self.grad_backup[name]

2.3 具体用法

pgd = PGD(model)
K = 3
for batch_input, batch_label in data:# 正常训练loss = model(batch_input, batch_label)loss.backward() # 反向传播,得到正常的gradpgd.backup_grad()# 累积多次对抗训练——每次生成对抗样本后,进行一次对抗训练,并不断累积梯度for t in range(K):pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.dataif t != K-1:model.zero_grad()else:pgd.restore_grad()loss_adv = model(batch_input, batch_label)loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度pgd.restore() # 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()

Reference:

1.NLP中的对抗训练_colourmind的博客-CSDN博客

2.【NLP】NLP中的对抗训练_风度78的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/36979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spark 学习记录

基础 SparkContext是什么?有什么作用? https://blog.csdn.net/Shockang/article/details/118344357 SparkContext 是什么? SparkContext 是通往 Spark 集群的唯一入口,可以用来在 Spark 集群中创建 RDDs 、累加和广播变量( Br…

【数据库基础】Mysql下载安装及配置

下载 下载地址:https://downloads.mysql.com/archives/community/ 当前最新版本为 8.0版本,可以在Product Version中选择指定版本,在Operating System中选择安装平台,如下 安装 MySQL安装文件分两种 .msi和.zip [外链图片转存失…

C++11时间日期库chrono的使用

chrono是C11中新加入的时间日期操作库,可以方便地进行时间日期操作,主要包含了:duration, time_point, clock。 时钟与时间点 chrono中用time_point模板类表示时间点,其支持基本算术操作;不同时钟clock分别返回其对应…

Docker中部署Nginx

1.Nginx部署需求 2.操作教程 3.实际步骤 把配置粘过来。

Cookie、Session、Token的区别

有人或许还停留在它们只是验证身份信息的机制,但是它们之间的关系你真的弄懂了么? 发展史: Coolie: Netscape Communications 公司引入了 Cookie 概念,作为在客户端存储状态信息的一种方法。初始目的是为了解决 HTTP 的无状态性…

Python爬虫:单线程、多线程、多进程

前言 在使用爬虫爬取数据的时候,当需要爬取的数据量比较大,且急需很快获取到数据的时候,可以考虑将单线程的爬虫写成多线程的爬虫。下面来学习一些它的基础知识和代码编写方法。 一、进程和线程 进程可以理解为是正在运行的程序的实例。进…

python爬虫数据解析xpath、jsonpath,bs4

数据的解析 解析数据的方式大概有三种 xpathJsonPathBeautifulSoup xpath 安装xpath插件 打开谷歌浏览器扩展程序,打开开发者模式,拖入插件,重启浏览器,ctrlshiftx,打开插件页面 安装lxml库 安装在python环境中的Scri…

并发服务器模型,多线程并发

一、多线程并发完整代码 #include <stdio.h> #include <sys/types.h> #include <sys/socket.h> #include <arpa/inet.h> #include <string.h> #include <unistd.h> #include <sys/wait.h> #include <stdlib.h> #include <…

突然让做性能测试?试试RunnerGo

当前&#xff0c;性能测试已经是一名软件测试工程师必须要了解&#xff0c;甚至熟练使用的一项技能了&#xff0c;在工作时可能每次发版都要跑一遍性能&#xff0c;跑一遍自动化。性能测试入门容易&#xff0c;深入则需要太多的知识量&#xff0c;今天这篇文章给大家带来&#…

Rocky Linux更换为国内源

Rocky Linux提供的可供切换的源列表&#xff1a;Mirrors - Mirror Manager 其中以 COUNTRY 列为 CN 的是国内源。 选择其中一个Rocky Linux 源使用帮助 — USTC Mirror Help 文档 操作前请做好备份 对于 Rocky Linux 8&#xff0c;使用以下命令替换默认的配置 sed -e s|^mirr…

新能源汽车电控系统

新能源汽车电控系统主要分为&#xff1a;三电系统电控系统、高压系统电控系统、低压系统电控系统 三电系统电控系统 包括整车控制器、电池管理系统、驱动电机控制器等。 整车控制器VCU 整车控制器作为电动汽车中央控制单元&#xff0c;是整个控制系统的核心&#xff0c;也是…

zabbix监控mysql数据库、nginx、Tomcat

zabbix监控mysql数据库、nginx、Tomcat 一.zabbix监控mysql数据库 1.环境规划 hostIP部署zabbix-server192.168.198.17zabbix服务器搭建zabbix-mysql192.168.198.15zabbix客户端搭建 2.zabbix-server安装部署&#xff08;192.168.198.17&#xff09; 请参考以下配置&#…

Azure概念介绍

云计算定义 云计算是一种使用网络进行存储和处理数据的计算方式。它通过将数据和应用程序存储在云端服务器上&#xff0c;使用户能够通过互联网访问和使用这些资源&#xff0c;而无需依赖于本地硬件和软件。 发展历史 云计算的概念最早可以追溯到20世纪60年代的时候&#x…

年至年的选择仿elementui的样式

组件&#xff1a;<!--* Author: liuyu liuyuxizhengtech.com* Date: 2023-02-01 16:57:27* LastEditors: wangping wangpingxizhengtech.com* LastEditTime: 2023-06-30 17:25:14* Description: 时间选择年 - 年 --> <template><div class"yearPicker"…

Smart HTML Elements 16.1 Crack

Smart HTML Elements 是一个现代 Vanilla JS 和 ES6 库以及下一代前端框架。企业级 Web 组件包括辅助功能&#xff08;WAI-ARIA、第 508 节/WCAG 合规性&#xff09;、本地化、从右到左键盘导航和主题。与 Angular、ReactJS、Vue.js、Bootstrap、Meteor 和任何其他框架集成。 智…

九、多态(2)

本章概要 构造器和多态 构造器调用顺序继承和清理构造器内部多态方法的行为 协变返回类型使用继承设计 替代 vs 扩展向下转型与运行时类型信息 构造器和多态 通常&#xff0c;构造器不同于其他类型的方法。在涉及多态时也是如此。尽管构造器不具有多态性&#xff08;事实上…

【JavaScript】new 的原理以及实现

网道 - new 命令的原理 使用new命令时&#xff0c;它后面的函数依次执行下面的步骤。 创建一个空对象&#xff0c;作为将要返回的对象实例。将这个空对象的原型&#xff0c;指向构造函数的prototype属性。将这个空对象赋值给函数内部的this关键字。如果构造函数返回了一个对象…

在Visual Studio上,使用OpenCV实现人脸识别

1. 环境与说明 本文介绍了如何在Visual Studio上&#xff0c;使用OpenCV来实现人脸识别的功能 环境说明 : 操作系统 : windows 10 64位Visual Studio版本 : Visual Studio Community 2022 (社区版)OpenCV版本 : OpenCV-4.8.0 (2023年7月最新版) 实现效果如图所示&#xff0…

Linux命令200例:adduser用于创建新用户

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;全栈领域新星创作者✌。CSDN专家博主&#xff0c;阿里云社区专家博主&#xff0c;2023年6月csdn上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &…

代理模式【Proxy Pattern】

什么是代理模式呢&#xff1f;我很忙&#xff0c;忙的没空理你&#xff0c;那你要找我呢就先找我的代理人吧&#xff0c;那代理人总要知道 被代理人能做哪些事情不能做哪些事情吧&#xff0c;那就是两个人具备同一个接口&#xff0c;代理人虽然不能干活&#xff0c;但是被 代…