在商品推荐系统中,粗排和精排环节的知识蒸馏方法主要通过复杂模型(Teacher)指导简单模型(Student)的训练,以提升粗排效果及与精排的一致性。本文将以淘宝的一篇论文《Privileged Features Distillation at Taobao Recommendations》中介绍的 PFD(Privileged Features Distillation)方法为例实现一个Demo,帮助读者学习知识蒸馏。
1.知识蒸馏方法概述
知识蒸馏诞生至今,早已不局限于粗排,而是在粗排和精排均有应用。粗排和精排的知识蒸馏核心在于通过不同形式的知识迁移(logits、排序结果、特征)提升模型效果与一致性。粗排侧重从精排获取排序偏好,而精排侧重模型压缩。实际应用中需结合业务场景选择蒸馏策略,并权衡性能与效果。本节简要介绍一下知识蒸馏方法。
一、粗排环节的典型蒸馏方法
粗排需平衡性能和效果,通常以精排为Teacher进行知识迁移,主要方法包括:
(1)Logits蒸馏
- 原理:利用精排模型的输出logits(未归一化的预测值)作为软标签(soft label),指导粗排模型学习。通过引入温度系数(Temperature Scaling)调整软标签的分布,增强非主导类别的信息传递。
- 损失函数:粗排模型的损失由两部分组成: Hard Loss:基于真实标签的交叉熵损失; Soft Loss:基于精排输出logits的KL散度或MSE损失。
- 应用:美团、爱奇艺等采用两阶段训练,先训练精排Teacher,再固定其参数指导粗排Student56。
(2)排序结果蒸馏
-
原理:直接利用精排输出的有序列表信息,构造粗排的训练样本。常见方法包括:
1. Point-wise:将精排Top-K结果作为正样本,其余作为负样本,并引入位置权重。2. Pair-wise:从精排列表中随机抽取商品对,学习偏序关系(如BPR损失)。、3. List-wise:通过NDCG等指标对齐粗排与精排的整体排序。
-
优势:缓解样本选择偏差,增强粗排对精排排序偏好的拟合。
(3)特征蒸馏
- 原理:迁移精排模型的中间层特征,要求粗排和精排的网络结构部分对齐。例如: 隐层特征对齐:通过MSE损失约束粗排与精排的隐层输出(如淘宝的 PFD(Privileged Features Distillation) 方法)。
- 优势特征蒸馏:将精排使用的交叉特征等“特权特征”迁移到粗排(如用户与商品的交互特征)。
- 应用:淘宝在 KDD 2020 提出的 PFD 方法中,精排 Teacher 使用交叉特征,粗排 Student 仅用基础特征,通过蒸馏提升效果。
二、精排环节的典型蒸馏方法
精排蒸馏主要用于模型压缩,将复杂模型(如集成模型)的能力迁移至轻量级模型:
(1)Logits蒸馏
- 原理:与粗排类似,使用复杂精排模型的 logits 指导轻量级 Student 模型训练。例如: 阿里 Rocket Launching 框架:Teacher 和 Student 共享 Embedding 层,联合训练并通过 logits 对齐。
- 改进:爱奇艺双 DNN 模型进一步约束 Student 隐层与 Teacher 隐层的激活值相似性。
(2)多目标蒸馏
- 原理:将精排的多任务输出(如CTR、CVR)迁移至 Student。例如: 腾讯在 SIGIR 2021 提出通过 KL 散度对齐多任务 logits,提升粗排/召回模型的多目标一致性。
- 损失设计:结合多任务损失和蒸馏损失,如加权交叉熵或对比学习损失。
三、关键技术与实践
(1)温度系数(Temperature)
调节 softmax 输出的平滑度,温度值越大,分布越平滑,帮助 Student 学习 Teacher 的暗知识(Dark Knowledge)。
(2)两阶段训练 vs 联合训练
- 两阶段:先独立训练 Teacher,再固定其参数指导 Student(稳定性高)。
- 联合训练:Teacher 和 Student 同步更新(减少耗时,但需设计梯度阻断防止相互干扰)。
(3)实际应用案例
- 美团:通过对比学习强化粗排与精排的特征对齐,粗排CTR提升 0.15%。
- 淘宝:优势特征蒸馏使粗排 CTR 提升 5%,精排CVR提升 2.3%。
- 腾讯音乐:多目标蒸馏在粗排阶段实现阅读时长与点击率的联合优化。
2. PFD(Privileged Features Distillation)方法介绍
PFD(Privileged Features Distillation)方法出自论文《Privileged Features Distillation at Taobao Recommendations》。论文中描述:在离线环境下同时训练两个模型:一个学生模型以及一个教师模型。其中学生模型和原始模型完全相同,而教师模型额外利用了优势特征, 其准确率也因此更高。通过将教师模型蒸馏出的知识(Knowlege, 本文特指教师模型中最后一层的输出)传递给学生模型,可以辅助其训练以进一步提升准确率。在线上服务时,我们只抽取学生模型进行部署,因为输入不依赖于优势特征,离线、在线的一致性得以保证。在 PFD 中,所有的优势特征都被统一到教师模型作为输入,加入更多的优势特征往往能带来模型更高的准确度。
PFD 不同于常见的模型蒸馏(Model Disitillation, 简称 MD)。 在 MD 中,教师模型和学生模型处理同样的输入特征,其中教师模型会比学生模型更为复杂, 比如,教师模型会用更深的网络结构来指导使用浅层网络的学生模型进行学习。在 PFD 中,教师和学生模型会使用相同网络结构,而处理不同的输入特征。MD 和 PFD 两者的差异如下图所示。
如上图所示:模型蒸馏(Model Distill, 简称 MD)与优势特征蒸馏(PFD)对比; 在 MD 中,知识(Knowledge)是从更复杂的模型中蒸馏出来,而在 PFD 中,知识是从优势特征中蒸馏出来。
由此可见,我们可以训练一个使用了复杂特征(如交叉特征)的模型作为老师,指导训练一个仅使用简单特征的学生模型,从而实现提升模型效果,而又不增加线上耗时(线上使用交叉特征等复杂特征通常会导致耗时大幅增加,因此,在粗排环节几乎不直接使用交叉特征)。
3.基于 Wide&Deep 指导训练 TowTower 模型
基于 PFD 方法的原理,在本节我们将实现一个知识蒸馏的 Demo。其中,Teacher 模型基于 Wide&Deep 模型;Student 模型则采用简单的“双塔模型”。为了简单起见,Wide&Deep 模型和 “双塔模型” 均为单目标(CTR )模型。
3.1 模拟数据构造
"""
Part-1:模拟数据构造本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、转化),并进行必要的预处
"""
# 设置随机种子保证可复现性
np.random.seed(42)
tf.random.set_seed(42)# 生成用户、商品和交互数据
num_users = 100
num_items = 200
num_interactions = 1000# 用户特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用户特征、商品特征和交互数据
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 划分数据集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)
3.2 特征工程
代码如下,相较于 Student 模型,作为 Teacher 的 Wide&Deep 模型采用了更多的特征,特别是交叉特征。
"""
Part-2:特征工程本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异2.类别型特征:需要进行 Embedding 处理3.交叉特征:由于维度高,需要哈希技巧处理高维组合特征
"""
# 用户特征处理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征处理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)"""
交叉特征预处理
"""
# 使用TensorFlow的交叉特征(crossed_column)定义了Wide部分的特征列,主要用于捕捉用户与商品特征之间的组合效应
# 将用户ID(user_id)和商品ID(item_id)组合成一个新特征,捕捉**“特定用户对特定商品的偏好”**
# 用户ID和商品ID的组合总数可能非常大(num_users * num_items),直接编码会导致维度爆炸。
# hash_bucket_size=10000:使用哈希函数将组合映射到固定数量的桶(10,000个),控制内存和计算开销,适用于稀疏高维特征(如用户-商品对)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)"""
特征列定义
"""
# ESMM 模型相关特征列定义
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相关特征列定义
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]
3.3 模型架构设计
Teacher 模型:采用 Wide&Deep 模型(模拟精排模型);Student 模型:采用普通 “双塔模型”(模拟粗排模型)。
"""
Part-3:模型架构设计
"""
# 教师模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:线性模型,擅长记忆(Memorization),通过交叉特征捕捉明确的特征组合模式(如用户A常点击商品B)。Deep部分:深度神经网络,擅长泛化(Generalization),通过嵌入向量学习特征的潜在关系(如女性用户与服装品类的关联)。结合优势:同时处理稀疏特征(如用户ID、商品ID)和密集特征(如价格、年龄),平衡记忆与泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(线性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神经网络)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:预测CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:预测CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 将Wide和Deep的logits相加,通过Sigmoid输出点击概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 学生模型:采用普通双塔模型
class TowTowerStudent(tf.keras.Model):"""普通双塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征处理层self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 双塔结构user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR预测# 点积交互(即用户Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}
3.4 知识蒸馏实现
本质上就是用 Teacher 模型指导 Student 模型训练。使得 Student 模型的预测结果逼近 Teacher 模型。
"""
Part-4:知识蒸馏实现
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 温度参数:典型取值2-5之间self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理时直接使用学生模型return self.student(inputs)def train_step(self, data):# 解包数据x, y = data# 教师模型前向传播(仅推理)teacher_predictions = self.teacher(x, training=False) # 冻结教师模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape实现动态梯度计算with tf.GradientTape() as tape:# 学生模型前向传播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 计算学生损失# 学生损失(student_loss):直接拟合真实标签# y['ctr_logits'] = labels['click_label'],在输入数据时有定义student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 计算蒸馏损失distillation_loss_ctr = self.distillation_loss_fn(# 蒸馏损失(distillation_loss):学习教师模型的软标签分布teacher_ctr / self.temperature, # 教师输出软化student_ctr / self.temperature # 学生输出对齐)# 总损失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 计算梯度并更新(仅更新学生参数)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指标self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}
3.5 模型训练与评估
- 第一步:数据准备;
- 第二步:模型初始化;
- 第三步:编译、训练 Teacher 模型;
- 第四步:编译、训练 Student 模型;
- 第五步:评估、可视化效果
"""
Part-5:模型训练与评估
"""
# 数据输入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 这里做了一个映射,主要为了对齐学生模型和教师模型的输出,从而便于计算损失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 转换数据集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 编译教师模型(先单独训练)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3] # 可选:设置不同任务的损失权重
)# 训练教师模型
print("训练教师模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 编译蒸馏模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 训练学生模型(带蒸馏)
print("训练学生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可视化训练过程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
3.6 模型服务化与测试
保存训练好的学生模型,在另一个工程中可以加载这个模型,并执行预测。
"""
Part-6:模型服务化(示例)
"""
# 保存学生模型
student.save('esmm_student_model')# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型输入层名称
loaded_model.summary()# 示例预测:从 test_features 数据框中提取第一行数据
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"预测结果:CTR={predictions['ctr_logits'][0][0]:.3f}")
3.7 知识蒸馏完整代码
完整代码如下:
import tensorflow as tftf.config.set_visible_devices([], 'GPU') # 禁用GPU设备
from tensorflow import feature_column
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler"""
Part-1:模拟数据构造本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、转化),并进行必要的预处
"""
# 设置随机种子保证可复现性
np.random.seed(42)
tf.random.set_seed(42)# 生成用户、商品和交互数据
num_users = 100
num_items = 200
num_interactions = 1000# 用户特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用户特征、商品特征和交互数据
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 划分数据集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)"""
Part-2:特征工程本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异2.类别型特征:需要进行 Embedding 处理3.交叉特征:由于维度高,需要哈希技巧处理高维组合特征
"""
# 用户特征处理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征处理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)"""
交叉特征预处理
"""
# 使用TensorFlow的交叉特征(crossed_column)定义了Wide部分的特征列,主要用于捕捉用户与商品特征之间的组合效应
# 将用户ID(user_id)和商品ID(item_id)组合成一个新特征,捕捉**“特定用户对特定商品的偏好”**
# 用户ID和商品ID的组合总数可能非常大(num_users * num_items),直接编码会导致维度爆炸。
# hash_bucket_size=10000:使用哈希函数将组合映射到固定数量的桶(10,000个),控制内存和计算开销,适用于稀疏高维特征(如用户-商品对)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)"""
特征列定义
"""
# ESMM 模型相关特征列定义
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相关特征列定义
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]"""
Part-3:模型架构设计
"""
# 教师模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:线性模型,擅长记忆(Memorization),通过交叉特征捕捉明确的特征组合模式(如用户A常点击商品B)。Deep部分:深度神经网络,擅长泛化(Generalization),通过嵌入向量学习特征的潜在关系(如女性用户与服装品类的关联)。结合优势:同时处理稀疏特征(如用户ID、商品ID)和密集特征(如价格、年龄),平衡记忆与泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(线性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神经网络)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:预测CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:预测CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 将Wide和Deep的logits相加,通过Sigmoid输出点击概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 学生模型:采用普通双塔模型
class TowTowerStudent(tf.keras.Model):"""普通双塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征处理层self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 双塔结构user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR预测# 点积交互(即用户Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}"""
Part-4:知识蒸馏实现
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 温度参数:典型取值2-5之间self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理时直接使用学生模型return self.student(inputs)def train_step(self, data):# 解包数据x, y = data# 教师模型前向传播(仅推理)teacher_predictions = self.teacher(x, training=False) # 冻结教师模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape实现动态梯度计算with tf.GradientTape() as tape:# 学生模型前向传播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 计算学生损失# 学生损失(student_loss):直接拟合真实标签# y['ctr_logits'] = labels['click_label'],在输入数据时有定义student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 计算蒸馏损失distillation_loss_ctr = self.distillation_loss_fn(# 蒸馏损失(distillation_loss):学习教师模型的软标签分布teacher_ctr / self.temperature, # 教师输出软化student_ctr / self.temperature # 学生输出对齐)# 总损失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 计算梯度并更新(仅更新学生参数)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指标self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}"""
Part-5:模型训练与评估
"""
# 数据输入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 这里做了一个映射,主要为了对齐学生模型和教师模型的输出,从而便于计算损失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 转换数据集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 编译教师模型(先单独训练)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3] # 可选:设置不同任务的损失权重
)# 训练教师模型
print("训练教师模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 编译蒸馏模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 训练学生模型(带蒸馏)
print("训练学生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可视化训练过程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()"""
Part-6:模型服务化(示例)
"""
# 保存学生模型
student.save('esmm_student_model')# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型输入层名称
loaded_model.summary()# 示例预测:从 test_features 数据框中提取第一行数据
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"预测结果:CTR={predictions['ctr_logits'][0][0]:.3f}")
3.8 运行效果
Teacher 模型训练过程:
训练教师模型...
Epoch 1/5
2025-03-30 21:41:55.398982: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
25/25 [==============================] - 2s 13ms/step - loss: 0.5838 - accuracy: 0.5013 - val_loss: 0.5115 - val_accuracy: 0.4850
Epoch 2/5
25/25 [==============================] - 0s 3ms/step - loss: 0.5049 - accuracy: 0.5013 - val_loss: 0.5101 - val_accuracy: 0.4850
Epoch 3/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5037 - accuracy: 0.5013 - val_loss: 0.5093 - val_accuracy: 0.4850
Epoch 4/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5026 - accuracy: 0.5013 - val_loss: 0.5085 - val_accuracy: 0.4850
Epoch 5/5
25/25 [==============================] - 0s 5ms/step - loss: 0.5014 - accuracy: 0.5013 - val_loss: 0.5077 - val_accuracy: 0.4850
Student 模型训练过程:
训练学生模型...
Epoch 1/10
25/25 [==============================] - 2s 11ms/step - accuracy: 0.4975 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 2/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5038 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 3/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5063 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 4/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5050 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 5/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 6/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 7/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 8/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 9/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 10/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5088 - val_loss: 0.0000e+00 - val_accuracy: 0.4900
模型结构及预测示例:
Model: "tow_tower_student"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================dense_features_2 (DenseFeat multiple 23706 ures) dense_features_3 (DenseFeat multiple 1620 ures) sequential_1 (Sequential) (None, 32) 4000 sequential_2 (Sequential) (None, 32) 2976 dense_8 (Dense) multiple 2 =================================================================
Total params: 32,304
Trainable params: 32,304
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 0s 178ms/step
预测结果:CTR=0.526
可视化训练过程:
3.9 模型预测
在另一个工程中加载通过蒸馏训练好的 Student 模型,并执行预测,代码示例如下:
# 导入必要的库
import tensorflow as tf
import pandas as pd
import numpy as np# 人工构造数据
num_users = 100
num_items = 200# 重新生成新的样本,模拟真实数据进行预测
def generate_new_samples(num_samples=5):new_samples = []for _ in range(num_samples):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)user_age = np.random.randint(18, 65)user_gender = np.random.choice(['male', 'female'])user_occupation = np.random.choice(['student', 'worker', 'teacher'])city_code = np.random.randint(1, 2856)device_type = np.random.randint(0, 5)item_category = np.random.choice(['electronics', 'books', 'clothing'])item_brand = np.random.choice(['brandA', 'brandB', 'brandC'])item_price = np.random.randint(1, 199)new_samples.append({'user_id': user_id,'user_age': user_age,'user_gender': user_gender,'user_occupation': user_occupation,'city_code': city_code,'device_type': device_type,'item_id': item_id,'item_category': item_category,'item_brand': item_brand,'item_price': item_price})return pd.DataFrame(new_samples)# 生成并打印预览新的样本数据
new_samples = generate_new_samples(num_samples=5)
# 设置display.max_columns为None,强制显示全部列:
pd.set_option('display.max_columns', None)
print("\nGenerated New Samples:\n", new_samples)# 准备输入数据
input_dict = {'user_id': tf.convert_to_tensor(new_samples['user_id'].values, dtype=tf.int64),'user_age': tf.convert_to_tensor(new_samples['user_age'].values, dtype=tf.int64),'user_gender': tf.convert_to_tensor(new_samples['user_gender'].values, dtype=tf.string),'user_occupation': tf.convert_to_tensor(new_samples['user_occupation'].values, dtype=tf.string),'city_code': tf.convert_to_tensor(new_samples['city_code'].values, dtype=tf.int64),'device_type': tf.convert_to_tensor(new_samples['device_type'].values, dtype=tf.int64),'item_id': tf.convert_to_tensor(new_samples['item_id'].values, dtype=tf.int64),'item_category': tf.convert_to_tensor(new_samples['item_category'].values, dtype=tf.string),'item_brand': tf.convert_to_tensor(new_samples['item_brand'].values, dtype=tf.string),'item_price': tf.convert_to_tensor(new_samples['item_price'].values, dtype=tf.int64)
}# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')
# 明确使用默认签名
predict_fn = loaded_model.signatures['serving_default']
predictions = predict_fn(**input_dict)# 提取并打印预测结果
# 预测结果是一个 CTCVR 综合分
predicted_ctr = predictions['ctr_logits'].numpy().flatten()
new_samples['ctr_prob'] = predicted_ctr
print("\nPrediction Results:")
for idx, row in new_samples.iterrows():print(f"Item ID: {row['item_id']} | CTR Final Score: {row['ctr_prob']:.4f}")
运行结果如下:
Generated New Samples:user_id user_age user_gender user_occupation city_code device_type \
0 34 49 female teacher 843 0
1 15 30 female student 564 3
2 26 37 male teacher 2229 0
3 31 35 male worker 2494 0
4 41 57 female student 1668 3 item_id item_category item_brand item_price
0 147 electronics brandA 127
1 196 clothing brandC 190
2 1 books brandA 1
3 150 clothing brandA 5
4 128 electronics brandA 156
Metal device set to: Apple M1 ProPrediction Results:
Item ID: 147 | CTR Final Score: 0.5263
Item ID: 196 | CTR Final Score: 0.5263
Item ID: 1 | CTR Final Score: 0.5263
Item ID: 150 | CTR Final Score: 0.4793
Item ID: 128 | CTR Final Score: 0.5263