ctr/cvr预估之DeepFM模型
在数字营销的浪潮中,点击率(CTR)和转化率(CVR)预估已成为精准广告投放和个性化推荐系统的核心。随着深度学习技术的蓬勃发展,传统的机器学习方法,如逻辑回归和因子分解机(FM),虽然在某些方面表现出色,但在处理高维稀疏数据和捕捉复杂特征交互方面逐渐显现出局限性。为了突破这一瓶颈,DeepFM模型应运而生,它巧妙地结合了深度学习与因子分解的思想,旨在更深层次地挖掘用户行为背后的模式。
文章目录
- ctr/cvr预估之DeepFM模型
- 一、什么是DeepFM模型
- 二、DeepFM模型提出背景
- 三、DeepFM模型原理
- 四、DeepFM模型注意事项
- 五、DeepFM模型的核心参数
- 六、DeepFM模型实现代码
一、什么是DeepFM模型
DeepFM模型是一种结合了因子分解机(Factorization Machines, FM)和深度神经网络(Deep Neural Network, DNN)的推荐系统模型,它专门设计用于处理点击率(CTR)预估问题。DeepFM模型的核心思想是同时捕获低阶和高阶的特征交互,这在提高模型预测准确性方面非常关键。
二、DeepFM模型提出背景
DeepFM模型的提出背景主要是为了解决传统推荐系统中特征工程复杂度高、难以处理大规模数据等问题。随着大数据时代的到来,推荐系统成为互联网应用中不可或缺的一部分,但同时也面临着新的挑战。传统的推荐系统在特征处理上通常需要大量的手工特征工程,这不仅增加了系统的复杂性,也限制了模型处理大规模稀疏数据的能力。
为了克服这些挑战,研究者们提出了多种深度学习模型,DeepFM便是其中的佼佼者。它是基于Wide & Deep模型的改进和提升,将因子分解机(FM)和深度神经网络(DNN)相结合,通过同时学习低阶和高阶特征交互,实现了强大的推荐功能。DeepFM模型通过FM部分学习低阶特征交互,而DNN部分则负责学习高阶特征交互,两者共享相同的输入特征,从而简化了特征工程。
此外,DeepFM模型的提出也是为了充分利用深度学习在自动特征学习方面的优势,同时保留FM在处理稀疏数据和捕获特征交互方面的高效性。通过这种结合,DeepFM能够自动学习数据中的复杂模式,并提供更加精准的推荐结果。模型的端到端学习能力和不依赖于特征工程的特点,使其在CTR预估领域得到了广泛应用。
三、DeepFM模型原理
DeepFM模型结构
DeepFM模型由Wide&Deep模型演化而来,在Wide部分使用FM代替了Wide&Deep中的LR模型,有了FM自动构造学习二阶(考虑到时间复杂度原因,通常都是二阶)交叉特征的能力,因此不再需要特征工程。Wide&Deep模型中LR部分依然需要人工的特征交叉,比如【用户已安装的app】与【给用户曝光的app】两个特征做交叉。并且在DeepFM模型中,FM模型与DNN模型共享底层Embedding向量。
- 因子分解机(FM)部分
线性交互:FM的第一部分是线性回归部分,它处理所有特征的一阶线性关系。
二阶交互:FM的核心在于它可以有效地计算任意两个特征间的二阶交互,而不需要显式地为每对特征创建交互项。它通过将每个特征映射到一个低维空间来实现,交互项的权重是通过这些低维向量的点积来计算的。这种方法使得模型能够捕捉到特征间的相互作用,同时保持参数数量和计算复杂度相对较低。
- 深度神经网络(DNN)部分
高阶交互:深度网络部分能够学习输入特征的高阶交互关系。与传统的多层感知机(MLP)相似,这一部分由多个隐藏层组成,每一层都是前一层输出的非线性变换。
特征嵌入共享:DeepFM的一个关键创新是FM部分和DNN部分共享相同的特征嵌入。这意味着输入到FM和DNN的特征表示是相同的,这样可以有效地减少模型需要学习的参数量,并且使得学习到的特征表示在两个模型部分中都是一致的。
- 结合FM和DNN
输出层:DeepFM模型的输出是FM部分和DNN部分的结果的组合。这通常通过将两部分的输出求和或连接后通过一个或多个全连接层来实现。
优化和训练:整个模型可以端到端地进行训练,通常使用如梯度下降的优化算法来最小化预测误差(如二元交叉熵损失)。
四、DeepFM模型注意事项
对于长度不一致的特征,FM模型通过将这些特征转换为固定长度的向量来处理它们之间的交叉项,通常通过特征的嵌入(Embedding)实现。
对于多特征的场景,一般会将各个特征嵌入到相同的维度,然后进行拼接,拼接后的总维度就是各个嵌入维度之和,在FM层处理时,关注的是处理后的嵌入特征,而非原始的输入维度。
FMLayer层的关键在于计算两个部分:一是所有嵌入向量的和的平方,二是所有嵌入向量的平方的和。然后,将前者减去后者,乘以0.5得到交叉项。
五、DeepFM模型的核心参数
DeepFM模型的核心参数
六、DeepFM模型实现代码
DeepFM模型实现代码