bert ranking pairwise demo

下面是用bert 训练pairwise rank 的 demo

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.metrics import pairwise_distances_argmin_minclass PairwiseRankingDataset(Dataset):def __init__(self, sentence_pairs, tokenizer, max_length):self.input_ids = []self.attention_masks = []for pair in sentence_pairs:encoded_pair = tokenizer(pair, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')self.input_ids.append(encoded_pair['input_ids'])self.attention_masks.append(encoded_pair['attention_mask'])self.input_ids = torch.cat(self.input_ids, dim=0)self.attention_masks = torch.cat(self.attention_masks, dim=0)def __len__(self):return len(self.input_ids)def __getitem__(self, idx):input_id = self.input_ids[idx]attention_mask = self.attention_masks[idx]return input_id, attention_maskclass BERTPairwiseRankingModel(torch.nn.Module):def __init__(self, bert_model_name):super(BERTPairwiseRankingModel, self).__init__()self.bert = BertModel.from_pretrained(bert_model_name)self.dropout = torch.nn.Dropout(0.1)self.fc = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = self.dropout(outputs[1])logits = self.fc(pooled_output)return logits.squeeze()# 初始化BERT模型和分词器
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)# 示例输入数据
sentence_pairs = [('I like cats', 'I like dogs'),('The sun is shining', 'It is raining'),('Apple is a fruit', 'Car is a vehicle')
]# 超参数
batch_size = 8
max_length = 128
learning_rate = 1e-5
num_epochs = 5# 创建数据集和数据加载器
dataset = PairwiseRankingDataset(sentence_pairs, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 初始化模型并加载预训练权重
model = BERTPairwiseRankingModel(bert_model_name)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)# 训练模型
model.train()for epoch in range(num_epochs):total_loss = 0for input_ids, attention_masks in dataloader:optimizer.zero_grad()logits = model(input_ids, attention_masks)# 计算损失函数(使用对比损失函数)pos_scores = logits[::2]  # 正样本分数neg_scores = logits[1::2]  # 负样本分数loss = torch.relu(1 - pos_scores + neg_scores).mean()total_loss += loss.item()loss.backward()optimizer.step()print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")# 推断模型
model.eval()with torch.no_grad():embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids)pairwise_distances = pairwise_distances_argmin_min(embeddings.numpy())# 输出结果
for i, pair in enumerate(sentence_pairs):pos_idx = pairwise_distances[0][2 * i]neg_idx = pairwise_distances[0][2 * i + 1]pos_dist = pairwise_distances[1][2 * i]neg_dist = pairwise_distances[1][2 * i + 1]print(f"Pair: {pair}")print(f"Positive example index: {pos_idx}, Distance: {pos_dist:.4f}")print(f"Negative example index: {neg_idx}, Distance: {neg_dist:.4f}")print()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/76367.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ClickHouse进阶(十三):Clickhouse数据字典-3-文件数据源及Mysql数据源

进入正文前,感谢宝子们订阅专题、点赞、评论、收藏!关注IT贫道,获取高质量博客内容! 🏡个人主页:含各种IT体系技术,IT贫道_大数据OLAP体系技术栈,Apache Doris,Kerberos安全认证-CSDN博客 📌订阅…

鸿蒙应用程序入口UIAbility详解

一、UIAbility概述 UIAbility是一种包含用户界面的应用组件,主要用于和用户进行交互。UIAbility也是系统调度的单元,为应用提供窗口在其中绘制界面。每一个UIAbility实例,都对应于一个最近任务列表中的任务。一个应用可以有一个UIAbility&am…

LVS + Keepalived群集

文章目录 1. Keepalived工具概述1.1 什么是Keepalived1.2 工作原理1.3 Keepailved实现原理1.4 Keepalived体系主要模块及其作用1.5 keepalived的抢占与非抢占模式 2. 脑裂现象 (拓展)2.1 什么是脑裂2.2 脑裂的产生原因2.3 如何解决脑裂2.4 如何预防脑裂 …

介绍PHP

PHP是一种流行的服务器端编程语言,用于开发Web应用程序。它是一种开源的编程语言,具有易学易用的语法和强大的功能。PHP支持在服务器上运行的动态网页和Web应用程序的快速开发。 PHP可以与HTML标记语言结合使用,从而能够生成动态的Web页面&a…

关于前端就业前景的一点看法

一、前言 最近,很多人在知乎上鼓吹前端未来会没掉的概念。在此我想说这个说法并不真实,而是一种极端的观点。 事实上,前端开发在当今的互联网行业中扮演着至关重要的角色,它是构建 Web 应用程序所必需的一部分,能够实现…

项目中应该使用nginx还是拦截器来封禁IP

项目中应该使用nginx还是拦截器来封禁IP 在项目中,使用 Nginx 或拦截器(例如 Spring Interceptor)来封禁 IP 地址都是可行的方法,具体选择取决于你的需求和项目架构。 Nginx 是一种高性能的 Web 服务器和反向代理服务器&#xf…

小白学go基础04-命名惯例对标识符进行命名

计算机科学中只有两件难事:缓存失效和命名。 命名是编程语言的要求,但是好的命名却是为了提高程序的可读性和可维护性。好的命名是什么样子的呢?Go语言的贡献者和布道师Dave Cheney给出了一个说法:“一个好笑话,如果你…

如何使用命令行参数?

使用命令行参数是C语言编程中非常常见和有用的技巧。命令行参数允许您在运行程序时向程序传递信息,这样程序可以根据不同的输入执行不同的操作。在本文中,我们将详细讨论如何在C语言中使用命令行参数,包括如何访问和解析命令行参数、处理不同…

TSINGSEE青犀AI视频分析/边缘计算/AI算法·人脸识别功能——多场景高效运用

旭帆科技AI智能分析网关可提供海量算法供应,涵盖目标监测、分析、抓拍、动作分析、AI识别等,可应用于各行各业的视觉场景中。同时针对小众化场景可快速定制AI算法,主动适配大厂近百款芯片,打通云/边/端灵活部署,算法一…

前端中的事件委托

前端小知识 事 件 委 托 作者:李俊才 (jcLee95):https://blog.csdn.net/qq_28550263 邮箱 :291148484163.com 本文地址:https://blog.csdn.net/qq_28550263/article/details/132819265 【介绍】&#xff1…

目标检测YOLO实战应用案例100讲-森林野火预警的小目标检测(续)

目录 3.2 实验数据评价指标 3.3 YOLO算法 3.3.1 YOLO算法原理 3.3.2 YOLO v5 网络模型

VM安装RedHat7虚机ens33网络不显示IP问题解决

1、今天在VMware中安装RedHat7.4虚拟机,网络连接使用的是 NAT 连接方式,刚开始安装成功之后输入ifconfig 还能看到ens33自动分配的IP地址,但是当虚机关机重启后,再查看IP发现原来的ens33网络已经没有了,只变成了这两个…

XML格式转JSON格式

前言: XML和JSON是两种常见的数据交换格式,它们在现代软件开发中扮演着重要的角色。本文将介绍这两种格式的基本概念、特点以及它们的使用场景,以帮助更好地理解和应用它们。 XML(可扩展标记语言)和JSON(Ja…

Mybatis -- 读取 DATE 类型字段时可能遇到的问题(夏令时问题)

在使用 MYBATIS 读取数据库字段的时候,我们一般需要为查询字段指定数据类型。特别是当我们使用 mybatis generator 去生成对应的接口代码时,会自动按照数据库字段类型生成响应映射规则的代码。   如下,左侧是 date 类型生成的字段映射规则&…

MySQL数据库——存储引擎(2)-存储引擎特点(InnoDB、MyISAM、Memory)、存储引擎选择

目录 存储引擎特点 InnoDB 介绍 特点 文件 逻辑存储结构 MyISAM 介绍 特点 文件 Memory 介绍 特点 文件 区别及特点 存储引擎选择 存储引擎特点 上面我们介绍了什么是存储引擎,以及如何在建表时如何指定存储引擎,接下来我们就来介绍比较…

JS中 bind()的用法,call(),apply(),bind()异同点及使用,如何手写一个bind()

✨什么是bind() bind()的MDN地址 bind() 方法创建一个新函数,当调用该新函数时,它会调用原始函数并将其 this 关键字设置为给定的值,同时,还可以传入一系列指定的参数,这些参数会插入到调用新函数时传入的参数的前面。…

Map集合

Map中常见的API Map<键值对类型&#xff0c; 键值对对象类型> put&#xff08;K key , V value&#xff09;【可以有返回值】 添加/覆盖元素 在添加数据的时候&#xff0c;如果键不存在&#xff0c;那么直接将键对对象添加到map集合当中 在添加数据的时候&#xff0c…

云原生Kubernetes:pod基础

目录 一、理论 1.pod 2.pod容器分类 3.镜像拉取策略&#xff08;image PullPolicy&#xff09; 二、实验 1.Pod容器的分类 2.镜像拉取策略 三、问题 1.apiVersion 报错 2.pod v1版本资源未注册 3.取行显示指定pod信息 四、总结 一、理论 1.pod (1) 概念 Pod是ku…

jwt自定义表签发、jwt 多方式登录(auth的user表)

补充 # 1 接口文档编写规范&#xff1a;-1 描述-2 请求地址-3 请求方式-4 请求参数-headers-请求体-请求参数-5 请求编码格式-6 返回格式-示例-返回数据字段含义-其他&#xff1a;-错误状态码-...-接口文档编写位置-写在文件中&#xff1a;word&#xff0c;md&#xff0c;跟前…

centos定期清理磁盘

centos/linux定期清理磁盘 要定时清理空间&#xff0c;我们需要了解一个命令&#xff0c;find 命令&#xff0c;这个命令可以查询目录下特定文件名&#xff0c;生成日期的文件 小白教程&#xff0c;一看就会&#xff0c;一做就成。 1.查找需要删除的 find /data_back/zhhyba…