论文笔记:多任务学习模型:渐进式分层提取(PLE)含pytorch实现

整理了RecSys2020 Progressive Layered Extraction : A Novel Multi-Task Learning Model for Personalized Recommendations)论文的阅读笔记

  • 背景
  • 模型
  • 代码

论文地址:PLE

背景

  多任务学习(multi-task learning,MTL):给定 m 个学习任务,这m个任务或它们的一个子集彼此相关但不完全相同。简单地说就是一个模型有多个输出对应多个任务的结果。
  多任务学习在推荐系统中已经有很多成功的应用。但是存在一些问题,文章的作者观察到了一个跷跷板现象,即一个任务的性能通常通过损害其他任务的性能来提高。当任务相关性复杂时,相应的单任务模型相比,多个任务无法同时提高。
  基于这一点,本文提出了一种渐进分层提取(PLE)模型。明确分离共享组件和特定任务组件,采用渐进式路由机制,逐步提取和分离更深层次的语义知识。

模型

  利用门结构和注意网络进行信息融合在之前的模型中已经很常见,比如MMoE,单层的MMoE如图:
在这里插入图片描述

  在这种模型中,没有任务特定的概念,所有的专家被所有的任务共享,PLE就是在MMoE的基础上修改的,在PLE中,明确地分离了任务公共参数和任务特定参数,以避免复杂任务相关性导致的参数冲突。单层的PLE(CGC模块)是这样的:
在这里插入图片描述  和MMoE的区别就很明显了,在MMoE中,所有的任务同时更新所有的专家网络,没有任务特定的概念,而在PLE中,明确分离了任务通用专家和特定任务专家,特定于任务的专家仅接受对应的任务tower梯度更新参数,而共享的专家则被多任务结果更新参数,这就使得不同类型 experts 可以专注于更高效地学习不同的知识且避免不必要的交互。另外,得益于门控网络动态地融合输入,CGC可以更灵活地在不同子任务之间找到平衡且更好地处理任务之间的冲突和样本相关性问题。
  对CGC模型进行扩展,就形成了具有多级门控网络和渐进式分离路由的广义PLE模型:在这里插入图片描述
  具体的CGC(Customized Gate Control)模型:在这里插入图片描述
  具体的PLE模型:在这里插入图片描述

代码

  由于博主不是做这个方向的,仅记录这篇文章的思想,就不推公式和实验了,PLE的代码似乎没有公开,但是在网上找了一个可用的pytorch版本,稍微调试一下就可用了,代码修改自博客【推荐系统多任务学习 MTL】PLE论文精读笔记(含代码实现)

import numpy as np
import torch
from torch import nn'''专家网络'''
class Expert_net(nn.Module):def __init__(self, feature_dim, expert_dim):super(Expert_net, self).__init__()p = 0self.dnn_layer = nn.Sequential(nn.Linear(feature_dim, 256),nn.ReLU(),nn.Dropout(p),nn.Linear(256, expert_dim),nn.ReLU(),nn.Dropout(p))def forward(self, x):out = self.dnn_layer(x)return out'''特征提取层'''
class Extraction_Network(nn.Module):'''FeatureDim-输入数据的维数; ExpertOutDim-每个Expert输出的维数; TaskExpertNum-任务特定专家数;CommonExpertNum-共享专家数; GateNum-gate数(2表示最后一层,3表示中间层)'''def __init__(self, FeatureDim, ExpertOutDim, TaskExpertNum, CommonExpertNum, GateNum):super(Extraction_Network, self).__init__()self.GateNum = GateNum  # 输出几个Gate的结果,2表示最后一层只输出两个任务的Gate,3表示还要输出中间共享层的Gate'''两个任务模块,一个共享模块'''self.n_task = 2self.n_share = 1'''TaskA-Experts'''for i in range(TaskExpertNum):setattr(self, "expert_layer" + str(i + 1), Expert_net(FeatureDim, ExpertOutDim).cuda())self.Experts_A = [getattr(self, "expert_layer" + str(i + 1)) for i inrange(TaskExpertNum)]  # Experts_A模块,TaskExpertNum个Expert'''Shared-Experts'''for i in range(CommonExpertNum):setattr(self, "expert_layer" + str(i + 1), Expert_net(FeatureDim, ExpertOutDim).cuda())self.Experts_Shared = [getattr(self, "expert_layer" + str(i + 1)) for i inrange(CommonExpertNum)]  # Experts_Shared模块,CommonExpertNum个Expert'''TaskB-Experts'''for i in range(TaskExpertNum):setattr(self, "expert_layer" + str(i + 1), Expert_net(FeatureDim, ExpertOutDim).cuda())self.Experts_B = [getattr(self, "expert_layer" + str(i + 1)) for i inrange(TaskExpertNum)]  # Experts_B模块,TaskExpertNum个Expert'''Task_Gate网络结构'''for i in range(self.n_task):setattr(self, "gate_layer" + str(i + 1),nn.Sequential(nn.Linear(FeatureDim, TaskExpertNum + CommonExpertNum),nn.Softmax(dim=1)).cuda())self.Task_Gates = [getattr(self, "gate_layer" + str(i + 1)) for i inrange(self.n_task)]  # 为每个gate创建一个lr+softmax'''Shared_Gate网络结构'''for i in range(self.n_share):setattr(self, "gate_layer" + str(i + 1),nn.Sequential(nn.Linear(FeatureDim, 2 * TaskExpertNum + CommonExpertNum),nn.Softmax(dim=1)).cuda())self.Shared_Gates = [getattr(self, "gate_layer" + str(i + 1)) for i in range(self.n_share)]  # 共享gatedef forward(self, x_A, x_S, x_B):'''Experts_A模块输出'''Experts_A_Out = [expert(x_A) for expert in self.Experts_A]  #Experts_A_Out = torch.cat(([expert[:, np.newaxis, :] for expert in Experts_A_Out]),dim=1)  # 维度 (bs,TaskExpertNum,ExpertOutDim)'''Experts_Shared模块输出'''Experts_Shared_Out = [expert(x_S) for expert in self.Experts_Shared]  #Experts_Shared_Out = torch.cat(([expert[:, np.newaxis, :] for expert in Experts_Shared_Out]),dim=1)  # 维度 (bs,CommonExpertNum,ExpertOutDim)'''Experts_B模块输出'''Experts_B_Out = [expert(x_B) for expert in self.Experts_B]  #Experts_B_Out = torch.cat(([expert[:, np.newaxis, :] for expert in Experts_B_Out]),dim=1)  # 维度 (bs,TaskExpertNum,ExpertOutDim)'''Gate_A的权重'''Gate_A = self.Task_Gates[0](x_A)  # 维度 n_task个(bs,TaskExpertNum+CommonExpertNum)'''Gate_Shared的权重'''if self.GateNum == 3:Gate_Shared = self.Shared_Gates[0](x_S)  # 维度 n_task个(bs,2*TaskExpertNum+CommonExpertNum)'''Gate_B的权重'''Gate_B = self.Task_Gates[1](x_B)  # 维度 n_task个(bs,TaskExpertNum+CommonExpertNum)'''GateA输出'''g = Gate_A.unsqueeze(2)  # 维度(bs,TaskExpertNum+CommonExpertNum,1)experts = torch.cat([Experts_A_Out, Experts_Shared_Out],dim=1)  # 维度(bs,TaskExpertNum+CommonExpertNum,ExpertOutDim)Gate_A_Out = torch.matmul(experts.transpose(1, 2), g)  # 维度(bs,ExpertOutDim,1)Gate_A_Out = Gate_A_Out.squeeze(2)  # 维度(bs,ExpertOutDim)'''GateShared输出'''if self.GateNum == 3:g = Gate_Shared.unsqueeze(2)  # 维度(bs,2*TaskExpertNum+CommonExpertNum,1)experts = torch.cat([Experts_A_Out, Experts_Shared_Out, Experts_B_Out],dim=1)  # 维度(bs,2*TaskExpertNum+CommonExpertNum,ExpertOutDim)Gate_Shared_Out = torch.matmul(experts.transpose(1, 2), g)  # 维度(bs,ExpertOutDim,1)Gate_Shared_Out = Gate_Shared_Out.squeeze(2)  # 维度(bs,ExpertOutDim)'''GateB输出'''g = Gate_B.unsqueeze(2)  # 维度(bs,TaskExpertNum+CommonExpertNum,1)experts = torch.cat([Experts_B_Out, Experts_Shared_Out],dim=1)  # 维度(bs,TaskExpertNum+CommonExpertNum,ExpertOutDim)Gate_B_Out = torch.matmul(experts.transpose(1, 2), g)  # 维度(bs,ExpertOutDim,1)Gate_B_Out = Gate_B_Out.squeeze(2)  # 维度(bs,ExpertOutDim)if self.GateNum == 3:return Gate_A_Out, Gate_Shared_Out, Gate_B_Outelse:return Gate_A_Out, Gate_B_Outclass PLE(nn.Module):# FeatureDim-输入数据的维数;ExpertOutDim-每个Expert输出的维数;TaskExpertNum-任务特定专家数;CommonExpertNum-共享专家数;n_task-任务数(gate数)def __init__(self, FeatureDim, ExpertOutDim, TaskExpertNum, CommonExpertNum, n_task=2):super(PLE, self).__init__()# self.FeatureDim = x.shape[1]'''一层Extraction_Network,一层CGC'''self.Extraction_layer1 = Extraction_Network(FeatureDim, ExpertOutDim, TaskExpertNum, CommonExpertNum, GateNum=3)self.CGC = Extraction_Network(ExpertOutDim, ExpertOutDim, TaskExpertNum, CommonExpertNum, GateNum=2)'''TowerA'''p1 = 0hidden_layer1 = [64, 32]self.tower1 = nn.Sequential(nn.Linear(ExpertOutDim, hidden_layer1[0]),nn.ReLU(),nn.Dropout(p1),nn.Linear(hidden_layer1[0], hidden_layer1[1]),nn.ReLU(),nn.Dropout(p1),nn.Linear(hidden_layer1[1], 1))'''TowerB'''p2 = 0hidden_layer2 = [64, 32]self.tower2 = nn.Sequential(nn.Linear(ExpertOutDim, hidden_layer2[0]),nn.ReLU(),nn.Dropout(p2),nn.Linear(hidden_layer2[0], hidden_layer2[1]),nn.ReLU(),nn.Dropout(p2),nn.Linear(hidden_layer2[1], 1))def forward(self, x):Output_A, Output_Shared, Output_B = self.Extraction_layer1(x, x, x)Gate_A_Out, Gate_B_Out = self.CGC(Output_A, Output_Shared, Output_B)out1 = self.tower1(Gate_A_Out)out2 = self.tower2(Gate_B_Out)return out1, out2return Gate_A_Out, Gate_B_Out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/655395.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

防火墙路由

目录 1. 防火墙的智能选路 2. 策略路由 -- PBR 3. 智能选路 --- 全局路由策略 3.1 基于链路带宽的负载分担: 3.2 基于链路质量进行负载分担 3.3 基于链路权重进行负载分担 3.4 基于链路优先级的主备备份 1. 防火墙的智能选路 就近选路 --- 我们希望在访问不同运营商的服…

Vue2 通过.sync修饰符实现数据双向绑定

App.vue <template><div class"app"><buttonv-on:clickisShowtrue>退出按钮</button><BaseDialog:visible.syncisShow></BaseDialog></div> </template><script> import BaseDialog from "./components…

多符号表达式的共同子表达式提取教程

生成的符号表达式&#xff0c;可能会存在过于冗长的问题&#xff0c;且多个符号表达式中&#xff0c;有可能存在相同的计算部分&#xff0c;如果不进行处理&#xff0c;计算过程中会导致某些算式计算多次&#xff0c;从而影响计算效率。 那么多个符号表达式生成函数时&#xf…

[机器学习]KNN——K邻近算法实现

一.K邻近算法概念 二.代码实现 # 0. 引入依赖 import numpy as np import pandas as pd# 这里直接引入sklearn里的数据集&#xff0c;iris鸢尾花 from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split # 切分数据集为训练集和测试…

2024年数学建模美赛 分析与编程

2024年数学建模美赛 分析与编程 1、本专栏将在2024年美赛题目公布后&#xff0c;进行深入分析&#xff0c;建议收藏&#xff1b; 2、本专栏对2023年赛题&#xff0c;其它题目分析详见专题讨论&#xff1b; 2023年数学建模美赛A题&#xff08;A drought stricken plant communi…

JavaSE——运算符、运算符优先级、API、Scanner

目录 基本的算术运算符 自增自减运算符 赋值运算符 关系运算符 逻辑运算符 三目运算符 运算符优先级 API Scanner 基本的算术运算符 符号作用加-减*乘/除%取余 基本与C语言的基本算术运算符一致 注意&#xff1a;两个整数相除结果还是整数 public static void main…

C++PythonC# 三语言OpenCV从零开发(7):图像的阈值

文章目录 相关链接前言阈值阈值使用代码PythonCCsharpcsharp代码问题 总结 相关链接 C&Python&Csharp in OpenCV 专栏 【2022B站最好的OpenCV课程推荐】OpenCV从入门到实战 全套课程&#xff08;附带课程课件资料课件笔记&#xff09; OpenCV一个窗口同时显示多张图片 …

C 变量

目录 1. C变量 2. C变量定义 2.1 变量初始化 2.2 C中的变量声明 3. C中的左值&#xff08;Lvalues&#xff09;和右值&#xff08;Rvalues&#xff09; 1. C变量 在C语言中&#xff0c;变量可以根据其类型分为以下几种基本类型&#xff1a; 整型变量&#xff1a;用…

自然语言nlp学习 三

4-8 Prompt-Learning--应用_哔哩哔哩_bilibili Prompt Learning&#xff08;提示学习&#xff09;是近年来在自然语言处理领域中&#xff0c;特别是在预训练-微调范式下的一个热门研究方向。它主要与大规模预训练模型如GPT系列、BERT等的应用密切相关。 在传统的微调过程中&a…

将vite项目(vue/react)使用vite-plugin-pwa配置为pwa应用,只需要3分钟即可

将项目配置为pwa模式&#xff0c;就可以在浏览器里面看到安装应用的选项&#xff0c;并且可以将web网页像app一样添加到手机桌面或者pad桌面上&#xff0c;或者是电脑桌面上&#xff0c;这样带来的体验就像真的在一个app上运行一样。为了实现这个目的&#xff0c;我们可以为vue…

算法设计与分析实验:滑动窗口与二分查找

目录 一、寻找两个正序数组的中位数 1.1 具体思路 1.2 流程展示 1.3 代码实现 1.4 代码复杂度分析 1.5 运行结果 二、X的平方根 2.1 具体思路 2.2 流程展示 2.3 代码实现 2.4 代码复杂度分析 2.5 运行结果 三、两数之和 II-输入有序数组 3.1 采用二分查找的思想 …

LeetCode —— 43. 字符串相乘

&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️Take your time ! &#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️&#x1f636;‍&#x1f32b;️…

Bio-Rad(Abd serotec)独特性抗体

当一种抗体与另一种抗体的独特型结合时&#xff0c;它被称为抗独特型抗体。抗体的可变部分包括独特的抗原结合位点&#xff0c;称为独特型。独特型(即独特型)内表位的组合对于每种抗体都是独特的。 如今开发的大多数治疗性单克隆抗体是人的或人源化的&#xff0c;用于诱导抗药…

【国产MCU】-认识CH32V307及开发环境搭建

认识CH32V307及开发环境搭建 文章目录 认识CH32V307及开发环境搭建1、CH32V307介绍2、开发环境搭建3、程序固件下载1、CH32V307介绍 CH32V307是沁恒推出的一款基于32位RISC-V设计的互联型微控制器,配备了硬件堆栈区、快速中断入口,在标准RISC-V基础上大大提高了中断响应速度…

java 社区资源管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java Web社区资源管系统是一套完善的java web信息管理系统 &#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5.…

bxCAN-中断

bxCAN中断 bxCAN 共有四个专用的中断向量。每个中断源均可通过 CAN 中断使能寄存器 (CAN_IER) 来单独地使能或禁止。 发送中断可由以下事件产生&#xff1a; 发送邮箱 0 变为空&#xff0c;CAN_TSR 寄存器的 RQCP0 位置 1。 发送邮箱 1 变为空&#xff0c;CAN_TSR 寄存器…

SkyWalking+es部署与使用

第一步下载skywalking :http://skywalking.apache.org/downloads/ 第二步下载es:https://www.elastic.co/cn/downloads/elasticsearch 注&#xff1a;skywalking 和es要版本对应&#xff0c;可从下面连接查看版本对应关系&#xff0c;8.5.0为skywalking 版本号 Index of /di…

Apache Commons Collection3.2.1反序列化分析(CC1)

Commons Collections简介 Commons Collections是Apache软件基金会的一个开源项目&#xff0c;它提供了一组可复用的数据结构和算法的实现&#xff0c;旨在扩展和增强Java集合框架&#xff0c;以便更好地满足不同类型应用的需求。该项目包含了多种不同类型的集合类、迭代器、队…

大数据学习之Redis、从零基础到入门(三)

目录 三、redis10大数据类型 1.哪十个&#xff1f; 1.1 redis字符串&#xff08;String&#xff09; 1.2 redis列表&#xff08;List&#xff09; 1.3 redis哈希表&#xff08;Hash&#xff09; 1.4 redis集合&#xff08;Set&#xff09; 1.5 redis有序集合&#xff08…

Android SystemUI 介绍

目录 一、什么是SystemUI 二、SystemUI应用源码 三、学习 SystemUI 的核心组件 四、修改状态与导航栏测试 本篇文章&#xff0c;主要科普的是Android SystemUI &#xff0c; 下一篇文章我们将介绍如何把Android SystemUI 应用转成Android Studio 工程项目。 一、什么是Syst…