昆明著名网站建设/百度指数名词解释

昆明著名网站建设,百度指数名词解释,桂林微物网络科技有限公司,广告制作费和广告服务费区别在商品推荐系统中,粗排和精排环节的知识蒸馏方法主要通过复杂模型(Teacher)指导简单模型(Student)的训练,以提升粗排效果及与精排的一致性。本文将以淘宝的一篇论文《Privileged Features Distillation at …

在商品推荐系统中,粗排和精排环节的知识蒸馏方法主要通过复杂模型(Teacher)指导简单模型(Student)的训练,以提升粗排效果及与精排的一致性。本文将以淘宝的一篇论文《Privileged Features Distillation at Taobao Recommendations》中介绍的 PFD(Privileged Features Distillation)方法为例实现一个Demo,帮助读者学习知识蒸馏。

1.知识蒸馏方法概述

知识蒸馏诞生至今,早已不局限于粗排,而是在粗排和精排均有应用。粗排和精排的知识蒸馏核心在于通过不同形式的知识迁移(logits、排序结果、特征)提升模型效果与一致性。粗排侧重从精排获取排序偏好,而精排侧重模型压缩。实际应用中需结合业务场景选择蒸馏策略,并权衡性能与效果。本节简要介绍一下知识蒸馏方法。

一、粗排环节的典型蒸馏方法

粗排需平衡性能和效果,通常以精排为Teacher进行知识迁移,主要方法包括:

(1)Logits蒸馏

  • 原理:利用精排模型的输出logits(未归一化的预测值)作为软标签(soft label),指导粗排模型学习。通过引入温度系数(Temperature Scaling)调整软标签的分布,增强非主导类别的信息传递。
  • 损失函数:粗排模型的损失由两部分组成: Hard Loss:基于真实标签的交叉熵损失; Soft Loss:基于精排输出logits的KL散度或MSE损失。
  • 应用:美团、爱奇艺等采用两阶段训练,先训练精排Teacher,再固定其参数指导粗排Student56。

(2)排序结果蒸馏

  • 原理:直接利用精排输出的有序列表信息,构造粗排的训练样本。常见方法包括:

     1. Point-wise:将精排Top-K结果作为正样本,其余作为负样本,并引入位置权重。2. Pair-wise:从精排列表中随机抽取商品对,学习偏序关系(如BPR损失)。、3. List-wise:通过NDCG等指标对齐粗排与精排的整体排序。 
    
  • 优势:缓解样本选择偏差,增强粗排对精排排序偏好的拟合。

(3)特征蒸馏

  • 原理:迁移精排模型的中间层特征,要求粗排和精排的网络结构部分对齐。例如: 隐层特征对齐:通过MSE损失约束粗排与精排的隐层输出(如淘宝的 PFD(Privileged Features Distillation) 方法)。
  • 优势特征蒸馏:将精排使用的交叉特征等“特权特征”迁移到粗排(如用户与商品的交互特征)。
  • 应用:淘宝在 KDD 2020 提出的 PFD 方法中,精排 Teacher 使用交叉特征,粗排 Student 仅用基础特征,通过蒸馏提升效果。

二、精排环节的典型蒸馏方法

精排蒸馏主要用于模型压缩,将复杂模型(如集成模型)的能力迁移至轻量级模型:

(1)Logits蒸馏

  • 原理:与粗排类似,使用复杂精排模型的 logits 指导轻量级 Student 模型训练。例如: 阿里 Rocket Launching 框架:Teacher 和 Student 共享 Embedding 层,联合训练并通过 logits 对齐。
  • 改进:爱奇艺双 DNN 模型进一步约束 Student 隐层与 Teacher 隐层的激活值相似性。

(2)多目标蒸馏

  • 原理:将精排的多任务输出(如CTR、CVR)迁移至 Student。例如: 腾讯在 SIGIR 2021 提出通过 KL 散度对齐多任务 logits,提升粗排/召回模型的多目标一致性。
  • 损失设计:结合多任务损失和蒸馏损失,如加权交叉熵或对比学习损失。

三、关键技术与实践

(1)温度系数(Temperature)

调节 softmax 输出的平滑度,温度值越大,分布越平滑,帮助 Student 学习 Teacher 的暗知识(Dark Knowledge)。

(2)两阶段训练 vs 联合训练

  • 两阶段:先独立训练 Teacher,再固定其参数指导 Student(稳定性高)。
  • 联合训练:Teacher 和 Student 同步更新(减少耗时,但需设计梯度阻断防止相互干扰)。

(3)实际应用案例

  • 美团:通过对比学习强化粗排与精排的特征对齐,粗排CTR提升 0.15%。
  • 淘宝:优势特征蒸馏使粗排 CTR 提升 5%,精排CVR提升 2.3%。
  • 腾讯音乐:多目标蒸馏在粗排阶段实现阅读时长与点击率的联合优化。

2. PFD(Privileged Features Distillation)方法介绍

PFD(Privileged Features Distillation)方法出自论文《Privileged Features Distillation at Taobao Recommendations》。论文中描述:在离线环境下同时训练两个模型:一个学生模型以及一个教师模型。其中学生模型和原始模型完全相同,而教师模型额外利用了优势特征, 其准确率也因此更高。通过将教师模型蒸馏出的知识(Knowlege, 本文特指教师模型中最后一层的输出)传递给学生模型,可以辅助其训练以进一步提升准确率。在线上服务时,我们只抽取学生模型进行部署,因为输入不依赖于优势特征,离线、在线的一致性得以保证。在 PFD 中,所有的优势特征都被统一到教师模型作为输入,加入更多的优势特征往往能带来模型更高的准确度。

PFD 不同于常见的模型蒸馏(Model Disitillation, 简称 MD)。 在 MD 中,教师模型和学生模型处理同样的输入特征,其中教师模型会比学生模型更为复杂, 比如,教师模型会用更深的网络结构来指导使用浅层网络的学生模型进行学习。在 PFD 中,教师和学生模型会使用相同网络结构,而处理不同的输入特征。MD 和 PFD 两者的差异如下图所示。
在这里插入图片描述

如上图所示:模型蒸馏(Model Distill, 简称 MD)与优势特征蒸馏(PFD)对比; 在 MD 中,知识(Knowledge)是从更复杂的模型中蒸馏出来,而在 PFD 中,知识是从优势特征中蒸馏出来。

由此可见,我们可以训练一个使用了复杂特征(如交叉特征)的模型作为老师,指导训练一个仅使用简单特征的学生模型,从而实现提升模型效果,而又不增加线上耗时(线上使用交叉特征等复杂特征通常会导致耗时大幅增加,因此,在粗排环节几乎不直接使用交叉特征)。

3.基于 Wide&Deep 指导训练 TowTower 模型

基于 PFD 方法的原理,在本节我们将实现一个知识蒸馏的 Demo。其中,Teacher 模型基于 Wide&Deep 模型;Student 模型则采用简单的“双塔模型”。为了简单起见,Wide&Deep 模型和 “双塔模型” 均为单目标(CTR )模型。

3.1 模拟数据构造

"""
Part-1:模拟数据构造本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、转化),并进行必要的预处
"""
# 设置随机种子保证可复现性
np.random.seed(42)
tf.random.set_seed(42)# 生成用户、商品和交互数据
num_users = 100
num_items = 200
num_interactions = 1000# 用户特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用户特征、商品特征和交互数据
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 划分数据集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)

3.2 特征工程

代码如下,相较于 Student 模型,作为 Teacher 的 Wide&Deep 模型采用了更多的特征,特别是交叉特征。

"""
Part-2:特征工程本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异2.类别型特征:需要进行 Embedding 处理3.交叉特征:由于维度高,需要哈希技巧处理高维组合特征
"""
# 用户特征处理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征处理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)""" 
交叉特征预处理 
"""
# 使用TensorFlow的交叉特征(crossed_column)定义了Wide部分的特征列,主要用于捕捉用户与商品特征之间的组合效应
# 将用户ID(user_id)和商品ID(item_id)组合成一个新特征,捕捉**“特定用户对特定商品的偏好”**
# 用户ID和商品ID的组合总数可能非常大(num_users * num_items),直接编码会导致维度爆炸。
# hash_bucket_size=10000:使用哈希函数将组合映射到固定数量的桶(10,000个),控制内存和计算开销,适用于稀疏高维特征(如用户-商品对)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)""" 
特征列定义 
"""
# ESMM 模型相关特征列定义
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相关特征列定义
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]

3.3 模型架构设计

Teacher 模型:采用 Wide&Deep 模型(模拟精排模型);Student 模型:采用普通 “双塔模型”(模拟粗排模型)。

"""
Part-3:模型架构设计
"""
# 教师模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:线性模型,擅长记忆(Memorization),通过交叉特征捕捉明确的特征组合模式(如用户A常点击商品B)。Deep部分:深度神经网络,擅长泛化(Generalization),通过嵌入向量学习特征的潜在关系(如女性用户与服装品类的关联)。结合优势:同时处理稀疏特征(如用户ID、商品ID)和密集特征(如价格、年龄),平衡记忆与泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(线性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神经网络)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:预测CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:预测CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 将Wide和Deep的logits相加,通过Sigmoid输出点击概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 学生模型:采用普通双塔模型
class TowTowerStudent(tf.keras.Model):"""普通双塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征处理层self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 双塔结构user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR预测# 点积交互(即用户Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}

3.4 知识蒸馏实现

本质上就是用 Teacher 模型指导 Student 模型训练。使得 Student 模型的预测结果逼近 Teacher 模型。

"""
Part-4:知识蒸馏实现
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 温度参数:典型取值2-5之间self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理时直接使用学生模型return self.student(inputs)def train_step(self, data):# 解包数据x, y = data# 教师模型前向传播(仅推理)teacher_predictions = self.teacher(x, training=False)  # 冻结教师模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape实现动态梯度计算with tf.GradientTape() as tape:# 学生模型前向传播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 计算学生损失# 学生损失(student_loss):直接拟合真实标签# y['ctr_logits'] = labels['click_label'],在输入数据时有定义student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 计算蒸馏损失distillation_loss_ctr = self.distillation_loss_fn(# 蒸馏损失(distillation_loss):学习教师模型的软标签分布teacher_ctr / self.temperature,  # 教师输出软化student_ctr / self.temperature  # 学生输出对齐)# 总损失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 计算梯度并更新(仅更新学生参数)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指标self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}

3.5 模型训练与评估

  • 第一步:数据准备;
  • 第二步:模型初始化;
  • 第三步:编译、训练 Teacher 模型;
  • 第四步:编译、训练 Student 模型;
  • 第五步:评估、可视化效果
"""
Part-5:模型训练与评估
"""
# 数据输入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 这里做了一个映射,主要为了对齐学生模型和教师模型的输出,从而便于计算损失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 转换数据集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 编译教师模型(先单独训练)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3]  # 可选:设置不同任务的损失权重
)# 训练教师模型
print("训练教师模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 编译蒸馏模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 训练学生模型(带蒸馏)
print("训练学生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可视化训练过程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

3.6 模型服务化与测试

保存训练好的学生模型,在另一个工程中可以加载这个模型,并执行预测。

"""
Part-6:模型服务化(示例)
"""
# 保存学生模型
student.save('esmm_student_model')# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型输入层名称
loaded_model.summary()# 示例预测:从 test_features 数据框中提取第一行数据
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"预测结果:CTR={predictions['ctr_logits'][0][0]:.3f}")

3.7 知识蒸馏完整代码

完整代码如下:

import tensorflow as tftf.config.set_visible_devices([], 'GPU')  # 禁用GPU设备
from tensorflow import feature_column
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler"""
Part-1:模拟数据构造本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、转化),并进行必要的预处
"""
# 设置随机种子保证可复现性
np.random.seed(42)
tf.random.set_seed(42)# 生成用户、商品和交互数据
num_users = 100
num_items = 200
num_interactions = 1000# 用户特征
user_data = {'user_id': np.arange(1, num_users + 1),'user_age': np.random.randint(18, 65, size=num_users),'user_gender': np.random.choice(['male', 'female'], size=num_users),'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),'city_code': np.random.randint(1, 2856, size=num_users),'device_type': np.random.randint(0, 5, size=num_users)
}# 商品特征
item_data = {'item_id': np.arange(1, num_items + 1),'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),'item_price': np.random.randint(1, 199, size=num_items)
}# 交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)# 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据click_label = np.random.randint(0, 2)interactions.append([user_id, item_id, click_label])# 合并用户特征、商品特征和交互数据
interaction_df = pd.DataFrame(interactions, columns=['user_id', 'item_id', 'click_label'])
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')# 划分数据集
labels = df[['click_label']]
features = df.drop(['click_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=42)"""
Part-2:特征工程本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异2.类别型特征:需要进行 Embedding 处理3.交叉特征:由于维度高,需要哈希技巧处理高维组合特征
"""
# 用户特征处理
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users + 1)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')user_gender = feature_column.categorical_column_with_vocabulary_list('user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)user_occupation = feature_column.categorical_column_with_vocabulary_list('user_occupation',['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)# 商品特征处理
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items + 1)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')item_category = feature_column.categorical_column_with_vocabulary_list('item_category',['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)item_brand = feature_column.categorical_column_with_vocabulary_list('item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)""" 
交叉特征预处理 
"""
# 使用TensorFlow的交叉特征(crossed_column)定义了Wide部分的特征列,主要用于捕捉用户与商品特征之间的组合效应
# 将用户ID(user_id)和商品ID(item_id)组合成一个新特征,捕捉**“特定用户对特定商品的偏好”**
# 用户ID和商品ID的组合总数可能非常大(num_users * num_items),直接编码会导致维度爆炸。
# hash_bucket_size=10000:使用哈希函数将组合映射到固定数量的桶(10,000个),控制内存和计算开销,适用于稀疏高维特征(如用户-商品对)
user_id_x_item_id = feature_column.crossed_column([user_id, item_id], hash_bucket_size=10000)
user_id_x_item_id = feature_column.indicator_column(user_id_x_item_id)
user_gender_x_item_category = feature_column.crossed_column([user_gender, item_category], hash_bucket_size=1000)
user_gender_x_item_category = feature_column.indicator_column(user_gender_x_item_category)
user_occupation_x_item_brand = feature_column.crossed_column([user_occupation, item_brand], hash_bucket_size=1000)
user_occupation_x_item_brand = feature_column.indicator_column(user_occupation_x_item_brand)""" 
特征列定义 
"""
# ESMM 模型相关特征列定义
user_tower_columns = [user_id_emb, user_age, user_gender_emb, user_occupation_emb, city_code_emb, device_types_emb]
item_tower_columns = [item_id_emb, item_category_emb, item_brand_emb, item_price]# Wide&Deep 模型相关特征列定义
deep_feature_columns = [user_id_emb,user_age,user_gender_emb,user_occupation_emb,item_id_emb,item_category_emb,item_brand_emb,item_price
]wide_feature_columns = [user_id_x_item_id,user_gender_x_item_category,user_occupation_x_item_brand
]"""
Part-3:模型架构设计
"""
# 教师模型:采用 Wide&Deep 模型
class WideDeepModel(tf.keras.Model):"""Wide部分:线性模型,擅长记忆(Memorization),通过交叉特征捕捉明确的特征组合模式(如用户A常点击商品B)。Deep部分:深度神经网络,擅长泛化(Generalization),通过嵌入向量学习特征的潜在关系(如女性用户与服装品类的关联)。结合优势:同时处理稀疏特征(如用户ID、商品ID)和密集特征(如价格、年龄),平衡记忆与泛化能力"""def __init__(self, wide_feature_columns, deep_feature_columns):super(WideDeepModel, self).__init__()# Wide部分(线性模型)self.linear_features = tf.keras.layers.DenseFeatures(wide_feature_columns)self.wide_out = tf.keras.layers.Dense(1, activation='sigmoid')# Deep部分(深度神经网络)self.dnn_features = tf.keras.layers.DenseFeatures(deep_feature_columns)self.dnn_layer = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu')])self.deep_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# Wide部分:预测CTRlinear_features = self.linear_features(inputs)ctr_wide_logits = self.wide_out(linear_features)# Deep部分:预测CTRdnn_features = self.dnn_features(inputs)dnn_layer = self.dnn_layer(dnn_features)ctr_deep_logits = self.deep_out(dnn_layer)# 将Wide和Deep的logits相加,通过Sigmoid输出点击概率ctr_logits = tf.sigmoid(ctr_wide_logits + ctr_deep_logits)# 返回return {'ctr_logits': ctr_logits}# 学生模型:采用普通双塔模型
class TowTowerStudent(tf.keras.Model):"""普通双塔模型:User Tower + Item Tower"""def __init__(self, user_columns, item_columns):super(TowTowerStudent, self).__init__()# 共享特征处理层self.user_feature = tf.keras.layers.DenseFeatures(user_columns)self.item_feature = tf.keras.layers.DenseFeatures(item_columns)# User塔self.user_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])# Item塔self.item_tower = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),])self.tower_out = tf.keras.layers.Dense(1, activation='sigmoid')def call(self, inputs):# 双塔结构user_feature = self.user_feature(inputs)item_feature = self.item_feature(inputs)user_emb = self.user_tower(user_feature)item_emb = self.item_tower(item_feature)# CTR预测# 点积交互(即用户Embedding和商品Embedding求取余弦相似度)interaction = tf.keras.layers.Dot(axes=1)([user_emb, item_emb])ctr_logits = self.tower_out(interaction)return {'ctr_logits': ctr_logits}"""
Part-4:知识蒸馏实现
"""
class DistillationModel(tf.keras.Model):def __init__(self, teacher, student):super(DistillationModel, self).__init__()self.teacher = teacherself.student = student# 温度参数:典型取值2-5之间self.temperature = 2.0def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn):super().compile(optimizer=optimizer, metrics=metrics)self.student_loss_fn = student_loss_fnself.distillation_loss_fn = distillation_loss_fndef call(self, inputs):# 推理时直接使用学生模型return self.student(inputs)def train_step(self, data):# 解包数据x, y = data# 教师模型前向传播(仅推理)teacher_predictions = self.teacher(x, training=False)  # 冻结教师模型teacher_ctr = teacher_predictions['ctr_logits']# 使用tf.GradientTape实现动态梯度计算with tf.GradientTape() as tape:# 学生模型前向传播student_outputs = self.student(x, training=True)student_ctr = student_outputs['ctr_logits']# 计算学生损失# 学生损失(student_loss):直接拟合真实标签# y['ctr_logits'] = labels['click_label'],在输入数据时有定义student_loss_ctr = self.student_loss_fn(y['ctr_logits'], student_ctr)# 计算蒸馏损失distillation_loss_ctr = self.distillation_loss_fn(# 蒸馏损失(distillation_loss):学习教师模型的软标签分布teacher_ctr / self.temperature,  # 教师输出软化student_ctr / self.temperature  # 学生输出对齐)# 总损失total_loss = 0.7 * student_loss_ctr + 0.3 * distillation_loss_ctr# 计算梯度并更新(仅更新学生参数)trainable_vars = self.student.trainable_variablesgradients = tape.gradient(total_loss, trainable_vars)self.optimizer.apply_gradients(zip(gradients, trainable_vars))# 更新指标self.compiled_metrics.update_state(y, {'ctr_logits': student_ctr,})return {m.name: m.result() for m in self.metrics}"""
Part-5:模型训练与评估
"""
# 数据输入管道
def df_to_dataset(features, labels, shuffle=True, batch_size=32):ds = tf.data.Dataset.from_tensor_slices((dict(features),{# 这里做了一个映射,主要为了对齐学生模型和教师模型的输出,从而便于计算损失'ctr_logits': labels['click_label']}))if shuffle:ds = ds.shuffle(1000)ds = ds.batch(batch_size)return ds# 转换数据集
train_ds = df_to_dataset(train_features, train_labels)
test_ds = df_to_dataset(test_features, test_labels, shuffle=False)# 初始化模型
teacher = WideDeepModel(wide_feature_columns, deep_feature_columns)
student = TowTowerStudent(user_tower_columns, item_tower_columns)
distiller = DistillationModel(teacher, student)# 编译教师模型(先单独训练)
teacher.compile(optimizer='adam',loss={'ctr_logits': 'binary_crossentropy'},metrics=['accuracy'],loss_weights=[0.7, 0.3]  # 可选:设置不同任务的损失权重
)# 训练教师模型
print("训练教师模型...")
teacher.fit(train_ds, epochs=5, validation_data=test_ds)# 编译蒸馏模型
distiller.compile(optimizer='adam',metrics={'ctr_logits': ['accuracy']},student_loss_fn=tf.keras.losses.BinaryCrossentropy(),distillation_loss_fn=tf.keras.losses.KLDivergence()
)# 训练学生模型(带蒸馏)
print("训练学生模型...")
history = distiller.fit(train_ds, epochs=10, validation_data=test_ds)
print(history.history)# 可视化训练过程
plt.plot(history.history['accuracy'], label='CTR Accuracy')plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()"""
Part-6:模型服务化(示例)
"""
# 保存学生模型
student.save('esmm_student_model')# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')# 查看模型输入层名称
loaded_model.summary()# 示例预测:从 test_features 数据框中提取第一行数据
sample = test_features.iloc[0]sample_dict = {col: tf.expand_dims(value, -1)for col, value in dict(sample).items()
}predictions = loaded_model.predict(sample_dict)
print(f"预测结果:CTR={predictions['ctr_logits'][0][0]:.3f}")

3.8 运行效果

Teacher 模型训练过程:

训练教师模型...
Epoch 1/5
2025-03-30 21:41:55.398982: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
25/25 [==============================] - 2s 13ms/step - loss: 0.5838 - accuracy: 0.5013 - val_loss: 0.5115 - val_accuracy: 0.4850
Epoch 2/5
25/25 [==============================] - 0s 3ms/step - loss: 0.5049 - accuracy: 0.5013 - val_loss: 0.5101 - val_accuracy: 0.4850
Epoch 3/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5037 - accuracy: 0.5013 - val_loss: 0.5093 - val_accuracy: 0.4850
Epoch 4/5
25/25 [==============================] - 0s 2ms/step - loss: 0.5026 - accuracy: 0.5013 - val_loss: 0.5085 - val_accuracy: 0.4850
Epoch 5/5
25/25 [==============================] - 0s 5ms/step - loss: 0.5014 - accuracy: 0.5013 - val_loss: 0.5077 - val_accuracy: 0.4850

Student 模型训练过程:

训练学生模型...
Epoch 1/10
25/25 [==============================] - 2s 11ms/step - accuracy: 0.4975 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 2/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5038 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 3/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5063 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 4/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5050 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 5/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 6/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 7/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 8/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 9/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5075 - val_loss: 0.0000e+00 - val_accuracy: 0.4850
Epoch 10/10
25/25 [==============================] - 0s 2ms/step - accuracy: 0.5088 - val_loss: 0.0000e+00 - val_accuracy: 0.4900

模型结构及预测示例:

Model: "tow_tower_student"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================dense_features_2 (DenseFeat  multiple                 23706     ures)                                                           dense_features_3 (DenseFeat  multiple                 1620      ures)                                                           sequential_1 (Sequential)   (None, 32)                4000      sequential_2 (Sequential)   (None, 32)                2976      dense_8 (Dense)             multiple                  2         =================================================================
Total params: 32,304
Trainable params: 32,304
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 0s 178ms/step
预测结果:CTR=0.526

可视化训练过程:
在这里插入图片描述

3.9 模型预测

在另一个工程中加载通过蒸馏训练好的 Student 模型,并执行预测,代码示例如下:

# 导入必要的库
import tensorflow as tf
import pandas as pd
import numpy as np# 人工构造数据
num_users = 100
num_items = 200# 重新生成新的样本,模拟真实数据进行预测
def generate_new_samples(num_samples=5):new_samples = []for _ in range(num_samples):user_id = np.random.randint(1, num_users + 1)item_id = np.random.randint(1, num_items + 1)user_age = np.random.randint(18, 65)user_gender = np.random.choice(['male', 'female'])user_occupation = np.random.choice(['student', 'worker', 'teacher'])city_code = np.random.randint(1, 2856)device_type = np.random.randint(0, 5)item_category = np.random.choice(['electronics', 'books', 'clothing'])item_brand = np.random.choice(['brandA', 'brandB', 'brandC'])item_price = np.random.randint(1, 199)new_samples.append({'user_id': user_id,'user_age': user_age,'user_gender': user_gender,'user_occupation': user_occupation,'city_code': city_code,'device_type': device_type,'item_id': item_id,'item_category': item_category,'item_brand': item_brand,'item_price': item_price})return pd.DataFrame(new_samples)# 生成并打印预览新的样本数据
new_samples = generate_new_samples(num_samples=5)
# 设置display.max_columns为None,强制显示全部列:
pd.set_option('display.max_columns', None)
print("\nGenerated New Samples:\n", new_samples)# 准备输入数据
input_dict = {'user_id': tf.convert_to_tensor(new_samples['user_id'].values, dtype=tf.int64),'user_age': tf.convert_to_tensor(new_samples['user_age'].values, dtype=tf.int64),'user_gender': tf.convert_to_tensor(new_samples['user_gender'].values, dtype=tf.string),'user_occupation': tf.convert_to_tensor(new_samples['user_occupation'].values, dtype=tf.string),'city_code': tf.convert_to_tensor(new_samples['city_code'].values, dtype=tf.int64),'device_type': tf.convert_to_tensor(new_samples['device_type'].values, dtype=tf.int64),'item_id': tf.convert_to_tensor(new_samples['item_id'].values, dtype=tf.int64),'item_category': tf.convert_to_tensor(new_samples['item_category'].values, dtype=tf.string),'item_brand': tf.convert_to_tensor(new_samples['item_brand'].values, dtype=tf.string),'item_price': tf.convert_to_tensor(new_samples['item_price'].values, dtype=tf.int64)
}# 加载模型进行推理
loaded_model = tf.keras.models.load_model('esmm_student_model')
# 明确使用默认签名
predict_fn = loaded_model.signatures['serving_default']
predictions = predict_fn(**input_dict)# 提取并打印预测结果
# 预测结果是一个 CTCVR 综合分
predicted_ctr = predictions['ctr_logits'].numpy().flatten()
new_samples['ctr_prob'] = predicted_ctr
print("\nPrediction Results:")
for idx, row in new_samples.iterrows():print(f"Item ID: {row['item_id']} | CTR Final Score: {row['ctr_prob']:.4f}")

运行结果如下:

Generated New Samples:user_id  user_age user_gender user_occupation  city_code  device_type  \
0       34        49      female         teacher        843            0   
1       15        30      female         student        564            3   
2       26        37        male         teacher       2229            0   
3       31        35        male          worker       2494            0   
4       41        57      female         student       1668            3   item_id item_category item_brand  item_price  
0      147   electronics     brandA         127  
1      196      clothing     brandC         190  
2        1         books     brandA           1  
3      150      clothing     brandA           5  
4      128   electronics     brandA         156  
Metal device set to: Apple M1 ProPrediction Results:
Item ID: 147 | CTR Final Score: 0.5263
Item ID: 196 | CTR Final Score: 0.5263
Item ID: 1 | CTR Final Score: 0.5263
Item ID: 150 | CTR Final Score: 0.4793
Item ID: 128 | CTR Final Score: 0.5263

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/75100.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习四大核心架构:神经网络(NN)、卷积神经网络(CNN)、循环神经网络(RNN)与Transformer全概述

目录 📂 深度学习四大核心架构 🌰 知识点概述 🧠 核心区别对比表 ⚡ 生活化案例理解 🔑 选型指南 📂 深度学习四大核心架构 第一篇: 神经网络基础(NN) 🌰 知识点概述…

R语言对偏态换数据进行转换(对数、平方根、立方根)

我们进行研究的时候经常会遇见偏态数据,数据转换是统计分析和数据预处理中的一项基本技术。使用 R 时,了解如何正确转换数据有助于满足统计假设、标准化分布并提高分析的准确性。在 R 中实现和可视化最常见的数据转换:对数、平方根和立方根转…

第十四届蓝桥杯省赛电子类单片机学习记录(客观题)

01.一个8位的DAC转换器,供电电压为3.3V,参考电压2.4V,其ILSB产生的输出电压增量是(D)V。 A. 0.0129 B. 0.0047 C. 0.0064 D. 0.0094 解析: ILSB(最低有效位)的电压增量计算公式…

【算法】手撕快速排序

快速排序的思想 任取一个元素作为枢轴,然后想办法把这个区间划分为两部分,小于等于枢轴的放左边,大于等于枢轴的放右边 然后递归处理左右区间,直到空或只剩一个 具体动画演示详见 数据结构合集 - 快速排序(算法过程, 效率分析…

《八大排序算法》

相关概念 排序:使一串记录,按照其中某个或某些关键字的大小,递增或递减的排列起来。稳定性:它描述了在排序过程中,相等元素的相对顺序是否保持不变。假设在待排序的序列中,有两个元素a和b,它们…

JavaScript DOM与元素操作

目录 DOM 树、DOM 对象、元素操作 一、DOM 树与 DOM 对象 二、获取 DOM 元素 1. 基础方法 2. 现代方法(ES6) 三、修改元素内容 四、修改元素常见属性 1. 标准属性 2. 通用方法 五、通过 style 修改样式 六、通过类名修改样式 1. className 属…

大模型学习:从零到一实现一个BERT微调

目录 一、准备阶段 1.导入模块 2.指定使用的是GPU还是CPU 3.加载数据集 二、对数据添加词元和分词 1.根据BERT的预训练,我们要将一个句子的句头添加[CLS]句尾添加[SEP] 2.激活BERT词元分析器 3.填充句子为固定长度 代码解释: 三、数据处理 1.…

10组时尚复古美学自然冷色调肖像电影照片调色Lightroom预设 De La Mer – Nautical Lightroom Presets

De La Mer 预设系列包含 10 种真实的调色预设,适用于肖像、时尚和美术。为您的肖像摄影带来电影美学和个性! De La Mer 预设非常适合专业人士和业余爱好者,可在桌面或移动设备上使用,为您的摄影项目提供轻松的工作流程。这套包括…

机器学习的一百个概念(4)下采样

前言 本文隶属于专栏《机器学习的一百个概念》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见[《机器学习的一百个概念》 ima 知识库 知识库广场搜索&…

数据安全系列4:密码技术的应用-接口调用的身份识别

传送门 数据安全系列1:开篇 数据安全系列2:单向散列函数概念 数据安全系列3:密码技术概述 什么是认证? 一谈到认证,多数人的反应可能就是"用户认证" 。就是应用系统如何识别用户的身份,直接…

STL之map和set

1. 关联式容器 vector、list、deque、 forward_list(C11)等,这些容器统称为序列式容器,因为其底层为线性序列的数据结构,里面存储的是元素本身。 关联式容器也是用来存储数据的,与序列式容器不同的是,其里面存储的是结…

Vue3 其它API Teleport 传送门

Vue3 其它API Teleport 传送门 在定义一个模态框时,父组件的filter属性会影响子组件的position属性,导致模态框定位错误使用Teleport解决这个问题把模态框代码传送到body标签下

《Python Web网站部署应知应会》No4:基于Flask的调用AI大模型的高性能博客网站的设计思路和实战(上)

基于Flask的调用AI大模型的高性能博客网站的设计思路和实战(上) 摘要 本文详细探讨了一个基于Flask框架的高性能博客系统的设计与实现,该系统集成了本地AI大模型生成内容的功能。我们重点关注如何在高并发、高负载状态下保持系统的高性能和…

力扣刷题-热题100题-第27题(c++、python)

21. 合并两个有序链表 - 力扣(LeetCode)https://leetcode.cn/problems/merge-two-sorted-lists/description/?envTypestudy-plan-v2&envIdtop-100-liked 常规法 创建一个新链表,遍历list1与list2,将新链表指向list1与list2…

AI加Python的文本数据情感分析流程效果展示与代码实现

本文所使用数据来自于梯田景区评价数据。 一、数据预处理 数据清洗 去除重复值、空值及无关字符(如表情符号、特殊符号等)。 提取中文文本,过滤非中文字符。 统一文本格式(如全角转半角、繁体转简体)。 中文分词与去停用词 使用 jieba 分词工具进行分词。 加载自定义词…

Microi吾码界面设计引擎之基础组件用法大全【内置组件篇·上】

🎀🎀🎀 microi-pageengine 界面引擎系列 🎀🎀🎀 一、Microi吾码:一款高效、灵活的低代码开发开源框架【低代码框架】 二、Vue3项目快速集成界面引擎 三、Vue3 界面设计插件 microi-pageengine …

【多线程】单例模式和阻塞队列

目录 一.单例模式 1. 饿汉模式 2. 懒汉模式 二.阻塞队列 1. 阻塞队列的概念 2. BlockingQueue接口 3.生产者-消费者模型 4.模拟生产者-消费者模型 一.单例模式 单例模式(Singleton Pattern)是一种常用的软件设计模式,其核心思想是确保…

Vuex状态管理

Vuex Vuex是一个专为Vue.js应用程序开发的状态管理模式。它采用集中式管理应用的所有组件状态,并以相应的规则保证状态以一种可预测的方式发生变化。(类似于在前端的数据库,这里的数据存储在内存当中) 一、安装并配置 在项目的…

从代码学习深度学习 - 使用块的网络(VGG)PyTorch版

文章目录 前言一、VGG网络简介1.1 VGG的核心特点1.2 VGG的典型结构1.3 优点与局限性1.4 本文的实现目标二、搭建VGG网络2.1 数据准备2.2 定义VGG块2.3 构建VGG网络2.4 辅助工具2.4.1 计时器和累加器2.4.2 准确率计算2.4.3 可视化工具2.5 训练模型2.6 运行实验总结前言 深度学习…

Baklib激活企业知识管理新动能

Baklib核心技术架构解析 Baklib的底层架构以模块化设计为核心,融合知识中台的核心理念,通过分布式存储引擎与智能语义分析系统构建三层技术体系。数据层采用多源异构数据接入协议,支持文档、音视频、代码片段等非结构化数据的实时解析与分类…