【深度学习】多目标融合算法(四):多门混合专家网络MMOE(Multi-gate Mixture-of-Experts)

目录

一、引言

二、MMoE(Multi-gate Mixture-of-Experts,多门混合专家网络)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

2.3.1 业务场景与建模

2.3.2 模型代码实现

2.3.3 模型训练与推理测试

2.3.4 打印模型结构 

三、总结


一、引言

上一篇我们讲了MoE混合专家网络,通过引入Gate门控,针对不同的Input分布,对多个专家网络赋予不同的权重,解决多场景或多目标任务task的底层信息共享及个性化问题。但MoE网络对于不同的Expert专家网络,采用同一个Gate门控网络,仅对不同的Input分布实现了个性化,对不同目标任务task的个性化刻画能力不足,今天在MoE的基础上,引入MMoE网络,为每一个task任务构建专属的Gate门控网络,这样的改进可以针对不同的task得到不同的Experts权重,从而实现对Experts专家的选择利用,不同的任务task对应的gate门控网络可以学习到不同的Experts网络组合模式,更容易捕捉到不容task间的相关性和差异性。

二、MMoE(Multi-gate Mixture-of-Experts,多门混合专家网络)

2.1 技术原理

MMoE(Multi-gate Mixture-of-Experts)全称为多门混合专家网络,主要由多个专家网络、多个任务塔、多个门控网络构成。核心原理:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim;同时,样本数据分别输入目标task对应的门控网络Gate A及Gate B,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(采用softmax将输出归一化,各个维度加起来和为1);对于每一个Task,将各自对应专家网络的输出,基于对应gate门控网络的softmax加权平均,作为各自Task的输入,所有Task的输入统一维度均为output_experts_dim。在每次反向传播迭代时,对Gate A、Gate B和num_experts个专家参数进行更新,Gate A、Gate B和专家网络的参数受任务Task A、B共同影响。

  • 专家网络:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim。
  • 门控网络:样本数据分别输入目标task对应的门控网络Gate A及Gate B,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(采用softmax将输出归一化,各个维度加起来和为1)
  • 任务网络:对于每一个Task,将各自对应专家网络的输出,基于对应gate门控网络的softmax加权平均,作为各自Task的输入,所有Task的输入统一维度均为output_experts_dim。

2.2 技术优缺点

相较于MoE网络,MMoE的本质是每个task自带Gate门控网络对多个专家的预估结果进行选择,相当于给每个task安排了一个个人助理,对专家的结果进行评审(而MoE对于所有task仅有一个公共助理,对task的专属需求了解不深)。相较于MoE网络:

优点:

  • 对每个task安排专属的gate网络,在专家网络赋值时更加个性化
  • 更容易捕捉到不容task间的相关性和差异性。

缺点: 

  • MMOE中所有的Expert是被所有task共享的,这可能无法捕捉到任务之间更复杂的关系,从而给部分任务带来一定的噪声
  • 不同的Expert之间没有交互,联合优化的效果有所折扣,虽然可以缓解负迁移问题,但跷跷板现象仍然存在。

2.3 业务代码实践

2.3.1 业务场景与建模

我们还是以小红书推荐场景为例,针对一个视频,用户可以点红心(互动),也可以点击视频进行播放(点击),针对互动和点击两个目标进行多目标建模

我们构建一个100维特征输入,4个experts专家网络,2个task目标,2个门控的MMoE网络,用于建模多目标学习问题,模型架构图如下:

​​​​​​​​​​​​​​​​​​​​​

如架构图所示,其中有几个注意的点:

  • num_experts:门控gate的输出维度和专家数相同,均为num_experts,因为gate的用途是对专家网络最后一层进行加权平均,gate维度与专家数是直接对应关系。
  • output_experts_dim:专家网络的输出维度和task网络的输入维度相同,task网络承接的是专家网络各维度的加权平均值,experts网络与task网络是直接对应关系。
  • Softmax:Gate门控网络对最后一层采用Softmax归一化,保证专家网络加权平均后值域相同

2.3.2 模型代码实现

基于pytorch,实现上述网络架构,如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass MMoEModel(nn.Module):def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts):super(MMoEModel, self).__init__()# 初始化函数外使用初始化变量需要赋值,否则默认使用全局变量# 初始化函数内使用初始化变量不需要赋值 self.num_experts = num_expertsself.output_experts_dim = output_experts_dim# 初始化多个专家网络self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_experts)])# 定义任务1的输出层self.task1_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task1_dim),nn.Sigmoid()) # 定义任务2的输出层self.task2_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task2_dim),nn.Sigmoid()) # 初始化门控网络1self.gating1_network = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_experts),nn.Softmax(dim=1))# 初始化门控网络2self.gating2_network = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_experts),nn.Softmax(dim=1))def forward(self, x):# 计算输入数据通过门控网络后的权重gates1 = self.gating1_network(x)gates2 = self.gating2_network(x)#print(gates)batch_size, _ = x.shapetask1_inputs = torch.zeros(batch_size, self.output_experts_dim)task2_inputs = torch.zeros(batch_size, self.output_experts_dim)# 计算每个专家的输出并加权求和for i in range(self.num_experts):expert_output = self.experts[i](x)task1_inputs += expert_output * gates1[:, i].unsqueeze(1)task2_inputs += expert_output * gates2[:, i].unsqueeze(1)task1_outputs = self.task1_head(task1_inputs)task2_outputs = self.task2_head(task2_inputs)return task1_outputs, task2_outputs# 实例化模型对象
num_experts = 4  # 假设有4个专家
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 1
output_task2_dim = 1# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 100
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)  # 假设任务1的输出维度为1
y_train_task2 = torch.rand(num_samples, output_task2_dim)  # 假设任务2的输出维度为1# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)model = MMoEModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 100
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值#print(batch_idx, X_batch )#print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()if epoch % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print(model)
#for param_tensor in model.state_dict():
#    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randint(0, 2, (1, input_dim)).float()  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'互动目标预测结果: {pred_task1}')print(f'点击目标预测结果: {pred_task2}')

相比于上一篇MoE中的代码,MMoE初始化了gating1_network和gating2_network两个门控网络,在forward前向传播网络结构定义中,两个gate分别以input为输入,通过多层MLP后得到task相对应的加权平均权重。

2.3.3 模型训练与推理测试

运行上述代码,模型启动训练,Loss逐渐收敛,测试结果如下:

2.3.4 打印模型结构 ​​​​​​​

三、总结

本文详细介绍了MMoE多任务模型的算法原理、算法优势,并以小红书业务场景为例,构建网络结构并使用pytorch代码实现对应的网络结构、训练流程。相比于MoE,MMoE可以更好的学习不同Task任务的相关性和差异性。是深度学习推荐系统中多目标或多场景类问题中必须掌握的根基模型。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

【深度学习】多目标融合算法(三):混合专家网络MOE(Mixture-of-Experts) 

 【深度学习】多目标融合算法(四):多门混合专家网络MMOE(Multi-gate Mixture-of-Experts)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/70713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

9 数据流图

9 数据流图 9.1数据平衡原则 子图缺少处理后的数据操作结果返回前端应用以及后端数据库返回操作结果到数据管理中间件。 9.2解题技巧 实件名 存储名 加工名 数据流

Hdoop之MapReduce的原理

简单版本 AppMaster: 整个Job任务的核心协调工具 MapTask: 主要用于Map任务的执行 ReduceTask: 主要用于Reduce任务的执行 一个任务提交Job --> AppMaster(项目经理)--> 根据切片的数量统计出需要多少个MapTask任务 --> 向ResourceManager(Yarn平台的老大)索要资源 --…

Linux云计算SRE-第六周

1. 总结openssh服务安全加固和总结openssh免密认证原理,及免认证实现过程。 1、 openssh服务安全加固 OpenSSH(Open Secure Shell)服务安全加固是确保远程登录会话和其他网络服务安全性的关键步骤。以下是一些常见的OpenSSH服务安全加固措施…

Excel 笔记

实际问题记录 VBA脚本实现特殊的行转列 已知:位于同一Excel工作簿文件中的两个工作表:Sheet1、Sheet2。 问题:现要将Sheet2中的每一行,按Sheet1中的样子进行转置: Sheet2中每一行的黄色单元格,为列头。…

react使用if判断

1、第一种 function Dade(req:any){console.log(req)if(req.data.id 1){return <span>66666</span>}return <span style{{color:"red"}}>8888</span>}2、使用 {win.map((req,index) > ( <> <Dade data{req}/>{req.id 1 ?…

Java从入门到精通 第三版 读书笔记

第一章 初识Java Java同时是编译型(编译器将Java源代码静态编译为Java字节码)和解释型(JVM将Java字节码动态解释为本地机器码)语言。Java程序的运行需要解释器(如JVM)。因Java字节码本具有平台无关性,那么若要在一个新目标平台上运行一个Java程序,则仅需解释器做好目标…

【零基础学习CAPL】——Panel之弹窗的创建与使用

🙋‍♂️【零基础学习CAPL】系列💁‍♂️点击跳转 ——————————————————————————————————–—— 从0开始学习CANoe使用 从0开始学习车载测试 相信时间的力量 星光不负赶路者,时光不负有心人。 文章目录 1.概述2. panel制作2.1 panel窗体…

C# OpenCV机器视觉:对位贴合

在热闹非凡的手机维修街上&#xff0c;阿强开了一家小小的手机贴膜店。每天看着顾客们自己贴膜贴得歪歪扭扭&#xff0c;不是膜的边缘贴不整齐&#xff0c;就是里面充满了气泡&#xff0c;阿强心里就想&#xff1a;“要是我能有个自动贴膜的神器&#xff0c;那该多好啊&#xf…

推荐一个免费的、开源的大数据工程学习教程

在当今信息爆炸的时代&#xff0c;每一个企业都会产生大量的数据&#xff0c;而大数据也已经成为很多企业发展的重要驱动力&#xff0c;然而如何有效得处理和分析这些海量的数据&#xff0c;却是一个非常有挑战的技术。 今天推荐一个免费的数据工程教程&#xff0c;带你系统化…

2月10日QT

作业> 将文本编辑器功能完善 include "widget.h" #include "ui_widget.h" #include <QMessageBox> //消息对话框类 #include <QFontDialog> //字体类对话框 #include <QFont> //字体类 #include <QColorDialog> //颜…

【Java】多线程和高并发编程(四):阻塞队列(上)基础概念、ArrayBlockingQueue

文章目录 四、阻塞队列1、基础概念1.1 生产者消费者概念1.2 JUC阻塞队列的存取方法 2、ArrayBlockingQueue2.1 ArrayBlockingQueue的基本使用2.2 生产者方法实现原理2.2.1 ArrayBlockingQueue的常见属性2.2.2 add方法实现2.2.3 offer方法实现2.2.4 offer(time,unit)方法2.2.5 p…

【Java】多线程和高并发编程(三):锁(下)深入ReentrantReadWriteLock

文章目录 4、深入ReentrantReadWriteLock4.1 为什么要出现读写锁4.2 读写锁的实现原理4.3 写锁分析4.3.1 写锁加锁流程概述4.3.2 写锁加锁源码分析4.3.3 写锁释放锁流程概述&释放锁源码 4.4 读锁分析4.4.1 读锁加锁流程概述4.4.1.1 基础读锁流程4.4.1.2 读锁重入流程4.4.1.…

【R语言】相关系数

一、cor()函数 cor()函数是R语言中用于计算相关系数的函数&#xff0c;相关系数用于衡量两个变量之间的线性关系强度和方向。 常见的相关系数有皮尔逊相关系数&#xff08;Pearson correlation coefficient&#xff09;、斯皮尔曼秩相关系数&#xff08;Spearmans rank corre…

编译和链接【一】

文章目录 编译和链接【一】从翻译单元到二进制文件 编译和链接【一】 在我大一的时候&#xff0c; 我使用VC6.0对C语言程序进行编译链接和运行 &#xff0c; 然后我接触了VS&#xff0c; VS code等众多IDE&#xff0c; 这些IDE界面友好&#xff0c; 使用方便&#xff0c; 例如…

Linux: ASoC 声卡硬件参数的设置过程简析

文章目录 1. 前言2. ASoC 声卡设备硬件参数2.1 将 DAI、Machine 平台的硬件参数添加到声卡2.2 打开 PCM 流时将声卡硬件参数配置到 PCM 流2.3 应用程序对 PCM 流参数进行修改调整 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&am…

ansible使用学习

一、查询手册 1、官网 ansible官网地址&#xff1a;https://docs.ansible.com 模块查看路径&#xff1a;https://docs.ansible.com/ansible/latest/collections/ansible/builtin/index.html#plugins-in-ansible-builtin 2、命令 ansible-doc -s command二、相关脚本 1、服务…

jmap使用

常用命令 jmap -heap PID jmap -histo PID | head -20 jmap -dump:formatb,fileheap_dump.hprof PID jmap 是 Java 开发工具包&#xff08;JDK&#xff09;提供的一个命令行工具&#xff0c;用于生成 Java 进程的内存映射信息。它可以帮助开发者分析 Java 堆内存的使用情况…

基于 SpringBoot 和 Vue 的智能腰带健康监测数据可视化平台开发(文末联系,整套资料提供)

基于 SpringBoot 和 Vue 的智能腰带健康监测数据可视化平台开发 一、系统介绍 随着人们生活水平的提高和健康意识的增强&#xff0c;智能健康监测设备越来越受到关注。智能腰带作为一种新型的健康监测设备&#xff0c;能够实时采集用户的腰部健康数据&#xff0c;如姿势、运动…

docker离线安装及部署各类中间件(x86系统架构)

前言&#xff1a;此文主要针对需要在x86内网服务器搭建系统的情况 一、docker离线安装 1、下载docker镜像 https://download.docker.com/linux/static/stable/x86_64/ 版本&#xff1a;docker-23.0.6.tgz 2、将docker-23.0.6.tgz 文件上传到服务器上面&#xff0c;这里放在…

从零到一:我的元宵灯谜小程序诞生记

缘起&#xff1a;一碗汤圆引发的灵感 去年元宵节&#xff0c;我正捧着热腾腾的汤圆刷朋友圈&#xff0c;满屏都是"转发锦鲤求灯谜答案"的动态。看着大家对着手机手忙脚乱地切换浏览器查答案&#xff0c;我突然拍案而起&#xff1a;为什么不做一个能即时猜灯谜的微信…