【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

 

目录

一、引言

1.1 往期回顾

1.2 本期概要

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

三、总结


一、引言

在朴素的深度学习ctr预估模型中(如DNN),通常以一个行为为预估目标,比如通过ctr预估点击率。但实际推荐系统业务场景中,更多是多种目标融合的结果,比如视频推荐,会存在视频点击率、视频完整播放率、视频播放时长等多个目标,而多种目标如何更好的融合,在工业界与学术界均有较多内容产出,由于该环节对实际业务影响最为直接,特开此专栏对推荐系统深度学习多目标问题进行讲述。

1.1 往期回顾

上一篇文章主要介绍了推荐系统多目标算法中的“样本Loss加权”,该方法在训练时Loss乘以样本权重实现对多种目标的加权,通过引导Loss梯度的学习方向,让模型参数朝着你设定的权重方向去学习。

1.2 本期概要

今天进一步深化,主要介绍Shared-Bottom Multi-task Model算法,该算法中文可译为“底部共享多任务模型”,该算法设定多个任务,每个任务设定多个目标,通过“Loss计算时调整每个任务的权重”,亦或是“每个塔单元内,多目标Loss计算时调整每个目标的权重”进行多任务多目标的调整。

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

Shared-Bottom Multi-task Model(SBMM)全称为底层共享多任务模型,主要由底层共享网络、多任务塔、多目标输出构成。核心原理:通过构造多任务多目标样本数据,在Loss计算环节,将各任务Loss求和(或加权求和),对Loss求导(求梯度)后,逐步后向传播迭代。

  1. 底部网络:Shared-Bottom 网络通常位于底部,可以为一个DNN网络,或者emb+pooling+mlp的方式对input输入的稀疏(sparse)特征进行稠密(dense)化。
  2. ​​​​​​​多个任务塔:底部网络上层接N个任务塔(Tower),每个塔根据需要可以定义为简单或复杂的多层感知器(mlp)网络。每个塔可以对应特定的场景,比如一二级页面场景。
  3. 多个目标:每个任务塔(Tower)可以输出多个学习目标,每个学习目标还可以像上一篇文章一样进行样本Loss加权。每个目标可以对应一种特定的指标行为,比如点击、时长、下单等。

2.2 技术优缺点

相比于上一篇文章提到的样本Loss加权融合法,以及后续文章将会介绍的MoE、MMoE方法,有如下优缺点:

优点:

  • 可以对多级场景任务进行建模,使得ctcvr等点击后转化问题可以被深度学习
  • 浅层参数共享,互相补充学习,任务相关性越高,模型的loss可以降低到更低

缺点: 

  • 跷跷板问题:任务没有好的相关性时,这种Hard parameter sharing会损害效果

2.3 业务代码实践

我们以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

跨场景多目标建模:我们定义一个SBMM算法结构,底层是一个3层的MLP(64,32,16),MLP出来后接一级场景Tower和二级场景Tower,一级场景任务中分别定义视频一级页“是否停留”、“停留时长”、“是否点击”,二级场景任务中分别定义“点击后播放时长”,“播放后是否点赞”

伪代码:

导入 pytorch 库
定义 SharedBottomMultiTaskModel 类 继承自 nn.Module:定义 __init__ 方法 参数 (self, 输入维度, 隐藏层1大小, 隐藏层2大小, 隐藏层3大小, 输出任务1维度, 输出任务2维度):初始化共享底部的三层全连接层初始化任务1的三层全连接层初始化任务2的三层全连接层定义 forward 方法 参数 (self, 输入数据):计算输入数据通过共享底部后的输出从共享底部输出分别计算任务1和任务2的结果返回任务1和任务2的结果生成虚拟样本数据:创建训练集和测试集实例化模型对象
定义损失函数和优化器
训练循环:前向传播: 获取预测值计算每个任务的损失反向传播和优化

PyTorch版本:

算法逻辑

  1. 导入必要的库。
  2. 定义一个类来表示共享底部和特定任务头部的模型结构。
  3. 在初始化方法中定义共享底部和两个独立的任务头部网络层。
  4. 实现前向传播函数,处理输入数据通过共享底部后分发到不同的任务头部。
  5. 生成虚拟样本数据。
  6. 定义损失函数和优化器。
  7. 编写训练循环。
  8. 进行模型预测。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass SharedBottomMultiTaskModel(nn.Module):def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_task1_dim, output_task2_dim):super(SharedBottomMultiTaskModel, self).__init__()# 定义共享底部的三层全连接层self.shared_bottom = nn.Sequential(nn.Linear(input_dim, hidden1_dim),nn.ReLU(),nn.Linear(hidden1_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, hidden3_dim),nn.ReLU())# 定义任务1的三层全连接层self.task1_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task1_dim))# 定义任务2的三层全连接层self.task2_head = nn.Sequential(nn.Linear(hidden3_dim, hidden2_dim),nn.ReLU(),nn.Linear(hidden2_dim, output_task2_dim))def forward(self, x):# 计算输入数据通过共享底部后的输出shared_output = self.shared_bottom(x)# 从共享底部输出分别计算任务1和任务2的结果task1_output = self.task1_head(shared_output)task2_output = self.task2_head(shared_output)return task1_output, task2_output# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
task1_dim = 3
task2_dim = 2
num_samples = 1000
X_train = torch.randn(num_samples, input_dim)
y_train_task1 = torch.randn(num_samples, task1_dim)  # 假设任务1的输出维度为task1_dim
y_train_task2 = torch.randn(num_samples, task2_dim)  # 假设任务2的输出维度为task2_dim# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 实例化模型对象
model = SharedBottomMultiTaskModel(input_dim, 64, 32, 16, task1_dim, task2_dim)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)#print(f'loss_task1:{loss_task1},loss_task2:{loss_task2}')total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randn(1, input_dim)  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'任务1预测结果: {pred_task1}')print(f'任务2预测结果: {pred_task2}')

三、总结

本文从技术原理、技术优缺点方面对推荐系统深度学习多任务多目标“Shared-Bottom Multi-task Model”算法进行讲解,该模型使用深度学习模型对多个任务场景多个目标的业务问题进行建模,使得用户在多个场景连续性行为可以被学习,在现实推荐系统业务中是比较基础的方法,后面本专栏还会陆续介绍MoE、MMoE等多任务多目标算法,期待您的关注和支持。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/66568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TVbox 手机、智能电视节目一网打尽

文章目录 一、简要介绍二、优点三、下载地址 一、简要介绍 TVbox是目前最火爆的多端、多源的电视影音工具,是一款开源的自定义添加站源的影音工具。TVBox,支持电视频道直播。一款TV端影视工具,软件本身不具有任何影视资源,但可以…

2025新春烟花代码(二)HTML5实现孔明灯和烟花效果

效果展示 源代码 <!DOCTYPE html> <html lang"en"> <script>var _hmt _hmt || [];(function () {var hm document.createElement("script");hm.src "https://hm.baidu.com/hm.js?45f95f1bfde85c7777c3d1157e8c2d34";var …

ue5 蒙太奇,即上半身动画和下半身组合在一起,并使用。学习b站库得科技

本文核心 正常跑步动画端枪动画跑起来也端枪 正常跑步动画 端枪动画的上半身 跑起来也端枪 三步走&#xff1a; 第一步制作动画蒙太奇和插槽 第二步动画蓝图选择使用上半身动画还是全身动画&#xff0c;将上半身端枪和下半身走路结合 第三步使用动画蒙太奇 1.开始把&a…

YOLOv8实战人员跌倒检测

本文采用YOLOv8作为核心算法框架&#xff0c;结合PyQt5构建用户界面&#xff0c;使用Python3进行开发。YOLOv8以其高效的实时检测能力&#xff0c;在多个目标检测任务中展现出卓越性能。本研究针对人员跌倒目标数据集进行训练和优化&#xff0c;该数据集包含丰富人员跌倒图像样…

Java 分布式锁:Redisson、Zookeeper、Spring 提供的 Redis 分布式锁封装详解

&#x1f4da; Java 分布式锁&#xff1a;Redisson、Zookeeper、Spring 提供的 Redis 分布式锁封装详解 在分布式系统中&#xff0c;分布式锁 用于解决多个服务实例同时访问共享资源时的 数据一致性 问题。Java 生态中&#xff0c;有多种成熟的框架可以实现分布式锁&#xff0…

01.02、判定是否互为字符重排

01.02、[简单] 判定是否互为字符重排 1、题目描述 给定两个由小写字母组成的字符串 s1 和 s2&#xff0c;请编写一个程序&#xff0c;确定其中一个字符串的字符重新排列后&#xff0c;能否变成另一个字符串。 在这道题中&#xff0c;我们的任务是判断两个字符串 s1 和 s2 是…

C#进阶-在Ubuntu上部署ASP.NET Core Web API应用

随着云计算和容器化技术的普及&#xff0c;Linux 服务器已成为部署 Web 应用程序的主流平台之一。ASP.NET Core 作为一个跨平台、高性能的框架&#xff0c;非常适合在 Linux 环境中运行。本篇博客将详细介绍如何在 Linux 服务器上部署 ASP.NET Core Web API 应用&#xff0c;包…

【网页自动化】篡改猴入门教程

安装篡改猴 打开浏览器扩展商店&#xff08;Edge、Chrome、Firefox 等&#xff09;。搜索 Tampermonkey 并安装。 如图安装后&#xff0c;浏览器右上角会显示一个带有猴子图标的按钮。 创建用户脚本 已进入篡改猴管理面板点击创建 脚本注释说明 name&#xff1a;脚本名称。…

数据结构之双链表(C语言)

​ 数据结构之双链表&#xff08;C语言&#xff09; 1 链表的分类2 双向链表的结构3 双向链表的节点创建与初始化3.1 节点创建函数3.2 初始化函数 4 双向链表插入节点与删除节点的前序分析5 双向链表尾插法与头插法5.1 尾插函数5.2 头插函数 6 双向链表的尾删法与头删法6.1尾删…

【0x007A】HCI_Write_Secure_Connections_Host_Support命令详解

目录 一、命令概述 二、命令格式及参数 2.1. HCI_Write_Secure_Connections_Host_Support命令格式 2.2. Secure_Connections_Host_Support 三、生成事件及参数 3.1. HCI_Command_Complete事件格式 3.2. Status 四、命令执行流程梳理 4.1. 命令发送阶段 4.2. 命令接收…

第一节 环境搭建

Visual Studio Visual Studio 2019 密码&#xff1a;gd24 组件 安装即可

《Spring Framework实战》4:Spring Framework 文档

欢迎观看《Spring Framework实战》视频教程 概述 历史&#xff0c; 设计理念&#xff0c; 反馈&#xff0c; 开始。 核心技术 IoC 容器、事件、资源、i18n、 验证、数据绑定、类型转换、SpEL、AOP、AOT。 测试 Mock 对象、TestContext 框架、 Spring MVC 测试&#xff0c;…

PyTorch reshape函数介绍

torch.reshape 是 PyTorch 用于改变张量形状的函数之一。它不会改变张量的数据&#xff0c;而是重新组织其元素以适应新的形状。 reshape 的使用 torch.reshape(input, shape) → Tensorinput&#xff1a;输入张量。shape&#xff1a;新形状&#xff0c;使用整数或 -1 指定各维…

Java QueryWrapper groupBy自定义字段,以及List<Map>转List<Entity>

Java queryWrapper groupby自定义字段 String sql "data_id,(select value from lz_html a where a.data_id lz_html.data_id and class_nametest-item-status) status," "(select value from lz_html a where a.data_id lz_html.data_id and class_nametes…

【adb】5分钟入门adb操作安卓设备

ADB&#xff08;Android Debug Bridge&#xff09; 是一个多功能的命令行工具&#xff0c;用于与 Android 设备进行交互、调试和管理。它提供了对设备的直接控制&#xff0c;能够帮助开发者进行调试、安装应用、传输文件等。 目录 将设备和电脑连接 adb shell 文件的基本操…

LeetCode100之组合总和(39)--Java

1.问题描述 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target &#xff0c;找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 &#xff0c;并以列表形式返回。你可以按 任意顺序 返回这些组合。 candidates 中的 同一个 数字可以 无限制重复…

Cosmos的gRPC与Go

Cosmos与Go语言 gRPC gRPC的基本概念&#xff08;维基百科&#xff09;&#xff1a; gRPC (gRPC Remote Procedure Calls) 是一个跨平台的开源高性能远程过程调用&#xff08;RPC&#xff09;框架。gRPC最初由Google创建&#xff0c;它使用一个通用的RPC基础设施Stubby来连接…

maven的中国镜像有哪些

根据您的请求&#xff0c;以下是一些可用的 Maven 中国镜像&#xff1a; 阿里云 官网&#xff1a;阿里云 Maven 镜像配置&#xff1a;<mirror><id>aliyunmaven</id><mirrorOf>*</mirrorOf><name>阿里云公共仓库</name><url>…

Apache zookeeper集群搭建

文章目录 引言I 集群搭建保证服务器基础环境一致JDK安装与配置环境变量安装与修改zk配置文件同步zk安装包与配置文件zk集群启停查看进程、状态、日志II 扩展:shell脚本一键启停引言 springCloud 脚手架项目功能模块:Java分布式锁 https://blog.csdn.net/z929118967/article/d…

Tauri教程-基础篇-第二节 Tauri的核心概念上篇

“如果结果不如你所愿&#xff0c;就在尘埃落定前奋力一搏。”——《夏目友人帐》 “有些事不是看到了希望才去坚持&#xff0c;而是因为坚持才会看到希望。”——《十宗罪》 “维持现状意味着空耗你的努力和生命。”——纪伯伦 Tauri 技术教程 * 第四章 Tauri的基础教程 第二节…