TensorFlow 多任务学习

多任务学习

多任务学习,顾名思义,就是多个任务模型同时执行,进行模型的训练,利用模型的共性部分来简化多任务的模型,实现模型之间的融合与参数共享,可以在一定程度上优化模型的运算,提高计算机的效率,但模型本身并没有什么改变。

多任务学习的核心在于如何训练上:

  • 交替训练
  • 联合训练

通过一个简单的线性变换来展示多任务学习模型的运用。

首先,导入需要的包

import tensorflow as tf
import numpy as np

使用numpy制造两组假数据

x_data = np.float32(np.random.rand(2, 100))  # 随机输入
y1_data = np.dot([0.100, 0.200], x_data) + 0.300
y2_data = np.dot([0.500, 0.900], x_data) + 3.000

构造两个线性模型

b1 = tf.Variable(tf.zeros([1]))
W1 = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y1 = tf.matmul(W1, x_data) + b1b2 = tf.Variable(tf.zeros([1]))
W2 = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y2 = tf.matmul(W2, x_data) + b2

计算方差,使方差最小化,使模型不断的靠近真实解

# 最小化方差
loss1 = tf.reduce_mean(tf.square(y1 - y1_data))
loss2 = tf.reduce_mean(tf.square(y2 - y2_data))

构造优化器

# 构建优化器
optimizer = tf.train.GradientDescentOptimizer(0.5)
train1 = optimizer.minimize(loss1)
train2 = optimizer.minimize(loss2)

 交替训练

基本思想:使两个模型交替进行训练

# 初始化全局变量
init = tf.global_variables_initializer()# 启动图 (graph)
with tf.Session() as sess:sess.run(init)for step in range(1, 1001):if np.random.rand() < 0.5:sess.run(train1)print(step, 'W1,b1:', sess.run(W1), sess.run(b1))else:sess.run(train2)print(step, 'W2,b2:', sess.run(W2), sess.run(b2))

输出结果为:

 从最终的结果可以看出W1,W2,b1,b2已经非常接近真实值了,说明模型的建立还是非常有效的。

联合训练

基本思想:将两个模型的损失函数结合起来,共同进行优化训练

# 联合训练
loss = loss1 + loss2
# 构建优化器
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)# 初始化全局变量
init = tf.global_variables_initializer()# 启动图
with tf.Session() as sess:sess.run(init)for step in range(1, 300):sess.run(train)print(step, 'W1,b1,W2,b2:', sess.run(W1), sess.run(b1), sess.run(W2), sess.run(b2))

输出结果为:

从结果可以看出模型的参数不断的接近真实值。

应用场景

当你需要同一组数据集去处理不同的任务时,交替训练是一个很好地选择。

当两个甚至多个任务需要联合考虑时,为了整体的最优而放弃局部最优的时候,使用联合训练非常的合适。

 

欢迎关注和评论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/490806.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【工业互联网】全球工业互联网十大最具成长性技术展望(2019-2020年)

来源&#xff1a;中国工业互联网研究院来源&#xff1a;中国工业互联网研究院全球工业互联网十大最具成长性技术展望&#xff08;2019-2020年&#xff09;工业互联网工业互联网是第四次工业革命的重要基石&#xff0c;在世界范围已步入发展快车道&#xff0c;正处于面临重大突破…

win10调节屏幕亮度_自动调节电脑屏幕亮度软件,保护你的眼睛

本文共514个字&#xff0c;预计用时2分钟小伙伴们&#xff0c;今天给大家分享一个小软件&#xff0c;名字叫做 EyeCareApp&#xff0c;中文名&#xff1a;护眼软件EyeCareApp是一款能够调节屏幕亮度的软件&#xff0c;它可以调整屏幕亮度&#xff0c;滤除蓝光&#xff0c;有效减…

UIAlertAction添加输入框

* UIAlertController & UIAlertAction * 1. 在iOS8中&#xff0c;我们失去了两个非常简单的控件&#xff0c;那就是UIAlertView、UIActionSheet&#xff0c;取而代之的是UIAlertController和UIAlertAction * 2. 在iOS8中&#xff0c;UIAlertController控件使用两种样式代…

国家计划统筹布局哪些人工智能创新平台?

来源&#xff1a;智造智库建设布局人工智能创新平台&#xff0c;是强化对人工智能研发应用的基础支撑。未来&#xff0c;国家层面计划大力促进各类通用软件和技术平台的开源开放&#xff0c;且按照军民深度融合的要求和相关规定&#xff0c;推进军民共享共用。人工智能开源软硬…

install package vif包_2019-10-03【百宝箱】如何使用wireshark实时远程抓取openwrt路由器包...

前言经常遇到问题的时候需要抓取wifi数据包&#xff0c;常用的做法有三种&#xff1a;1、使用 专用网卡omnipeekwindows软件抓包2、使用 macbook pro的airtool软件抓包3、在路由器上使用tcpdump除了omnipeek 其他都不能实时操作。如何结合openwrt来达到远程实时抓包呢&#xff…

CNN中的卷积操作与权值共享

CNN中非常有特点的地方就在于它的局部连接和权值共享&#xff0c;通过卷积操作实现局部连接&#xff0c;这个局部区域的大小就是滤波器filter&#xff0c;避免了全连接中参数过多造成无法计算的情况&#xff0c;再通过参数共享来缩减实际参数的数量&#xff0c;为实现多层网络提…

python根据矩阵数值大小涂上不同深浅颜色

绘制矩阵颜色图 import matplotlib.pyplot as pltplt.matshow(np.random.rand(5,5), cmapplt.get_cmap(Greens), alpha0.5) # , alpha0.3 plt.show()

20150210--Smarty1-02

20150210--Smarty1-02 三、设计篇 1、Smarty注释 基本语法&#xff1a; {*注释内容*} 示例代码&#xff1a; 2、Smarty中的变量 1&#xff09;从PHP中分配的变量(普通的变量、数组、对象) 基本语法: $smarty->assign(); 示例代码&#xff1a; demo02.php demo02.html 运行效…

python 绘制时频图 plt.specgram

时频图以横轴为时间&#xff0c;纵轴为频率&#xff0c;用颜色表示幅值。在一幅图中表示信号的频率、幅度随时间的变化 matplotlib.pyplot.specgram(x, NFFTNone, FsNone, FcNone, detrendNone, windowNone, noverlapNone, cmapNone, xextentNone, pad_toNone, sidesNone, s…

UserWarning: The default mode, 'constant', will be changed to 'reflect'

问题&#xff1a;UserWarning: The default mode, constant, will be changed to reflect in skimage 0.15. warn("The default mode, constant, will be changed to reflect in " skimage.transform.resize(image, output_shape, order1, modeNone, cval0, clipT…

华为鸿蒙系统四大特性:基于微内核,面向全场景,分布式架构

来源&#xff1a;今日头条8月9日&#xff0c;在广东东莞召开的华为开发者大会上&#xff0c;华为正式发布了自研操作系统&#xff1a;鸿蒙OS。据华为消费者业务CEO、华为技术有限公司常务董事余承东介绍&#xff0c;鸿蒙OS是基于微内核的面向全场景的分布式操作系统。随着华为全…

HTML5 音频视频

audio 元素能够播放声音文件或者音频流。 <!DOCTYPE html> <html> <head lang"en"><meta charset"utf-8"><title>HTML5 音频播放</title> </head> <body><!-- audio&#xff08;音频&#xff09;contr…

VM虚拟机中 localhost login_UTM 2.0 虚拟机来了,解决上网和无声音问题

今天主要讲一下UTM虚拟机&#xff0c;如果你对UTM这款APP不太熟悉&#xff0c;我在这里大致讲一下&#xff0c;这款应用工具&#xff0c;它可以安装在 iPad 和 iPhone 上刷入电脑系统&#xff0c;举例子&#xff1a;在UTM中刷入win7系统。甚至还能刷入 Android 安卓系统&#x…

python计算ROC曲线和面积AUC

ROC曲线是根据一系列不同的二分类方式&#xff08;分界值或决定阈&#xff09;&#xff0c;以真正率&#xff08;也就是灵敏度&#xff09;&#xff08;True Positive Rate,TPR&#xff09;为纵坐标&#xff0c;假正率&#xff08;1-特效性&#xff09;&#xff08;False Posit…

python from __future__ import division

python from __future__ import division 之前一直很困惑&#xff0c;为什么这个模块叫future呢&#xff0c;难道有什么特殊功能能够让人们想到未来吗&#xff0c;最近才恍然大悟。 python的更新和前进是由社区进行推动的&#xff0c;而且是免费开源的&#xff0c;不受大型…

【VS开发】CTimeSpan类

CTimeSpan类。 日期和时间类简介 CTime类的对象表示的时间是基于格林威治标准时间&#xff08;GMT&#xff09;的。CTimeSpan类的对象表示的是时间间隔。 CTime类和CTimeSpan类一般不会被继承使用。两者对象的大小都是8个字节。 CTime表示的日期上…

搅动世界的两大因素

原创&#xff1a;张晓峰提要&#xff1a;移动互联、云计算、大数据、人工智能等技术因素逐步成为新基础设施&#xff0c;而泛连接、泛共享、泛融合与泛协同为代表的非技术因素正在重构这个世界。二者叠加融汇、相因相生。每个人都渐进或主动或被动地“被”函数化、数字化、孪生…

python计算PR曲线sklearn.metrics.precision_recall_curve

PR曲线实则是以precision&#xff08;精准率&#xff09;和recall&#xff08;召回率&#xff09;这两个为变量而做出的曲线&#xff0c;其中recall为横坐标&#xff0c;precision为纵坐标。设定一系列阈值&#xff0c;计算每个阈值对应的recall和precision&#xff0c;即可计算…

amigo幸运字符什么意思_转载 | 史上最全 python 字符串操作指南

点击蓝字关注&#xff0c;创智助你长姿势【本文已由 清风Python 授权转载(原创)作者&#xff1a;王翔&#xff0c;转载请联系出处】字符串的定义完了&#xff0c;估计很多人看到这个标题就要关网页了&#xff0c;稍等不妨再往下看看&#xff1f;python 定义字符、字符串没有 j…

在物理学的语言里,“生命”是什么?

转自&#xff1a;Darthusian“想象一种语言就像想象一种形式的生命。”--- 路德维希.维特根斯坦&#xff0c;《哲学研究》当今世界人们使用大约6,800种不同的语言。不是每个词都能在不同的语言之间完美地翻译&#xff0c;意义有时会落入语义的裂缝。例如&#xff0c;日语词wabi…