推荐系统经典模型YouTubeDNN代码

文章目录

    • 前言
    • 数据预处理部分
    • 模型训练预测部分
    • 总结与问答

前言

  • 上一篇讲到过YouTubeDNN论文部分内容,但是没有代码部分。最近网上教学视频里看到一段关于YouTubeDNN召回算法的代码,现在我分享一下给大家参考看一下,并附上一些我对代码的理解。

数据预处理部分

  • 首先我们需要对数据集进行预处理,数据集格式如下图所示
    在这里插入图片描述
  • 根据YouTubeDNN论文,输入的数据是用户的信息、视频的ID序列、用户搜索的特征和一些地理信息等其他信息。到了基于文章内容的信息流产品中,就变成了用户 ID、年龄、性别、城市、阅读的时间戳再加上视频的ID。我们把这些内容可以组合成YouTubeDNN需要的内容,最后处理成需要的Embedding。
from tqdm import tqdm
import numpy as np
import random
from tensorflow.python.keras.preprocessing.sequence import pad_sequencesdef gen_data_set(data, negsample=0):# 根据timestamp排序数据,并替换data.sort_values("timestamp", inplace=True)#根据item_id进行去重item_ids = data['item_id'].unique()# 构建训练与测试listtrain_set = list()test_set = list()for reviewrID, hist in tqdm(data.groupby('user_id')):# 正样本列表pos_list = hist['item_id'].tolist()rating_list = hist['rating'].tolist()if negsample > 0:# 候选集中去掉用户看过的item项目candidate_set = list(set(item_ids) - set(pos_list))# 随机选择负采样样本neg_list = np.random.choice(candidate_set, size=len(pos_list) * negsample, replace=True)for i in range(1, len(pos_list)):if i != len(pos_list) - 1:# 训练集和测试集划分train_set.append((reviewrID, hist[::-1], pos_list[i], 1, len(hist[:: -1]), rating_list[i]))for negi in range(negsample):train_set.append((reviewrID, hist[::-1], neg_list[i * negsample + negi], 0, len(hist[::-1])))else:test_set.append((reviewrID, hist[::-1], pos_list[i], 1, len(hist[::-1]), rating_list[i]))# 打乱数据集random.shuffle(train_set)random.shuffle(test_set)return train_set, test_setdef gen_model_input(train_set, user_profile, seq_max_len):# 用户idtrain_uid = np.array([line[0] for line in train_set])# 历史交互序列train_seq = [line[1] for line in train_set]# 物品idtrain_iid = np.array([line[2] for line in train_set])# 正负样本标签train_label = np.array([line[3] for line in train_set])# 历史交互序列长度train_hist_len = np.array([line[4] for line in train_set])train_seq_pad = pad_sequences(train_seq, maxlen=seq_max_len, padding='post', truncating='post', value=0 )train_model_input = {"user_id": train_uid, "item_id": train_iid, "hist_item_id": train_seq_pad, "hist_len": train_hist_len}for key in {"gender", "age", "city"}:train_model_input[key] = user_profile.loc[train_model_input['user_id']][key].valuesreturn train_model_input, train_label
  • 代码解释:
    • **gen_data_set() **主要作用是接收数据集(data)和一个负采样(negsample)参数,返回一个训练集列表(trainset)和一个测试集列表(testset)。具体流程是先通过timetamp列对数据进行排序,根据item_id进行去重;然后根据user_id分组形成正负样本(正样本为购买过的,负样本为没有购买过的),对于negsample大于0,我们就要进行负采样,也就是随机选择一些没有购买过的商品为负样本,然后将它们保存到训练集中;最后,将正负样本数据以及其他信息(如历史交互序列、用户 ID 和历史交互序列的长度)保存到训练集列表和测试集列表中。
    • gen_model_input() 主要作用就是接收一个训练集列表、用户画像信息和序列最大长度参数,返回训练模型的输入和标签。首先将训练集列表拆分成 5 个列表(train_uid train_seq train_iid train_label train_hist_len);然后使用pad_sequences() 函数对历史交互序列进行填充处理,将其变成长度相同的序列。最后,将用户画像信息(gender age city)加入到训练模型的关键字中,返回训练模型的输入和标签。
    • pad_sequences():pad_sequences()这个函数是来自于TensorFlow中数据预处理的一种方法,主要就是数据预填充。在TensorFlow2.8版本之前可以通过from tensorflow.python.keras.preprocessing.sequence import pad_sequences调用,后期版本则是在keras.utils里,这里建议使用低版本tesorflow2,具体版本信息请参考链接。

模型训练预测部分

  • 进入模型训练阶段,我们需要先了解一下,代码里我们所使用的一些包和函数介绍
    • sklearn.preprocessing.LabelEncoder:对数据进行特征编码
    • deepctr.feature_column.SparseFeat, VarLenSparseFeat:用户构建用户和物品特征输入。
    • deepmatch:用于构建和训练推荐模型
    • faiss:高效向量相似性搜索库
    • models.recall.preprocess.gen_data_set, gen_model_input:数据预处理部分(自建)
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from models.recall.preprocess import gen_data_set, gen_model_input
from deepctr.feature_column import SparseFeat, VarLenSparseFeat
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model
import tensorflow as tf
from deepmatch.models import *
from deepmatch.utils import recall_N
from deepmatch.utils import sampledsoftmaxloss
import numpy as np
from tqdm import tqdm
import faissclass YouTubeModel(object):def __init__(self, embedding_dim=32):self.SEQ_LEN = 50self.embedding_dim = embedding_dimself.user_feature_columns = Noneself.item_feature_columns = Nonedef training_set_construct(self):# 数据加载data = pd.read_csv('../../data/read_history.csv')# 负采样个数negsample = 0# 特征编码features = ["user_id", "item_id", "gender", "age", "city"]features_max_idx={}for feature in features:lbe = LabelEncoder()data[feature] = lbe.fit_transform(data[feature]) + 1features_max_idx[feature] = data[feature].max() + 1# 抽取用户、物品特征(并去重)user_info = data[["user_id", "gender", "age", "city"]].drop_duplicates('user_id')item_info = data[["item_id"]].drop_duplicates('item_id')# 构建输入数据train_set, test_set = gen_data_set(data, negsample)# 转化模型输入train_model_input, train_label = gen_model_input(train_set, user_info, self.SEQ_LEN)test_model_input, test_label = gen_model_input(test_set, user_info, self.SEQ_LEN)# 用户端特征输入self.user_feature_columns = [SparseFeat('user_id', features_max_idx['user_id'], 16),SparseFeat('gender', features_max_idx['gender'], 16),SparseFeat('age', features_max_idx['age'], 16),SparseFeat('city', features_max_idx['city'], 16),VarLenSparseFeat(SparseFeat('hist_item_id', features_max_idx['item_id'],self.embedding_dim, embedding_name='item_id'),self.SEQ_LEN, 'mean', 'hist_len')]# 物品端特征输入self.item_feature_columns = [SparseFeat('item_id', features_max_idx['item_id'], self.embedding_dim)]return train_model_input, train_label, test_model_input, test_label, train_set, test_set, user_info, item_infodef training_model(self, train_model_input, train_label):K.set_learning_phase(True)if tf.__version__ >= '2.0.0':tf.compat.v1.disable_eager_execution()# 定义模型model = YoutubeDNN(self.user_feature_columns, self.item_feature_columns, num_sampled=100,user_dnn_hidden_units=(128, 64, self.embedding_dim))# 使用adam优化,损失函数使用softmax+cross_entropymodel.compile(optimizer="adam", loss=sampledsoftmaxloss)# 训练并保存训练过程中的数据model.fit(train_model_input, train_label, batch_size=512, epochs=20, verbose=1, validation_split=0.0,)return model# 提取用户和物品的embedding layerdef extract_embedding_layer(self, model, test_model_input, item_info):all_item_model_input = {"item_id": item_info['item_id'].values, }# 获取用户、item的embedding_layeruser_embedding_model = Model(inputs=model.user_input, outputs=model.user_embedding)item_embedding_model = Model(inputs=model.item_input, outputs=model.item_embedding)user_embs = user_embedding_model.predict(test_model_input, batch_size=2 ** 12)item_embs = item_embedding_model.predict(all_item_model_input, batch_size=2 ** 12)print(user_embs.shape)print(item_embs.shape)return user_embs, item_embs# 计算召回率和命中率def eval(self, user_embs, item_embs, test_model_input, item_info, test_set):test_true_label = {line[0]: line[2] for line in test_set}index = faiss.IndexFlatIP(self.embedding_dim)index.add(item_embs)D, I = index.search(np.ascontiguousarray(user_embs), 50)s = []hit = 0# 统计预测结果for i, uid in tqdm(enumerate(test_model_input['user_id'])):try:pred = [item_info['item_id'].value[x] for x in I[i]]recall_score = recall_N(test_true_label[uid], pred, N=50)s.append(recall_score)if test_true_label[uid] in pred:hit += 1except:print(i)# 计算召回率和命中率recall = np.mean(s)hit_rate = hit / len(test_model_input['user_id'])return recall, hit_ratedef scheduler(self):# 构建训练集、测试集train_model_input, train_label, test_model_input, test_label, \train_set, test_set, user_info, item_info = self.training_set_construct()self.training_model(train_model_input, train_label)# 获取用户、item的layeruser_embs, item_embs = self.extract_embedding_layer(model, test_model_input, item_info)# 评估模型recall, hit_rate = self.eval(user_embs, item_embs, test_model_input, item_info, test_set)print(recall, hit_rate)if __name__ == '__main__':model = YouTubeModel()model.scheduler()
  • 代码解释:
    • training_set_construct:加载数据集,特征编码,数据集预处理,使用deepctr库中的SparseFeat(离散), VarLenSparseFeat(变长)实现用户物品的特征输入。
    • training_model:YoutubeDNN构建训练模型,compile编译训练模型,fit模型训练。
    • extract_embedding_layer:提取用户和物品的Embedding Layer。
    • eval:评估模型计算召回率和命中率,使用faiss中的faiss.IndexFlatIP(余弦距离搜索并非余弦相似度),统计预测结果,计算召回率为recall_score的平均值;命中率则是集中次数hit与test_model_input的总数。
    • scheduler:串联整个召回代码的函数,负责调用。

总结与问答

  1. 代码中提到的离散特征和变长特征该如何选择?
  • 答:首先我们要理解一下什么事离散特征,什么是变长特征?
    • 离散特征:是指具有有限取值或离散类别的特征,例如性别、国家、城市等(用户画像信息)。对于离散特征,可以使用embedding来将其映射到低维连续向量空间中。这使得模型能够学习离散特征之间的相关性和交互关系。通常情况下,离散特征需要经过编码(例如one-hot multi-hot)并与其他特征一起输入到模型中。
    • 变长特征:是指具有可变长度的特征,例如用户的历史行为序列或商品的标签列表。对于变长特征,可以使用循环神经网络(RNN)或Transformer等模型来建模。这些模型可以处理可变长度的序列,并捕捉序列中的时序关系和上下文信息。
    • 所以对于多特征输入,通常需要混合使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/709741.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一张图读懂人工智能

一、生成人工智能的概念和应用,以及如何使用大型语言模型进行聊天和创造原创内容。这项技术将会对人类和企业产生深远影响。 计算机获得学习、思考和交流的能力,被称为生成人工智能。生成人工智能可以立即获得人类所有知识的总和,并回答任何…

go语言基础 -- map的定义与使用

map的定义与使用 map声明基础语法map的基本使用map的遍历map切片map排序 map声明基础语法 // map的声明 var xxx_map map[key_type]value_typemap的key可以是基本数据类型,channel,接口,结构体,数组,但不能是slice&am…

leetcode刷题(剑指offer) 47.全排列Ⅱ

47.全排列Ⅱ 给定一个可包含重复数字的序列 nums ,按任意顺序 返回所有不重复的全排列。 示例 1: 输入:nums [1,1,2] 输出: [[1,1,2],[1,2,1],[2,1,1]]示例 2: 输入:nums [1,2,3] 输出:[[…

c#委托、lambda、事件

Lambda Lambda表达式是一种匿名函数,Lambda表达式通常以箭头“>”分隔左侧的输入和右侧的输出。 (parameter_list) > { statement_block } parameter_list 是由一个或多个参数组成的逗号分隔列表,每个参数都包括类型和名称,可以为空。…

综合实战(volume and Compose)

"让我,重获新生~" MySQL 灾难恢复 熟练掌握挂载卷的使用,将Mysql的业务数据存储在 外部。 实战思想: 使用 MySQL 5.7 的镜像创建容器并创建一个普通数据卷 "mysql-data"用来保存容器中产生的数据。我们需要容器连接到Mysql服务&a…

TeXiFy IDEA 编译后文献引用为 “[?]“

文章目录 1. 问题描述2. 原因分析3. 解决方案3.1 添加自动化脚本3.2 附录——配置一览表 1. 问题描述 在 IDEA 中使用 TeXiFy IDEA 编译后的文章文献引用是 [?] 2. 原因分析 根据网上教程所生成的目录结构如下: 报错日志: 根据 /out 目录结构&#x…

C++ | 使用正则表达式匹配特定形式的字符串

C | 使用正则表达式匹配特定形式的字符串 在 C 中&#xff0c;可以使用 <regex> 头文件提供的正则表达式库来对特定形式的字符串进行匹配操作。 常用的正则表达式模式语法 普通字符&#xff1a; 普通字符会按照其字面意义进行匹配&#xff0c;例如 a 会匹配字符 a。 转…

【vmware安装群晖】

vmware安装群晖 vmware安装群辉&#xff1a; vmware版本&#xff1a;17pro 下载链接&#xff0c; https://customerconnect.vmware.com/cn/downloads/details?downloadGroupWKST-1751-WIN&productId1376&rPId116859 激活码可自行搜索 教程&#xff1a; https://b…

软考笔记--企业应用集成

在企业信息化建设的过程中&#xff0c;由于缺乏统一规划和总体布局&#xff0c;往往形成多个信息孤岛。信息孤岛试数据的一致性无法得到保证&#xff0c;信息无法共享和反馈&#xff0c;需要重复多次的采集和输入。信息孤岛是企业信息话的一个重要的负面因素&#xff0c;其主要…

C++重新入门-string容器

目录 1.包含头文件 2.创建字符串 3.获取字符串长度 4.字符串拼接 5.字符串比较 相等性比较 大小比较 使用比较函数 6.访问字符串 7.查找子串 8.字符串修改 替换子串 插入字符或子串 删除字符或子串 9.提取子串 10.总结 当谈到C中的字符串时&#xff0c;std::str…

135.乐理基础-半音是小二度吗?全音是大二度吗?三全音

内存参考于&#xff1a;三分钟音乐社 上一个内容&#xff1a;134.乐理基础-音程名字的简写-CSDN博客 上一个内容里练习的答案&#xff1a; 半音可以与小二度划等号吗&#xff1f;全音可以和大二度划等号吗&#xff1f; 严格来说它们是不能划等号的&#xff0c;半音与全音是侧…

基于springboot实现的健康监控管理系统

一、系统架构 前端&#xff1a;html | bootstrap | jquery | css 后端&#xff1a;springboot | thymeleaf | mybatis 环境&#xff1a;jdk1.8 | mysql | maven 二、代码及数据库 三、功能介绍 01. 体检测评 02. 运动处方 03. 运动处方明细 04. 运动处方-打卡…

基于transform的scale属性,动态缩放整个页面,实现数据可视化大屏自适应,保持比例不变形,满足不同分辨率的需求

文章目录 一、需求背景&#xff1a;二、需求分析&#xff1a;三、选择方案&#xff1a;四、实现代码&#xff1a;五、效果预览&#xff1a;六、封装组件&#xff1a; 一、需求背景&#xff1a; 数据可视化大屏是一种将数据、信息和可视化效果集中展示在一块或多块大屏幕上的技…

PyTorch基础(19)-- torch.take_along_dim()方法

一、前言 在深挖ML4CO的代码过程中&#xff0c;遇到了torch.take_along_dim()这个方法&#xff0c;影响到我后续的代码阅读&#xff1b;加之在上网搜索资料的过程中&#xff0c;网络上对此函数的介绍文章少之又少&#xff0c;即使有&#xff0c;也是对torch官网文档中的解释进…

代码随想录算法训练营总结篇

时间好快&#xff0c;随着春节的穿插&#xff0c;两个月的算法训练营的一刷旅程在今天就落下了帷幕。回顾这两个月来的刷题经历&#xff0c;首先第一感受是学到很多&#xff0c;见识到了很多新的解题思想&#xff0c;如线性表中的双指针方法&#xff0c;快慢双指针、首位双指针…

价格腰斩:腾讯云和阿里云服务器优惠价格对比

2024年阿里云服务器和腾讯云服务器价格战已经打响&#xff0c;阿里云服务器优惠61元一年起&#xff0c;腾讯云服务器62元一年&#xff0c;2核2G3M、2核4G、4核8G、8核16G、16核32G、16核64G等配置价格对比&#xff0c;阿腾云atengyun.com整理阿里云和腾讯云服务器详细配置价格表…

jvm面试题目补充

jdk&jre Java程序设计语言、Java虚拟机、Java API类库这三部分统称为JDK&#xff08;Java Development Kit&#xff09;。 把Java API类库中的Java SE API子集 [1] 和Java虚拟机这两部分统称为JRE&#xff08;Java Runtime Environment&#xff09;&#xff0c;JRE是支持…

信号的学习

1.信号 1.pause int pause(void); 功能: 让进程睡眠,直到接收到信号(捕捉)才能继续向下执行 2.alarm unsigned int alarm(unsigned int seconds); 功能: 定时seconds秒后给调用进程发送SIGALRM信号 参数: seconds:定时的秒数 …

CUDA C:查看GPU设备信息

相关阅读 CUDA Chttps://blog.csdn.net/weixin_45791458/category_12530616.html?spm1001.2014.3001.5482 了解自己设备的性能是很有必要的&#xff0c;为此CUDA 运行时(runtime)API给用户也提供了一些查询设备信息的函数&#xff0c;下面的函数用于查看GPU设备的一切信息。 …

MyBatis 学习(二)之 第一个 MyBatis 案例

目录 1 配置 MyBatis 方式 1.1 XML 配置文件 1.2 Java 注解配置 1.3. Java API 配置 2 在 MySQL 中创建一张表 3 创建一个基于 Maven 的 JavaWeb 工程 4 编写 User 实体类 5 创建 Mybatis 全局配置文件 6 编写一个 DAO 或 Mapper 接口 7 编写 SQL 映射配置文件&#…