推荐系统经典模型YouTubeDNN代码

文章目录

    • 前言
    • 数据预处理部分
    • 模型训练预测部分
    • 总结与问答

前言

  • 上一篇讲到过YouTubeDNN论文部分内容,但是没有代码部分。最近网上教学视频里看到一段关于YouTubeDNN召回算法的代码,现在我分享一下给大家参考看一下,并附上一些我对代码的理解。

数据预处理部分

  • 首先我们需要对数据集进行预处理,数据集格式如下图所示
    在这里插入图片描述
  • 根据YouTubeDNN论文,输入的数据是用户的信息、视频的ID序列、用户搜索的特征和一些地理信息等其他信息。到了基于文章内容的信息流产品中,就变成了用户 ID、年龄、性别、城市、阅读的时间戳再加上视频的ID。我们把这些内容可以组合成YouTubeDNN需要的内容,最后处理成需要的Embedding。
from tqdm import tqdm
import numpy as np
import random
from tensorflow.python.keras.preprocessing.sequence import pad_sequencesdef gen_data_set(data, negsample=0):# 根据timestamp排序数据,并替换data.sort_values("timestamp", inplace=True)#根据item_id进行去重item_ids = data['item_id'].unique()# 构建训练与测试listtrain_set = list()test_set = list()for reviewrID, hist in tqdm(data.groupby('user_id')):# 正样本列表pos_list = hist['item_id'].tolist()rating_list = hist['rating'].tolist()if negsample > 0:# 候选集中去掉用户看过的item项目candidate_set = list(set(item_ids) - set(pos_list))# 随机选择负采样样本neg_list = np.random.choice(candidate_set, size=len(pos_list) * negsample, replace=True)for i in range(1, len(pos_list)):if i != len(pos_list) - 1:# 训练集和测试集划分train_set.append((reviewrID, hist[::-1], pos_list[i], 1, len(hist[:: -1]), rating_list[i]))for negi in range(negsample):train_set.append((reviewrID, hist[::-1], neg_list[i * negsample + negi], 0, len(hist[::-1])))else:test_set.append((reviewrID, hist[::-1], pos_list[i], 1, len(hist[::-1]), rating_list[i]))# 打乱数据集random.shuffle(train_set)random.shuffle(test_set)return train_set, test_setdef gen_model_input(train_set, user_profile, seq_max_len):# 用户idtrain_uid = np.array([line[0] for line in train_set])# 历史交互序列train_seq = [line[1] for line in train_set]# 物品idtrain_iid = np.array([line[2] for line in train_set])# 正负样本标签train_label = np.array([line[3] for line in train_set])# 历史交互序列长度train_hist_len = np.array([line[4] for line in train_set])train_seq_pad = pad_sequences(train_seq, maxlen=seq_max_len, padding='post', truncating='post', value=0 )train_model_input = {"user_id": train_uid, "item_id": train_iid, "hist_item_id": train_seq_pad, "hist_len": train_hist_len}for key in {"gender", "age", "city"}:train_model_input[key] = user_profile.loc[train_model_input['user_id']][key].valuesreturn train_model_input, train_label
  • 代码解释:
    • **gen_data_set() **主要作用是接收数据集(data)和一个负采样(negsample)参数,返回一个训练集列表(trainset)和一个测试集列表(testset)。具体流程是先通过timetamp列对数据进行排序,根据item_id进行去重;然后根据user_id分组形成正负样本(正样本为购买过的,负样本为没有购买过的),对于negsample大于0,我们就要进行负采样,也就是随机选择一些没有购买过的商品为负样本,然后将它们保存到训练集中;最后,将正负样本数据以及其他信息(如历史交互序列、用户 ID 和历史交互序列的长度)保存到训练集列表和测试集列表中。
    • gen_model_input() 主要作用就是接收一个训练集列表、用户画像信息和序列最大长度参数,返回训练模型的输入和标签。首先将训练集列表拆分成 5 个列表(train_uid train_seq train_iid train_label train_hist_len);然后使用pad_sequences() 函数对历史交互序列进行填充处理,将其变成长度相同的序列。最后,将用户画像信息(gender age city)加入到训练模型的关键字中,返回训练模型的输入和标签。
    • pad_sequences():pad_sequences()这个函数是来自于TensorFlow中数据预处理的一种方法,主要就是数据预填充。在TensorFlow2.8版本之前可以通过from tensorflow.python.keras.preprocessing.sequence import pad_sequences调用,后期版本则是在keras.utils里,这里建议使用低版本tesorflow2,具体版本信息请参考链接。

模型训练预测部分

  • 进入模型训练阶段,我们需要先了解一下,代码里我们所使用的一些包和函数介绍
    • sklearn.preprocessing.LabelEncoder:对数据进行特征编码
    • deepctr.feature_column.SparseFeat, VarLenSparseFeat:用户构建用户和物品特征输入。
    • deepmatch:用于构建和训练推荐模型
    • faiss:高效向量相似性搜索库
    • models.recall.preprocess.gen_data_set, gen_model_input:数据预处理部分(自建)
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from models.recall.preprocess import gen_data_set, gen_model_input
from deepctr.feature_column import SparseFeat, VarLenSparseFeat
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model
import tensorflow as tf
from deepmatch.models import *
from deepmatch.utils import recall_N
from deepmatch.utils import sampledsoftmaxloss
import numpy as np
from tqdm import tqdm
import faissclass YouTubeModel(object):def __init__(self, embedding_dim=32):self.SEQ_LEN = 50self.embedding_dim = embedding_dimself.user_feature_columns = Noneself.item_feature_columns = Nonedef training_set_construct(self):# 数据加载data = pd.read_csv('../../data/read_history.csv')# 负采样个数negsample = 0# 特征编码features = ["user_id", "item_id", "gender", "age", "city"]features_max_idx={}for feature in features:lbe = LabelEncoder()data[feature] = lbe.fit_transform(data[feature]) + 1features_max_idx[feature] = data[feature].max() + 1# 抽取用户、物品特征(并去重)user_info = data[["user_id", "gender", "age", "city"]].drop_duplicates('user_id')item_info = data[["item_id"]].drop_duplicates('item_id')# 构建输入数据train_set, test_set = gen_data_set(data, negsample)# 转化模型输入train_model_input, train_label = gen_model_input(train_set, user_info, self.SEQ_LEN)test_model_input, test_label = gen_model_input(test_set, user_info, self.SEQ_LEN)# 用户端特征输入self.user_feature_columns = [SparseFeat('user_id', features_max_idx['user_id'], 16),SparseFeat('gender', features_max_idx['gender'], 16),SparseFeat('age', features_max_idx['age'], 16),SparseFeat('city', features_max_idx['city'], 16),VarLenSparseFeat(SparseFeat('hist_item_id', features_max_idx['item_id'],self.embedding_dim, embedding_name='item_id'),self.SEQ_LEN, 'mean', 'hist_len')]# 物品端特征输入self.item_feature_columns = [SparseFeat('item_id', features_max_idx['item_id'], self.embedding_dim)]return train_model_input, train_label, test_model_input, test_label, train_set, test_set, user_info, item_infodef training_model(self, train_model_input, train_label):K.set_learning_phase(True)if tf.__version__ >= '2.0.0':tf.compat.v1.disable_eager_execution()# 定义模型model = YoutubeDNN(self.user_feature_columns, self.item_feature_columns, num_sampled=100,user_dnn_hidden_units=(128, 64, self.embedding_dim))# 使用adam优化,损失函数使用softmax+cross_entropymodel.compile(optimizer="adam", loss=sampledsoftmaxloss)# 训练并保存训练过程中的数据model.fit(train_model_input, train_label, batch_size=512, epochs=20, verbose=1, validation_split=0.0,)return model# 提取用户和物品的embedding layerdef extract_embedding_layer(self, model, test_model_input, item_info):all_item_model_input = {"item_id": item_info['item_id'].values, }# 获取用户、item的embedding_layeruser_embedding_model = Model(inputs=model.user_input, outputs=model.user_embedding)item_embedding_model = Model(inputs=model.item_input, outputs=model.item_embedding)user_embs = user_embedding_model.predict(test_model_input, batch_size=2 ** 12)item_embs = item_embedding_model.predict(all_item_model_input, batch_size=2 ** 12)print(user_embs.shape)print(item_embs.shape)return user_embs, item_embs# 计算召回率和命中率def eval(self, user_embs, item_embs, test_model_input, item_info, test_set):test_true_label = {line[0]: line[2] for line in test_set}index = faiss.IndexFlatIP(self.embedding_dim)index.add(item_embs)D, I = index.search(np.ascontiguousarray(user_embs), 50)s = []hit = 0# 统计预测结果for i, uid in tqdm(enumerate(test_model_input['user_id'])):try:pred = [item_info['item_id'].value[x] for x in I[i]]recall_score = recall_N(test_true_label[uid], pred, N=50)s.append(recall_score)if test_true_label[uid] in pred:hit += 1except:print(i)# 计算召回率和命中率recall = np.mean(s)hit_rate = hit / len(test_model_input['user_id'])return recall, hit_ratedef scheduler(self):# 构建训练集、测试集train_model_input, train_label, test_model_input, test_label, \train_set, test_set, user_info, item_info = self.training_set_construct()self.training_model(train_model_input, train_label)# 获取用户、item的layeruser_embs, item_embs = self.extract_embedding_layer(model, test_model_input, item_info)# 评估模型recall, hit_rate = self.eval(user_embs, item_embs, test_model_input, item_info, test_set)print(recall, hit_rate)if __name__ == '__main__':model = YouTubeModel()model.scheduler()
  • 代码解释:
    • training_set_construct:加载数据集,特征编码,数据集预处理,使用deepctr库中的SparseFeat(离散), VarLenSparseFeat(变长)实现用户物品的特征输入。
    • training_model:YoutubeDNN构建训练模型,compile编译训练模型,fit模型训练。
    • extract_embedding_layer:提取用户和物品的Embedding Layer。
    • eval:评估模型计算召回率和命中率,使用faiss中的faiss.IndexFlatIP(余弦距离搜索并非余弦相似度),统计预测结果,计算召回率为recall_score的平均值;命中率则是集中次数hit与test_model_input的总数。
    • scheduler:串联整个召回代码的函数,负责调用。

总结与问答

  1. 代码中提到的离散特征和变长特征该如何选择?
  • 答:首先我们要理解一下什么事离散特征,什么是变长特征?
    • 离散特征:是指具有有限取值或离散类别的特征,例如性别、国家、城市等(用户画像信息)。对于离散特征,可以使用embedding来将其映射到低维连续向量空间中。这使得模型能够学习离散特征之间的相关性和交互关系。通常情况下,离散特征需要经过编码(例如one-hot multi-hot)并与其他特征一起输入到模型中。
    • 变长特征:是指具有可变长度的特征,例如用户的历史行为序列或商品的标签列表。对于变长特征,可以使用循环神经网络(RNN)或Transformer等模型来建模。这些模型可以处理可变长度的序列,并捕捉序列中的时序关系和上下文信息。
    • 所以对于多特征输入,通常需要混合使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/709741.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一张图读懂人工智能

一、生成人工智能的概念和应用,以及如何使用大型语言模型进行聊天和创造原创内容。这项技术将会对人类和企业产生深远影响。 计算机获得学习、思考和交流的能力,被称为生成人工智能。生成人工智能可以立即获得人类所有知识的总和,并回答任何…

综合实战(volume and Compose)

"让我,重获新生~" MySQL 灾难恢复 熟练掌握挂载卷的使用,将Mysql的业务数据存储在 外部。 实战思想: 使用 MySQL 5.7 的镜像创建容器并创建一个普通数据卷 "mysql-data"用来保存容器中产生的数据。我们需要容器连接到Mysql服务&a…

TeXiFy IDEA 编译后文献引用为 “[?]“

文章目录 1. 问题描述2. 原因分析3. 解决方案3.1 添加自动化脚本3.2 附录——配置一览表 1. 问题描述 在 IDEA 中使用 TeXiFy IDEA 编译后的文章文献引用是 [?] 2. 原因分析 根据网上教程所生成的目录结构如下: 报错日志: 根据 /out 目录结构&#x…

【vmware安装群晖】

vmware安装群晖 vmware安装群辉: vmware版本:17pro 下载链接, https://customerconnect.vmware.com/cn/downloads/details?downloadGroupWKST-1751-WIN&productId1376&rPId116859 激活码可自行搜索 教程: https://b…

C++重新入门-string容器

目录 1.包含头文件 2.创建字符串 3.获取字符串长度 4.字符串拼接 5.字符串比较 相等性比较 大小比较 使用比较函数 6.访问字符串 7.查找子串 8.字符串修改 替换子串 插入字符或子串 删除字符或子串 9.提取子串 10.总结 当谈到C中的字符串时,std::str…

135.乐理基础-半音是小二度吗?全音是大二度吗?三全音

内存参考于:三分钟音乐社 上一个内容:134.乐理基础-音程名字的简写-CSDN博客 上一个内容里练习的答案: 半音可以与小二度划等号吗?全音可以和大二度划等号吗? 严格来说它们是不能划等号的,半音与全音是侧…

基于springboot实现的健康监控管理系统

一、系统架构 前端:html | bootstrap | jquery | css 后端:springboot | thymeleaf | mybatis 环境:jdk1.8 | mysql | maven 二、代码及数据库 三、功能介绍 01. 体检测评 02. 运动处方 03. 运动处方明细 04. 运动处方-打卡…

基于transform的scale属性,动态缩放整个页面,实现数据可视化大屏自适应,保持比例不变形,满足不同分辨率的需求

文章目录 一、需求背景:二、需求分析:三、选择方案:四、实现代码:五、效果预览:六、封装组件: 一、需求背景: 数据可视化大屏是一种将数据、信息和可视化效果集中展示在一块或多块大屏幕上的技…

PyTorch基础(19)-- torch.take_along_dim()方法

一、前言 在深挖ML4CO的代码过程中,遇到了torch.take_along_dim()这个方法,影响到我后续的代码阅读;加之在上网搜索资料的过程中,网络上对此函数的介绍文章少之又少,即使有,也是对torch官网文档中的解释进…

价格腰斩:腾讯云和阿里云服务器优惠价格对比

2024年阿里云服务器和腾讯云服务器价格战已经打响,阿里云服务器优惠61元一年起,腾讯云服务器62元一年,2核2G3M、2核4G、4核8G、8核16G、16核32G、16核64G等配置价格对比,阿腾云atengyun.com整理阿里云和腾讯云服务器详细配置价格表…

jvm面试题目补充

jdk&jre Java程序设计语言、Java虚拟机、Java API类库这三部分统称为JDK(Java Development Kit)。 把Java API类库中的Java SE API子集 [1] 和Java虚拟机这两部分统称为JRE(Java Runtime Environment),JRE是支持…

CUDA C:查看GPU设备信息

相关阅读 CUDA Chttps://blog.csdn.net/weixin_45791458/category_12530616.html?spm1001.2014.3001.5482 了解自己设备的性能是很有必要的,为此CUDA 运行时(runtime)API给用户也提供了一些查询设备信息的函数,下面的函数用于查看GPU设备的一切信息。 …

MyBatis 学习(二)之 第一个 MyBatis 案例

目录 1 配置 MyBatis 方式 1.1 XML 配置文件 1.2 Java 注解配置 1.3. Java API 配置 2 在 MySQL 中创建一张表 3 创建一个基于 Maven 的 JavaWeb 工程 4 编写 User 实体类 5 创建 Mybatis 全局配置文件 6 编写一个 DAO 或 Mapper 接口 7 编写 SQL 映射配置文件&#…

找不到mfc140.dll怎么办?教你五种mfc140.dll丢失的解决方法

当计算机系统中mfc140.dll文件丢失时,可能会引发一系列运行问题,影响到系统的正常功能及应用程序的稳定执行。具体来说,由于mfc140.dll是Microsoft Visual C Redistributable Package的重要组成部分,它的缺失会导致依赖于该动态链…

如何用好应用权限,保护隐私数据?银河麒麟桌面操作系统V10 SP1 2303 update2新功能解析

为您介绍银河麒麟桌面操作系统V10 SP1 2303 update2隐私设置和权限管理功能,为您的个人数据安全保驾护航。 说到个人数据隐私,在科技重塑生活本质的数字世界,个人信息遭受持续威胁。2018年,某国际知名社交平台因安全系统漏洞而遭…

深入解析Mybatis-Plus框架:简化Java持久层开发(六)

🍀 前言 博客地址: CSDN:https://blog.csdn.net/powerbiubiu 👋 简介 上一章介绍了新增的操作,后续删除,更新,查询的操作相对新增要复杂一些,因为有些方法的使用涉及到了条件&…

学习使用paddle来构造hrnet网络模型

1、首先阅读了hrnet的网络结构分析,了解到了网络构造如下: 参考博文姿态估计之2D人体姿态估计 - (HRNet)Deep High-Resolution Representation Learning for Human Pose Estimation(多家综合)-CSDN博客 最…

vue 部署后修改配置文件(接口IP)

近期,有一个项目,运维在部署的时候,接口ip还没有确定,而且ip后面的路径一直有变动,导致我这里一天打包至少四五次才行,很麻烦,然后看了下有没有打包后修改配置文件修改接口ip的方法,…

大话设计模式——4.装饰模式(Decorator Pattern)

1.定义 1)可以在不改动原有对象代码的情况下扩展对象的功能,通过聚合的方式相较于继承更加灵活。 2)UML图 2.示例 汽车有很多装饰可选,如座椅、音响、轮胎等都可以进行自定义组装 1)抽象汽车对象 public interfac…

数据结构------栈(Stack)和队列(Queue)

也是好久没写博客了,那今天就回归一下,写一篇数据结构的博客吧。今天要写的是栈和队列,也是数据结构中比较基础的知识。那么下面开始今天要写的博客了。 目录 栈(Stack) 队列(Queue) 喜欢就点…