特征交叉系列:DCN-Mix 混合低秩交叉网络理论和实践

DCN-Mix和DCN-V2的关系

DCN-Mix(a mixture of low-rank DCN)是基于DCN-V2的改进版,它提出使用矩阵分解降低DCN-V2的时间空间复杂度,又引入多次矩阵分解来达到类似混合专家网络MOE的效果从而提升交叉层的表征能力,若读者对DCN-V2不甚了解可以参考上一节[特征交叉系列:Deep&Cross(DCN-V2)理论和实践]做知识铺垫。


DCN-V2权重矩阵的低秩性和矩阵分解

在DCN-V2中核心的参数是交叉层的权重矩阵W,该参数是M×M的方阵,其中M是所有输入embedding拼接后的向量总长度,每一层交叉之间W不共享,W矩阵需要学习的参数数量能占到所有参数量的70%以上,而进一步作者发现随着网络的训练,W矩阵的奇异值出现快速下降呈现出低秩特性,代表该矩阵存在信息冗余,因此可以考虑通过矩阵分解来进行特征提取和信息压缩。
在PyTorch中可以通过torch.linalg.svd计算出矩阵的奇异值,例如

>>> a = torch.tensor([[1, 1], [1, 1.1]])
>>> u, s, v = torch.linalg.svd(a)
>>> print(s)
tensor([2.0512, 0.0488])

其中s是对角阵,斜对角线上的值就是奇异值,a矩阵的第二行几乎可以从第一行线性变换而来,因此s各位置上的奇异值差距极大,第一个奇异值基本携带了全部的矩阵信息。
在DCN-V2的训练代码里面,打印出第一个交叉层初始化的W矩阵和训练早停后W矩阵的奇异值,奇异值的长度和输入长度M一致,代码如下

# 初始化时
model = DCN(field_num=10, feat_dim=72, emb_num=16, order_num=2, dropout=0.1, method='parallel').to(DEVICE)
init_s = torch.linalg.svd(model.cross_net.cell_list[0].w)[1].cpu().detach().numpy().tolist()
# 早停时
if early_stop_flag:train_s = torch.linalg.svd(model.cross_net.cell_list[0].w)[1].cpu().detach().numpy().tolist()break

奇异值列表中元素大小逐个递减,对init_s和train_s分别做最大最小归一化,要求第一个奇异值归因化为1,

init_s = [(x - min(init_s)) / (max(init_s) - min(init_s)) for x in init_s]
train_s = [(x - min(train_s)) / (max(train_s) - min(train_s)) for x in train_s]

然后做图看一下初始矩阵的奇异值和收敛后的奇异值的各个位置元素的大小情况

import matplotlib.pylab as plt
plt.scatter(list(range(len(init_s))), init_s, label='init', s=3)
plt.scatter(list(range(len(train_s))), train_s, label='learned', s=3)
plt.legend(loc=0)
plt.show()

init和learned奇异值下降对比

相比于初始化阶段(蓝线),模型收敛后(橙线)的W矩阵奇异值急速下降,说明头部的奇异值已经携带了大部分矩阵信息,W矩阵可以考虑做压缩。
在论文中作者将W分解为U,V两个矩阵的相乘,其中U,V都是维度为[M, R]的二维矩阵,M和输入等长,R<=M/2,公式如下

矩阵分解

此时一个交叉权重的参数数量由M平方降低为2×MR。


DCN-Mix的混合专家网络

DCN-Mix使用矩阵UV分解来逼近原始的交叉矩阵W,受到MOE(Mixture of Experts)混合专家网络的启发,作者对W进行多次矩阵分解,单个矩阵分解相当于单个专家网络(Expert)在子空间学习特征交叉,再引入门控机制(Gate)对多个子空间的交叉结果进行自适应地融合,从而提高交叉层的表达能力,DCN结合MOE的示意图如下

MOE示意图

其中该层的输入Input x分别进入n个Expert专家网络,专家网络中包含UV矩阵相乘,同时Input x输入给一个门控网络Gate+Softmax输出n个权重标量,最后Input x会和加权求和的专家网络结果做残差连接。
将矩阵分解和MOE结合起来形成最终的交叉层公式如下

结合MOE的矩阵分解交叉层

相比于DCN-V2,等号左侧的哈达玛积部分改为了一个Σ加权求和的UV矩阵逼近,而右侧的残差连接放到最后和MOE的结果一起做残差连接。


DCN-Mix在PyTorch下的实践

本次实践的数据集和上一篇特征交叉系列:完全理解FM因子分解机原理和代码实战一致,采用用户的购买记录流水作为训练数据,用户侧特征是年龄,性别,会员年限等离散特征,商品侧特征采用商品的二级类目,产地,品牌三个离散特征,随机构造负样本,一共有10个特征域,全部是离散特征,对于枚举值过多的特征采用hash分箱,得到一共72个特征。
DCN-Mix的PyTorch代码实现如下

class Embedding(nn.Module):def __init__(self, feat_num, emb_num):super(Embedding, self).__init__()self.embedding = nn.Embedding(feat_num, emb_num)nn.init.xavier_normal_(self.embedding.weight.data)def forward(self, x):# [None, filed_num] => [None, filed_num, emb_num] => [None, filed_num * emb_num]return self.embedding(x).flatten(1)class DNN(nn.Module):def __init__(self, input_num, hidden_nums, dropout=0.1):super(DNN, self).__init__()layers = []input_num = input_numfor hidden_num in hidden_nums:layers.append(nn.Linear(input_num, hidden_num))layers.append(nn.BatchNorm1d(hidden_num))layers.append(nn.ReLU())layers.append(nn.Dropout(p=dropout))input_num = hidden_numself.mlp = nn.Sequential(*layers)for layer in self.mlp:if isinstance(layer, nn.Linear):nn.init.xavier_normal_(layer.weight.data)def forward(self, x):return self.mlp(x)class CrossCell(nn.Module):"""一个交叉单元"""def __init__(self, input_num, r):super(CrossCell, self).__init__()self.v = nn.Parameter(torch.randn(input_num, r))self.u = nn.Parameter(torch.randn(input_num, r))self.b = nn.Parameter(torch.randn(input_num, 1))nn.init.xavier_normal_(self.v.data)nn.init.xavier_normal_(self.u.data)def forward(self, x0, xi):# [None, emb_num] => [None, emb_num, 1]xi = xi.unsqueeze(2)x0 = x0.unsqueeze(2)# [r, input_num] * [None, emb_num, 1] => [None, r, 1]# [input_num, r] * [None, r, 1] => [None, emb_num, 1]xii = (torch.matmul(self.u, torch.matmul(self.v.t(), xi)) + self.b) * x0return xii  # [None, emb_num, 1]class MOECrossCell(nn.Module):def __init__(self, input_num, r, k):super(MOECrossCell, self).__init__()self.k = kself.cross_cell = nn.ModuleList([CrossCell(input_num, r) for i in range(self.k)])self.gate = nn.Linear(input_num, self.k)nn.init.xavier_normal_(self.gate.weight.data)def forward(self, x0, xi):# [None, emb_num] => [None, emb_num, 1]xii = xi.unsqueeze(2)export_out = []for i in range(self.k):cross_out = self.cross_cell[i](x0, xi)# [[None, emb_num, 1], [None, emb_num, 1], [None, emb_num, 1], [None, emb_num, 1]]export_out.append(cross_out)export_out = torch.concat(export_out, dim=2)  # [None, emb_num, 4]# [None, k] => [None, 1, k]gate_out = self.gate(xi).softmax(dim=1).unsqueeze(dim=1)# [None, emb_num, 4] * [None, 1, k] = [None, emb_num, k] => [None, emb_num, 1]out = torch.sum(export_out * gate_out, dim=2, keepdim=True)out = out + xii  # [None, emb_num, 1]return out.squeeze(2)class CrossNet(nn.Module):def __init__(self, order_num, input_num, r, k):super(CrossNet, self).__init__()self.order = order_numself.cell_list = nn.ModuleList([MOECrossCell(input_num, r, k) for i in range(order_num)])def forward(self, x0):xi = x0for i in range(self.order):xi = self.cell_list[i](x0=x0, xi=xi)return xiclass DCN(nn.Module):def __init__(self, field_num, feat_dim, emb_num, order_num, r=16, k=4, dropout=0.1, method='parallel',hidden_nums=(128, 64, 32)):super(DCN, self).__init__()input_num = field_num * emb_numself.embedding = Embedding(feat_num=feat_dim, emb_num=emb_num)self.dnn = DNN(input_num=input_num, hidden_nums=hidden_nums, dropout=dropout)self.cross_net = CrossNet(order_num=order_num, input_num=input_num, r=r, k=k)if method not in ('parallel', 'stacked'):raise ValueError('unknown combine type: ' + method)self.method = methodlinear_dim = hidden_nums[-1]if self.method == 'parallel':linear_dim = linear_dim + input_numself.linear = nn.Linear(linear_dim, 1)nn.init.xavier_normal_(self.linear.weight.data)def forward(self, x):emb = self.embedding(x)  # [None, field * emb_num]cross_out = self.cross_net(emb)  # [None, input_num]if self.method == 'parallel':dnn_out = self.dnn(emb)  # [None, input_num]out = torch.concat([cross_out, dnn_out], dim=1)else:out = self.dnn(cross_out)  # [None, input_num]out = self.linear(out)return torch.sigmoid(out).squeeze(dim=1)

在CrossCell模块中完成了一个给予UV逼近的交叉操作,在MOECrossCell模块中完成了MOE和残差连接,其中export_out和gate_out分别为专家网络的输出和门控机制的权重。
本例全部是离散分箱变量,所有有值的特征都是1,因此只要输入有值位置的索引即可,一条输入例如

>>> train_data[0]
Out[120]: (tensor([ 2, 10, 14, 18, 34, 39, 47, 51, 58, 64]), tensor(0))

x的长度为10代表10个特征域,每个域的值是特征的全局位置索引,从0到71,一共72个特征。


DCN-Mix调参和效果对比

对阶数(order_num)和融合策略(method)这两个参数进行调参,分别尝试1~4层交叉层,stacked和parallel两种策略,采用10次验证集AUC不上升作为早停条件,验证集的平均AUC如下

DCN调参AUC并行parallel串行stacked
1层交叉(2阶)0.63450.6321
2层交叉(3阶)0.63280.6323
3层交叉(4阶)0.63310.6333
4层交叉(5阶)0.63400.6331

结论依旧是parallel效果好于stacked,其中一层交叉的并行parallel达到验证集最优AUC为0.6345。
再对比一下之前文章中实践的FM,FFM,PNN,DCN-V2等一系列算法,验证集AUC和参数规模如下

算法AUC参数量
FM0.6274361
FFM0.63172953
PNN*0.634229953
DeepFM0.632212746
NFM0.632910186
DCN-parallel-30.6348110017
DCN-stacked-30.6344109857
DCN-Mix-parallel-10.634554501
DCN-Mix-stacked-30.633397869

使用矩阵分解逼近策略的DCN-Mix略低于原生的DCN-V2,但是还是超越一众FM系列的算法,其中以同样是三层交叉的stacked DCN为例,DCN-Mix的参数量相比于DCN-V2有所降低,也印证了论文中提到的“在模型效果和部署延迟之间找到一个平衡”。

最后的最后

感谢你们的阅读和喜欢,我收藏了很多技术干货,可以共享给喜欢我文章的朋友们,如果你肯花时间沉下心去学习,它们一定能帮到你。

因为这个行业不同于其他行业,知识体系实在是过于庞大,知识更新也非常快。作为一个普通人,无法全部学完,所以我们在提升技术的时候,首先需要明确一个目标,然后制定好完整的计划,同时找到好的学习方法,这样才能更快的提升自己。

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

一、全套AGI大模型学习路线

AI大模型时代的学习之旅:从基础到前沿,掌握人工智能的核心技能!

img

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

img

三、AI大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

img

四、AI大模型商业化落地方案

img

五、面试资料

我们学习AI大模型必然是想找到高薪的工作,下面这些面试题都是总结当前最新、最热、最高频的面试题,并且每道题都有详细的答案,面试前刷完这套面试题资料,小小offer,不在话下。
在这里插入图片描述

这份完整版的大模型 AI 学习资料已经上传CSDN,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/23846.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

部署kubesphere报错

安装kubesphere报错命名空间terminted [rootk8smaster ~]# kubectl apply -f kubesphere-installer.yaml Warning: apiextensions.k8s.io/v1beta1 CustomResourceDefinition is deprecated in v1.16, unavailable in v1.22; use apiextensions.k8s.io/v1 CustomResourceDefini…

618科技好物清单:物超所值的产品推荐,总有一款适合你!

随着科技的不断发展&#xff0c;我们生活中涌现出了越来越多的科技创新产品。这些产品不仅让我们的生活变得更加便捷&#xff0c;还提升了我们的生活品质。而在即将到来的618购物节&#xff0c;正是我们购买这些物超所值科技好物的绝佳时机。 本文将为您推荐一些在618期间值得关…

英伟达的数字孪生地球是什么

1 英伟达的数字孪生地球 Earth-2是一个全栈式开放平台&#xff0c;包含&#xff1a;ICON 和 IFS 等数值模型的物理模拟&#xff1b;多种机器学习模型&#xff0c;例如 FourCastNet、GraphCast 和通过 NVIDIA Modulus 实现的深度学习天气预测 (DLWP)&#xff1b;以及通过 NVIDI…

手撕设计模式——克隆对象之原型模式

1.业务需求 ​ 大家好&#xff0c;我是菠菜啊&#xff0c;前俩天有点忙&#xff0c;今天继续更新了。今天给大家介绍克隆对象——原型模式。老规矩&#xff0c;在介绍这期之前&#xff0c;我们先来看看这样的需求&#xff1a;《西游记》中每次孙悟空拔出一撮猴毛吹一下&#x…

pytorch-nn.Module

目录 1. nn.Module2. nn.Sequential容器3. 网络参数parameters4. Modules内部管理5. checkpoint6. train/test状态切换6. 实现自己的网络层6.1 实现打平操作6.2 实现自己的线性层 7. 代码 1. nn.Module 是所有nn.类的父类&#xff0c;其中包括nn.Linear nn.BatchNorm2d nn.Con…

肺结节14问,查出肺结节怎么办?哪些能用中医调治消散?快来了解一下吧

近些年&#xff0c;随着大众防癌意识的加强&#xff0c;和胸部低剂量CT的普及&#xff0c;肺结节的检出率也逐年升高&#xff0c;不少患者CT报告上&#xff0c;写着“肺小结”“肺部磨玻璃结节”的字样&#xff0c;当你看到这几个字时&#xff0c;会不会瞬间紧张起来&#xff1…

编程规范-代码检测-格式化-规范化提交

适用于vue项目的编程规范 – 在多人开发时统一编程规范至关重要 1、代码检测 --Eslint Eslint&#xff1a;一个插件化的 javascript 代码检测工具 在 .eslintrc.js 文件中进行配置 // ESLint 配置文件遵循 commonJS 的导出规则&#xff0c;所导出的对象就是 ESLint 的配置对…

简化电动汽车充电器和光伏逆变器的高压电流检测

在任何电气系统中&#xff0c;电流都是一个至关重要的参数。电动汽车 (EV) 充电系统和太阳能系统都需要检测电流的大小&#xff0c;以便控制和监测功率转换、充电和放电。电流传感器通过监测分流电阻器上的压降或导体中电流产生的磁场来测量电流。 金属氧化物半导体场效应晶体…

DBeaver连接MySQL提示“Public Key Retrieval is not allowed“问题的解决方式

问题描述 客户端root用户连接数据库出现出现Public Key Retrieval is not allowed 原因分析&#xff1a; 加上allowPublicKeyRetrievalfalse&#xff1a; 解决方案&#xff1a; allowPublicKeyRetrievaltrue&#xff1a;

Java Web学习笔记14——BOM对象

BOM&#xff1a; 概念&#xff1a;浏览器对象模型&#xff08;Browser Object Model&#xff09;&#xff0c;允许JavaScript与浏览器对话&#xff0c;JavaScript将浏览器的各个组成部分封装为对象。 组成&#xff1a; Window&#xff1a;浏览器窗口对象 介绍&#xff1a;浏览…

光伏电站鸟害解决方案,列式冲击波声压光伏驱鸟器

光伏电站的运营过程中&#xff0c;最怕遇上鸟粪污染。鸟粪不仅难以清洗&#xff0c;还可能导致光伏组件损坏、降低发电效率。因此&#xff0c;制定并实施有效的驱鸟策略对于光伏电站的稳定运营至关重要。 针对光伏电站的鸟害问题&#xff0c;我们可以从以下几个方面来解决&…

知名优秀定制线缆生产源头工厂推荐-精工电联:全程跟踪监制,打造水下机器人线缆定制新标杆

在科技飞速发展的今天&#xff0c;精工电联作为高科技智能化产品及自动化设备专用连接线束和连接器配套服务商&#xff0c;始终站在行业前沿。我们专注于为高科技行业提供高品质、优匹配的集成线缆和连接器定制服务&#xff0c;特别是在水下机器人线缆定制领域&#xff0c;通过…

sql死锁分析

一、重要参数 获取事务信息:SELECT * FROM information_schema.INNODB_TRX; 获取锁等待:SELECT * FROM information_schema.INNODB_LOCK_WAITS; 查看锁信息:SELECT * FROM information_schema.INNODB_LOCKS WHERE lock_trx_id IN () 二、case1:间隙锁和x锁互斥导致死锁 1、背景…

大厂AI团战高考作文,华师一附中特级教师这样打分

在人工智能的浪潮中&#xff0c; 人们不禁疑问&#xff1a; AI真的能超越人类吗&#xff1f; 这究竟是现实还是幻想&#xff1f; 我们将目睹一场前所未有的较量&#xff1a; 百度文心一言、阿里通义千问、 腾讯混元、字节豆包 四家国内顶尖互联网企业 精心打造的AI大模…

HBM简介

1、什么是HBM HBMHigh Bandwidth Memory 是一种用于某些 GPU的 3D 堆叠 DRAM存储器 &#xff08;动态随机存取存储器&#xff09;以及服务器、高性能计算 &#xff08;HPC&#xff09; 、网络连接的内存接口。其实就是将很多个DDR芯片堆叠在一起后和GPU封装在一起&#xff0c;实…

ROS socketcan_bridge使用说明

ROS socketcan_bridge使用说明&#xff08;以ubuntu20.04为例&#xff09; socketcan_bridge是什么 ROS针对socketcan提供了三个层次的驱动库&#xff0c;分别是ros_canopen&#xff0c;socketcan_bridge和socketcan_interface。 socketcan_interface&#xff1a; 功能&#x…

政安晨【零基础玩转各类开源AI项目】:解析开源项目:Champ 利用三维参数指导制作可控且一致的人体图像动画

目录 论文题目 Champ: 利用三维参数指导制作可控且一致的人体图像动画 安装 创建 conda 环境&#xff1a; 使用 pip 安装软件包 推理 1. 下载预训练模型 2. 准备准备引导动作数据 运行推理 训练模型 准备数据集 运行训练脚本 数据集 政安晨的个人主页&#xff1a;…

工业无线通信解决方案,企业在进行智能化升级改造

某大型制造企业在进行智能化升级改造,需要将分布在各个车间的数控机床、自动化生产线、AGV小车等设备连接到云端,实现设备的远程监控、数据采集分析等功能。之前工厂内部是用工业以太网连接,存在布线难、成本高、灵活性差等问题。 在了解客户需求后,我司星创易联的工程师建议客…

淘宝扭蛋机小程序,扭蛋市场创新模式

扭蛋机作为潮玩市场的娱乐消费方式&#xff0c;成为了当下消费者的新宠。扭蛋机凭借自身性价比高、商品多样、惊喜性等特点&#xff0c;吸引了各个年龄层的消费者&#xff0c;不仅年轻人喜欢&#xff0c;不少小学生和老年人也非常喜欢&#xff0c;扭蛋机市场迎来了快速发展期。…

简单聊下办公白环境

在当今信息化时代&#xff0c;办公环境对于工作效率和员工满意度有着至关重要的影响。而白名单作为一种网络安全策略&#xff0c;其是否适合办公环境&#xff0c;成为了许多企业和组织需要思考的问题。本文将从白名单的定义、特点及其在办公环境中的应用等方面&#xff0c;探讨…