多任务学习方法

  最近一直在做多任务,但是效果好象没什么提升,因为都是凭自己的想法和感觉在做。于是上网查找了一些这方面的资料,寻求一些理论上的支撑和前人经验上的帮助。

多任务学习:

  故名思意,就是多个任务一起学习。为什么要进行多任务学习呢?因为现实中样本采样的成本较高,而训练样本不足常常会出现过拟合的现象,而将多个相关任务同时学习,通过共享某个共同的知识可以提高各任务的泛化效果

分类:

在这里插入图片描述
  基于软共享的深度多任务学习在这里插入图片描述
  基于硬共享的深度多任务学习
在这里插入图片描述

一些问题:

1、损失的整合

  为多个任务定义一个损失函数,若将每个任务的损失进行简单相加,由于不同任务的收敛速度不同,可能某一任务的收敛得到不错的效果,而其他任务表现却很差。
  简单的解决办法是将简单相加变为加权相加,但这样会不时进行调参。
  论文《Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics》,提出引入不确定性来确定损失的权重:在每个任务的损失函数中学习另一个噪声参数(noise parameter)。此方法可以接受多任务(可以是回归和分类),并统一所有损失的尺度。这样就能像一开始那样,直接相加得到总损失了。该方法不仅可以得到很好的结果而且不需要考虑额外的权重超参数。

2、调节学习速率
  学习速率是最重要的超参数之一。我们发现,任务 A 和任务 B 各自合适的速率可能是不同的。这时,我们可以在各个任务的子网络(基于硬共享的深度多任务学习)分别调节各自的学习速率,而在共享网络部分,使用另一个学习速率。
  虽然听上去很复杂,但其实非常简单。通常,在利用 TensorFlow 训练神经网络时,使用的是:

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

  AdamOptimizer 定义如何应用梯度,而 minimize 则完成具体的计算和应用。我们可以将 minimize 替换为我们自己的实现方案,在应用梯度时,为计算图中的各变量使用各自适合的学习速率。

all_variables = shared_vars + a_vars + b_vars
all_gradients = tf.gradients(loss, all_variables)shared_subnet_gradients = all_gradients[:len(shared_vars)]
a_gradients = all_gradients[len(shared_vars):len(shared_vars + a_vars)]
b_gradients = all_gradients[len(shared_vars + a_vars):]shared_subnet_optimizer = tf.train.AdamOptimizer(shared_learning_rate)
a_optimizer = tf.train.AdamOptimizer(a_learning_rate)
b_optimizer = tf.train.AdamOptimizer(b_learning_rate)train_shared_op = shared_subnet_optimizer.apply_gradients(zip(shared_subnet_gradients, shared_vars))
train_a_op = a_optimizer.apply_gradients(zip(a_gradients, a_vars))
train_b_op = b_optimizer.apply_gradients(zip(b_gradients, b_vars))train_op = tf.group(train_shared_op, train_a_op, train_b_op)  

注:这个技巧其实在单任务网络中也很实用

3、将估计作为特征
  当完成第一阶段的工作,为预测多任务创建好神经网络后,我们可能希望将某一个任务得到的估计(estimate)作为另一个任务的特征。在前向传递(forward-pass)中,这非常简单。但在反向传播中呢?
  假设将任务 A 的估计作为特征输入给 B,我们可能并不希望将梯度从任务 B 传回任务 A,因为我们已经有了任务 A 的标签。对此,TensorFlow 的 API 所提供的 tf.stop_gradient 会有所帮助。在计算梯度时,它允许你传入一个希望作为常数的张量列表,这正是我们所需要的。

all_gradients = tf.gradients(loss, all_variables, stop_gradients=stop_tensors)    

不止如此,该技术可用在任何你希望利用 TensorFlow 计算某个值并将其作为常数的场景。

我的一些想法:

关于多个任务的训练,应该也可以不统一成一个损失函数,各个任务拥有自己的损失函数即可。
这样可以分别找到适合各个任务的学习速率,和迭代次数,然后进行次数不同迭代即可。
比如:
任务A需要迭代100次才收敛:optimizer1 = tf.train.AdamOptimizer(learning_rate1).minimize(loss1)
任务B需要迭代10次收敛:optimizer2 = tf.train.AdamOptimizer(learning_rate2).minimize(loss2)

# 训练:
for epoch in range(100):# Task Asess.run([optimizer1], feed_dict1)# Task Bif epoch % 10 == 0:sess.run([optimizer2], feed_dict2)

【当然这部分只是我的想法啦!没什么科学依据】

参考资料:

什么是多任务学习
深度神经网络中的多任务学习汇总
关于深度多任务学习的 3 点经验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479951.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

idea项目目录结构不是树形(横向变纵向)

关闭IDEA 删除项目文件夹下的.idea文件夹 重新用IDEA工具打开项目

曹羽 | 从知识工程到知识图谱全面回顾

本文转载自公众号:集智俱乐部。文本挖掘和图形数据库 | ©ontotext导语知识工程是符号主义人工智能的典型代表,近年来越来越火的知识图谱,就是新一代的知识工程技术。知识工程将如何影响未来人工智能领域的发展,甚至让计算机拥…

4大JVM性能分析工具详解,及内存泄漏分析方案

谈到性能优化分析一般会涉及到: Java代码层面的,典型的循环嵌套等 还会涉及到Java JVM:内存泄漏溢出等 MySQL数据库优化:分库分表、慢查询、长事务的优化等 阿里P8架构师谈:MySQL慢查询优化、索引优化、以及表等优化…

Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are

情况:就是本来你的 tensor 是有东西的,代码也应该是没问题的,百度无果,debug无果。 原因:突然发现了这一行 failed to allocate 202.56M (212402176 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory 然后…

从 0 搭建一个工业级推荐系统

推荐系统从来没像现在这样,影响着我们的生活。当你上网购物时,天猫、京东会为你推荐商品;想了解资讯,头条、知乎会为你准备感兴趣的新闻和知识;想消遣放松,抖音、快手会为你奉上让你欲罢不能的短视频。而驱…

最全中文停用词表整理(1893个)

在网上搜罗了一下&#xff0c;发现这个停用词还是挺好用的&#xff1a; ! " # $ % &( ) *, - -- . .. ... ...... ................... ./ .一 .数 .日 / // 0 1 2 3 4 5 6 7 8 9 : :// :: ; <> >> ?A Lex [ \ ] ^ _exp sub sup | } ~ ~~~~Δ Ψ γ…

论文浅尝 | 虚拟知识图谱:软件系统和应用案例综述

本文转载自公众号&#xff1a;DI数据智能。Virtual Knowledge Graphs: An Overview of Systems and Use Cases作者&#xff1a;Guohui Xiao, Linfang Ding, Benjamin Cogrel & Diego Calvanese供稿&#xff1a;Guohui Xiao编者按&#xff1a;Data Intelligence 发表意大利博…

LeetCode 169. 求众数(摩尔投票)

文章目录1. 题目信息2. 解题思路3. 代码3.1 排序3.2 map计数3.3 摩尔投票1. 题目信息 给定一个大小为 n 的数组&#xff0c;找到其中的众数。众数是指在数组中出现次数大于 ⌊ n/2 ⌋ 的元素。 你可以假设数组是非空的&#xff0c;并且给定的数组总是存在众数。 示例 1:输入…

阿里P8架构师谈:JVM的内存分配、运行原理、回收算法机制

不管是BAT面试&#xff0c;还是工作实践中的JVM调优以及参数设置&#xff0c;或者内存溢出检测等&#xff0c;都需要涉及到Java虚拟机的内存模型、内存分配&#xff0c;以及回收算法机制等&#xff0c;这些都是必考、必会技能。 JVM内存模型 JVM内存模型可以分为两个部分&…

Keras共享某个层

对一个层的多次调用&#xff0c;就是在共享这个层。 input1 Input(shape[28,28]) input2 Input(shape[28,28]) x1 Flatten()(input1) x1 Dense(60,activation"relu")(x1) x2 Flatten()(input2) x2 Dense(60,activation"relu")(x2)d Dense(10, acti…

我的BERT!改改字典,让BERT安全提速不掉分(已开源)

文 | 苏剑林编 | 小轶背景当前&#xff0c;大部分中文预训练模型都是以字为基本单位的&#xff0c;也就是说中文语句会被拆分为一个个字。中文也有一些多粒度的语言模型&#xff0c;比如创新工场的ZEN和字节跳动的AMBERT&#xff0c;但这类模型的基本单位还是字&#xff0c;只不…

2020年考证时间表汇总!这些证书值得拥有!

原文地址&#xff1a; https://zhuanlan.zhihu.com/p/100824416 2020年考证时间表汇总&#xff01;这些证书值得拥有&#xff01;已认证的官方帐号154 人赞同了该文章昨日之日不可留&#xff0c;2019年已然过去&#xff0c;2020年的我们不能再一成不变&#xff01;快根据自身情…

征稿 | 2019年全国知识图谱与语义计算大会(CCKS2019)第二轮征稿启事

2019年全国知识图谱与语义计算大会China Conference on Knowledge Graph and Semantic Computing (CCKS 2019)2019年8月24日-27日&#xff0c;杭州征稿截止: 2019年5月18日全国知识图谱与语义计算大会&#xff08;CCKS: China Conference on Knowledge Graph and Semantic Comp…

直通BAT必考题系列:JVM的4种垃圾回收算法、垃圾回收机制与总结

BAT必考JVM系列专题 直通BAT必考题系列&#xff1a;深入详解JVM内存模型与JVM参数详细配置 垃圾回收算法 1.标记清除 标记-清除算法将垃圾回收分为两个阶段&#xff1a;标记阶段和清除阶段。 在标记阶段首先通过根节点&#xff08;GC Roots&#xff09;&#xff0c;标记所…

遗传算法及其应用实现

使用遗传算法求解函数具有最大值的点X """ Visualize Genetic Algorithm to find a maximum point in a function. """ import numpy as np import matplotlib.pyplot as pltDNA_SIZE 10 # DNA length POP_SIZE 100 # population size CROSS…

python 判断一个点(坐标)是否在一个多边形内利用射线法

看了一篇博客写的用射线法判断一个经纬度点是否在一个多边形的内部的方法 经验证可行所以拿来用作备份: class Point:lng lat def __init__(self, lng, lat):self.lng lngself.lat lat求外包矩形 def get_polygon_bounds(points):length len(points)top down left ri…

论文浅尝 | 一种嵌入效率极高的 node embedding 方式

论文笔记整理&#xff1a;叶群&#xff0c;浙江大学计算机学院&#xff0c;知识图谱、NLP方向。会议&#xff1a;WSDM 2019链接&#xff1a;https://dl.acm.org/citation.cfm?id3290961Motivation基于spring-electrical的模型在网络可视化中取得了非常成功的应用&#xff0c;一…

重要的,是那些训练中被多次遗忘的样本

文 | kid丶源 | 知乎编 | 兔子酱今天跟大家分享一篇很有意思的文章&#xff0c;是一篇探讨深度学习模型记忆&遗忘机制的文章&#xff0c;是一篇角度很新颖的题材&#xff0c;同时又有一定启发作用。这篇文章发表在深度学习顶会ICLR19&#xff0c;标题是《An empirical stud…

直通BAT必考题系列:7种JVM垃圾收集器特点,优劣势、及使用场景

直通BAT之JVM系列 直通BAT必考题系列&#xff1a;JVM的4种垃圾回收算法、垃圾回收机制与总结 直通BAT必考题系列&#xff1a;深入详解JVM内存模型与JVM参数详细配置 今天继续JVM的垃圾回收器详解&#xff0c;如果说垃圾收集算法是JVM内存回收的方法论&#xff0c;那么垃圾收集…

模拟嫁接技术

模拟嫁接技术&#xff1a;定义嫁接算子及策略剪接算子及策略GPOGA算法总结定义 收益和代价 对一棵生成树 T1&#xff0c;若将某结点的一条分枝移至另一结点作为其一条分枝后产生的生成树为 T2&#xff0c;考察分枝移动前后生成树的边长和的变化&#xff0c;则定义收益(gain)和…