使用Bert预训练模型处理序列推荐任务

最近的工作有涉及该任务,整理一下思路以及代码细节。

流程

总体来说思路就是首先用预训练的bert模型,在训练集的序列上进行CLS任务。对序列内容(这里默认是token id的sequence)以0.3左右的概率进行随机mask,然后将相应sequence的attention mask(原来决定padding index)和label(也就是mask的ground truth)输入到bert model里面。

当然其中vocab.txt并不存在的token是需要add进去的,具体方法不再详述,网上例子很多,注意word embedding也需要初始化就行。

模型定义:
self.model = AutoModelForMaskedLM.from_pretrained('./bert')
模型的输入:
result = self.bert_model(tail_mask, attention_mask, labels)
得到模型训练的结果之后,要做一个选择:

(1)transformer的bert model可以输出要预测时间步的hidden state,可以选择取出对应的hidden state,其中需要在数据处理的时候记录下每个sequence的tail position,也就是要预测位置的idx。另外我认为既然要进行序列推荐,那么最后一个tail position的token表征一定是最重要的,所以需要对tail position的idx专门给个写死的mask,效果会好一些。然后与sequence中item的全集进行相似度的计算,再去算交叉熵loss。

bert_hidden = result.hidden_states[-1]
bert_seq_hidden = torch.zeros((self.args.batch_size, 312)).to(self.device)
for i in range(self.args.batch_size):bert_seq_hidden[i,:] = bert_hidden[i, tail_pos[i], :]
logits = torch.matmul(bert_seq_hidden, test_item_emb.transpose(0, 1))
main_loss = self.criterion(logits, targets)

(2)同时也可以result.loss直接数据mask prediction的loss,我理解这个loss面对的任务是我要求sequence中的各个token表征都要尽可能准确,都要考虑,(1)可能更加注重最后一个位置的标准的准确性。

然后在evaluate阶段,需要注意输入到模型的不再是tail_mask,而是仅仅mask掉tail token id的sequence,因为我们需要尽可能准确的序列信息,只需要保证要预测的存在mask就够了。

由于是推荐任务,而且bert得到的hidden state表征过于隐式,所以需要一定的个性化引导它进行训练。经过个人的实验也确实如此,而且结果相差很多。

以上就是我个人的总结经验,欢迎大家指点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/25277.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jupyter Notebook 未授权访问远程命令执行漏洞

漏洞描述 Jupyter是一个开源的交互式计算环境,它支持多种编程语言,包括Python、R、Julia等。Jupyter的名称来源于三种编程语言的缩写:Ju(lia)、Py(thon)和R。 Jupyter的主要特点是它以笔记本(Notebook)的形式组织代码…

PoseFormer:基于视频的2D-to-3D单人姿态估计

3D Human Pose Estimation with Spatial and Temporal Transformers论文解析 摘要1. 简介2. Related Works2.1 2D-to-3D Lifting HPE2.2 GNNs in 3D HPE2.3 Vision Transformers 3. Method3.1 Temporal Transformer Baseline3.2 PoseFormer: Spatial-Temporal TransformerSpati…

redis缓存雪崩和缓存击穿

目录 缓存雪崩 解决方案: 缓存击穿 ​解决方案 缓存雪崩 缓存雪崩是指在同一时段大量的缓存key同时失效或者Redis服务宕机,导致大量请求到达数据库,带来巨大压力。 解决方案: u 给不同的 Key 的 TTL 添加随机值 u 利用 Redis …

MongoDB SQL

Microsoft Windows [版本 6.1.7601] 版权所有 (c) 2009 Microsoft Corporation。保留所有权利。C:\Users\Administrator>cd C:\MongoDB\Server\3.4\binC:\MongoDB\Server\3.4\bin> C:\MongoDB\Server\3.4\bin> C:\MongoDB\Server\3.4\bin>net start MongoDB 请求的…

ArcGIS Pro简介下载安装地址

ArcGIS Pro简介 ArcGIS Pro是一款功能强大的地理信息系统(GIS)软件,由Esri开发。它为用户提供了一种直观、灵活且高效的方式来处理、分析和可视化地理数据。ArcGIS Pro具有现代化的用户界面和工作流程,使用户能够更好地利用地理信…

Linux下安装VMware虚拟机

目录 1. 简介 2. 工具/原料 2.1. 下载VMware 2.2. 安装 1. 简介 ​ VMware Workstation(中文名“威睿工作站”)是一款功能强大的桌面虚拟计算机软件,提供用户可在单一的桌面上同时运行不同的操作系统,和进行开发、测试 …

自监督去噪:Noise2Void原理和调用(Tensorflow)

文章原文: https://arxiv.org/abs/1811.10980 N2V源代码: https://github.com/juglab/n2v 参考博客: https://zhuanlan.zhihu.com/p/445840211https://zhuanlan.zhihu.com/p/133961768https://zhuanlan.zhihu.com/p/563746026 文章目录 1. 方法原理1.1 Noise2Noise回…

Android google admob Timeout for show call succeed 问题解决

项目场景: 项目中需要接入 google admob sdk 实现广告商业化 问题描述 在接入Institial ad 时,onAdLoaded 成功回调,但是onAdFailedToShowFullScreenContent 也回调了错误信息 “Timeout for show call succeed.” InterstitialAd.load(act…

MySQL5.7源码编译Debug版本

编译环境Ubuntu22.04LTS 1 官方下载MySQL源码 https://dev.mysql.com/downloads/mysql/?spma2c6h.12873639.article-detail.4.68e61a14ghILh5 2 安装基础软件 cmakeclangpkg-configperl 参考:https://dev.mysql.com/doc/refman/5.7/en/source-installation-prere…

命令行快捷键Mac Iterm2

原文:Jump forwards, backwards and delete a word in iTerm2 on Mac OS iTerm2并不允许你使用 ⌥← 或 ⌥→ 来跳过单词。 你也不能使用 ⌥backspace 来删除整个单词。 下面是在Mac OS上如何配置iTerm2以便能做到这一点的方法。 退格键 首先,你需要将你的左侧 ⌥…

思想道德与法治试题

1【单选题】 公园里的花木草地、亭台椅凳,街道旁的垃圾箱、电话亭,马路上的护栏、井盖等,都是保证公共生活正常进行的物质基础,任何人都不能损坏、滥用、浪费、私占。这是社会公德中(B )。 A、助人为乐的…

第5章 运算符、表达式和语句

本章介绍以下内容: 关键字:while、typedef 运算符:、-、*、/、%、、--、(类型名) C语言的各种运算符,包括用于普通数学运算的运算符 运算符优先级以及语句、表达式的含义 while循环 复合语句、自动类型转换和强制类型转换 如何编写…

MyCat配置文件schema.xml讲解

1.MyCat配置 1.1 schema标签 如果checkSQLschema配置的为false,那么执行DB01.TB_ORDER时就会报错,必须用use切换逻辑库以后才能进行查询。 sqlMaxLimit如果未指定limit进行查询,列表查询模式默认为100,最多只查询100条。因为用mycat后默认数…

计算机网络(4) --- 协议定制

计算机网络(3) --- 网络套接字TCP_哈里沃克的博客-CSDN博客https://blog.csdn.net/m0_63488627/article/details/132035757?spm1001.2014.3001.5501 目录 1. 协议的基础知识 TCP协议通讯流程 ​编辑 2.协议 1.介绍 2.手写协议 1.内容 2.接口 …

yolo-v5学习(使用yolo-v5进行安全帽检测错误记录)

常见错误 跑YOLOv5遇到的问题_runtimeerror: a view of a leaf variable that requi_Pysonmi的博客-CSDN博客 python train.py --img 640 --batch 16 --epochs 10 --data ./data/custom_data.yaml --cfg ./models/custom_yolov5.yaml --weights ./weights/yolov5s.pt 1、梯度…

Python OpenCV读取并显示USB UVC摄像头

1. 安装Python, 略。 2. 安装 OpenCV: pip install opencv-python 3. 预览摄像头画面脚本: import cv2cap cv2.VideoCapture(0, cv2.CAP_DSHOW)if not (cap.isOpened()):print("Could not open video device")cap.set(cv2.CAP_PR…

四、ESP32链接WIFI

1. 设置工作模式 Wi-Fi是基于IEEE 802.11标准的无线网络技术 让联网设备以无线电波的形式,加入采用TCP/IP通信协议的网络 Wi-Fi网络环境通常有两种设备 Access Point(AP) 无线接入点,提供无线接入的设备,家里的光猫就是结合WiFi和internet路由功能的AP。AP和AP可以相互连接…

数据可视化(七)常用图表的绘制

1. #seaborn绘制常用图表 #折线图 #replot(x,y,kind,data) #lineplot(x,y,data) #直方图 #displot(data,rug) #条形图 #barplot&…

ucore Lab8 文件系统

练习0:填写已有实验 本实验依赖实验1/2/3/4/5/6/7。请把你做的实验1/2/3/4/5/6/7的代码填入本实验中代码中有“LAB1”/“LAB2”/“LAB3”/“LAB4”/“LAB5”/“LAB6” /“LAB7”的注释相应部分。并确保编译通过。注意:为了能够正确执行lab8的测试应用程…

接口测试——python接口开发(二)

目录 1. python接口开发框架Flask简介与安装 2. 使用Flask开发一个Get接口 3. 使用Flask开发一个Post接口 4. Flask结合PyMySQL接口与数据库的交互 1. python接口开发框架Flask简介与安装 Flask接口测试框架的简介与安装Flask是轻量级的web开发框架相比于其他框架&#xff…