4.1 文本相似度(二)

目录

1 文本相似度评估

2 代码

2.1 load_dataset 方法

2.2 AutoTokenizer、AutoModelForSequenceClassification


1 文本相似度评估

      对两个文本拼接起来,然后作为一个样本喂给模型,作为一个二分类的任务;

数据处理的方式以及训练的基本流程与上文相似。

2 代码

  1. 数据预处理,把需要对比的文本放置一起,作为一个样本; tokenizer: 输入的语句是两个。分类标签的类型必须是 int,不能是其他的类型;
  2. 加载模型。
  3. 输出结果; 


2.1 load_dataset 方法

datasets是抱抱脸开发的一个数据集python库,可以很方便的从Hugging Face Hub里下载数据,也可很方便的从本地加载数据集,本文主要对load_dataset方法的使用进行详细说明。

def load_dataset(
    path: str,
    name: Optional[str] = None,
    data_dir: Optional[str] = None,
    data_files: Union[Dict, List] = None,
    split: Optional[Union[str, Split]] = None,
    cache_dir: Optional[str] = None,
    features: Optional[Features] = None,
    download_config: Optional[DownloadConfig] = None,
    download_mode: Optional[GenerateMode] = None,
    ignore_verifications: bool = False,
    save_infos: bool = False,
    script_version: Optional[Union[str, Version]] = None,
    **config_kwargs,
) -> Union[DatasetDict, Dataset]:

path:参数path表示数据集的名字或者路径。可以是如下几种形式(每种形式的使用方式后面会详细说明)
数据集的名字,比如imdb、glue
数据集文件格式,比如json、csv、parquet、txt
数据集目录中的处理数据集的脚本(.py)文件,比如“glue/glue.py”
name:参数name表示数据集中的子数据集,当一个数据集包含多个数据集时,就需要这个参数,比如glue数据集下就包含"sst2"、“cola”、"qqp"等多个子数据集,此时就需要指定name来表示加载哪一个子数据集
data_dir:数据集所在的目录
data_files:数据集文件
cache_dir:构建的数据集缓存目录,方便下次快速加载。

2.2 AutoTokenizer、AutoModelForSequenceClassification

类名称介绍
AutoTokenizerAutoTokenizer 是 Hugging Face Transformers 库中的一个类,用于自动选择适合特定预训练模型的 tokenizer。该类可以根据指定的模型名称或路径,自动选择对应的 tokenizer 类型,无需手动指定。这样可以方便地在不同的预训练模型之间切换,而无需更改代码中的 tokenizer 类型。
AutoModelForSequenceClassificationAutoModelForSequenceClassification 是 Hugging Face Transformers 库中的一个类,用于自动选择适合特定预训练模型的用于序列分类任务的模型。这个类会根据指定的模型名称或路径自动选择对应的模型类型,无需手动指定。这样可以方便地在不同的预训练模型之间切换,而无需更改代码中的模型类型。
TrainerTrainer 是 Hugging Face Transformers 库中用于训练和评估模型的高级 API。它提供了一个简单而强大的接口,用于管理训练循环、验证循环、日志记录、保存模型等任务。使用 Trainer 可以方便地训练和微调预训练模型,同时还支持分布式训练和混合精度训练等功能。
TrainingArgumentsTrainingArguments 是 Hugging Face Transformers 库中用于配置训练参数的类。通过 TrainingArguments 类,可以指定训练过程中的各种参数,如训练轮数、学习率、批次大小、日志路径、模型保存路径等。这些参数可以帮助控制训练过程的行为,并对训练过程进行定制。

from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset,load_from_disk
import traceback
from sklearn.model_selection import train_test_split#dataset = load_dataset("json", data_files="../data/train_pair_1w.json", split="train")
dataset = load_dataset("csv", data_files="/Users/user/studyFile/2024/nlp/text_similar/data/Chinese_Text_Similarity.csv", split="train")datasets = dataset.train_test_split(test_size=0.2,shuffle=True)import torchtokenizer = AutoTokenizer.from_pretrained("../chinese_macbert_base")
def process_function(examples):tokenized_examples = tokenizer(examples["sentence1"], examples["sentence2"], max_length=128, truncation=True)#  注意int(label)tokenized_examples["labels"] = [int(label) for label in examples["label"]]return tokenized_examplestokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
#tokenized_datasets# 创建模型
from transformers import BertForSequenceClassification 
model = AutoModelForSequenceClassification.from_pretrained("../chinese_macbert_base")import evaluate
acc_metric = evaluate.load("./metric_accuracy.py")
f1_metirc = evaluate.load("./metric_f1.py")
# 
# acc_metric = evaluate.load("accuracy")
# f1_metirc = evaluate.load("f1")
def eval_metric(eval_predict):predictions, labels = eval_predict#print(predictions,labels)predictions = predictions.argmax(axis=-1)#predictions = [int(p > 0.5) for p in predictions]labels = [int(l) for l in labels]# predictions = predictions.argmax(axis=-1)acc = acc_metric.compute(predictions=predictions, references=labels)f1 = f1_metirc.compute(predictions=predictions, references=labels)acc.update(f1)return acc
train_args = TrainingArguments(output_dir="./cross_model",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizeper_device_eval_batch_size=32,  # 验证时的batch_sizelogging_steps=10,                # log 打印的频率evaluation_strategy="epoch",     # 评估策略save_strategy="epoch",           # 保存策略save_total_limit=3,              # 最大保存数learning_rate=2e-5,              # 学习率weight_decay=0.01,               # weight_decaymetric_for_best_model="f1",      # 设定评估指标load_best_model_at_end=True)     # 训练完成后加载最优模型
train_args
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model, args=train_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], data_collator=DataCollatorWithPadding(tokenizer=tokenizer),compute_metrics=eval_metric)
trainer.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/11294.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c 指针基础

/* 指针练习*/ #include <stdio.h> #include <stdlib.h> void printAll(int n1, int n2, int *p1, int *p2); int main(){ //赋值操作语法演示 int num1 1111; int num2 2222; int *prt1 &num1; int *prt2 &num2; printAll(num1, num2, prt1…

maven .lastUpdated文件作用

现象 有时候我在用maven管理项目时会发现有些依赖报错&#xff0c;这时你可以看一下本地仓库中是否有.lastUpdated文件&#xff0c;也许与它有关。 原因 有这个文件就表示依赖下载过程中发生了错误导致依赖没成功下载&#xff0c;可能是网络原因&#xff0c;也有可能是远程…

平面设计基础指南:从零开始的学习之旅!

平面设计师主要做什么&#xff1f; 平面设计师通过创建视觉概念来传达信息。他们创造了从海报和广告牌到包装、标志和营销材料的所有内容&#xff0c;并通过使用形状、颜色、排版、图像和其他元素向观众传达了他们的想法。平面设计师可以在内部工作&#xff0c;专门为品牌创建…

Mac安装jadx

1、使用命令brew安装 : brew install jadx 输入完命令,等待安装完毕 备注&#xff08;关于Homebrew &#xff09;&#xff1a; Homebrew 是 MacOS 下的包管理工具&#xff0c;类似 apt-get/apt 之于 Linux&#xff0c;yum 之于 CentOS。如果一款软件发布时支持了 homebrew 安…

mac定时任务、自启动任务

https://quail.ink/mynotes/p/mac-startup-configuration-detailed-explanation <?xml version"1.0" encoding"UTF-8"?> <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.d…

【2024年5月备考新增】】 考前篇(2)《官方平台 - 考生模拟练习平台常用操作(一)》

软考考生常用操作说明 说明:模拟作答系统是旨在让考生熟悉计算机化考试环境和作答方式,模拟作答不保存考生作答 历史记录。考试题型、题量、分值、界面及文字内容以正式考试答题系统为准。 1 如何标记试题、切换试题 2 简答题如何查看历史记录、切换输入法 3 选做题,已作答…

游戏找不到steam_api64.dll如何解决,介绍5种简单有效的方法

面对“找不到steam_api64.dll&#xff0c;无法继续执行代码”的问题&#xff0c;许多游戏玩家或软件使用者可能会感到手足无措。这个错误提示意味着你的计算机系统在尝试运行某个游戏或应用程序时&#xff0c;无法定位到一个至关重要的动态链接库文件——steam_api64.dll&#…

《深入Linux内核架构》第3章 内存管理(6)

目录 3.5.7 内核中不连续页的分配 3.5.8 内核映射 本专栏文章将有70篇左右&#xff0c;欢迎关注&#xff0c;订阅后续文章。 本节讲解vmalloc, vmap&#xff0c;kmap原理。 3.5.7 内核中不连续页的分配 kmalloc函数&#xff1a;分配物理地址和虚拟地址都连续的内存。 kmall…

MongoDB聚合运算符:$type

MongoDB聚合运算符&#xff1a;$type 文章目录 MongoDB聚合运算符&#xff1a;$type语法使用可用的类型 举例 $type聚合运算符用来返回指定参数的BSON类型的字符串。。 语法 { $type: <expression> }<expression>可以是任何合法的表达式。 使用 不像查询操作符$…

Selenium + Pytest自动化测试框架实战(上)

前言 今天呢笔者想和大家来聊聊selenium自动化 pytest测试框架&#xff0c;在这篇文章里你需要知道一定的python基础——至少明白类与对象&#xff0c;封装继承&#xff1b;一定的selenium基础。这篇文章不会selenium&#xff0c;不会的可以自己去看selenium中文翻译网哟。 一…

六西格玛管理培训公司:事业进阶的充电站,助你冲破职场天花板!

六西格玛&#xff0c;源于制造业&#xff0c;却不仅仅局限于制造业。它是一种以数据为基础、以顾客为中心、以流程优化为手段的全面质量管理方法。通过六西格玛管理&#xff0c;企业可以系统性地识别并解决运营过程中的问题&#xff0c;提高产品和服务的质量&#xff0c;降低成…

导航app为什么知道还有几秒变绿灯?

在使用地图导航app行驶至信号灯的交叉路口时&#xff0c;这些应用程序会贴心地告知用户距信号灯变化还有多少秒&#xff0c;无论是即将转为绿灯还是红灯。这一智能化提示不仅使得驾驶员能适时做好起步或刹车的准备&#xff0c;有效缓解了因等待时间不确定而产生的焦虑情绪&…

GBPC2510-ASEMI工业电源专用GBPC2510

编辑&#xff1a;ll GBPC2510-ASEMI工业电源专用GBPC2510 型号&#xff1a;GBPC2510 品牌&#xff1a;ASEMI 封装&#xff1a;GBPC-4 最大重复峰值反向电压&#xff1a;1000V 最大正向平均整流电流(Vdss)&#xff1a;25A 功率(Pd)&#xff1a;中小功率 芯片个数&#x…

分布式锁之RedissonLock

什么是Redisson&#xff1f; 俗话说他就是看门狗&#xff0c;看门狗机制是一种用于保持Redis连接活跃性的方法&#xff0c;通常用于分布式锁的场景。看门狗的工作原理是&#xff1a;当客户端获取到锁之后&#xff0c;会对Redis中的一个特定的键设置一个有限的过期时间&#xff…

[附源码]传世手游_玲珑传世_GM_安卓搭建教程

本教程仅限学习使用&#xff0c;禁止商用&#xff0c;一切后果与本人无关&#xff0c;此声明具有法律效应&#xff01;&#xff01;&#xff01;&#xff01; 教程是本人亲自搭建成功的&#xff0c;绝对是完整可运行的&#xff0c;踩过的坑都给你们填上了。 如果你是小白也没…

C++ 509. 斐波那契数

文章目录 一、题目描述二、参考代码 一、题目描述 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;1 解释&#xff1a;F(2) F(1) F(0) 1 0 1 示例 2&#xff1a; 输入&#xff1a;n 3 输出&#xff1a;2 解释&#xff1a;F(3) F(2) F(1) 1 1 2 示例 3…

设计模式——访问者模式(Visitor)

访问者模式&#xff08;Visitor Pattern&#xff09;是一种将数据操作与数据结构分离的设计模式。这种模式适用于数据结构相对稳定&#xff0c;而操作算法经常改变的情况。访问者模式将数据结构&#xff08;稳定的部分&#xff09;中的元素&#xff08;Element&#xff09;的访…

C语言题目:一元二次方程

题目描述 解一元二次方程ax^2bxc0的解。 输入格式 a,b,c的值。 输出格式 输出两个解&#xff0c;按照大小顺序输出&#xff0c;一个解时需要打印两次&#xff0c;不用考虑无解问题&#xff0c;保留两位小数 样例输入 1 5 -2样例输出 0.37 -5.37 代码解析 首先&#xff0…

了解进程和线程

一、进程和线程 类比&#xff1a; 一个工厂&#xff0c;至少有一个车间&#xff0c;一个车间中至少有一个工人&#xff0c;最终是工人在工作。 一个程序&#xff0c;至少有一个进程&#xff0c;一个进程中至少有一个线程&#xff0c;最终是线程在工作。 进程&#xff1a;是计…

C#正则表达式,提取信息使用

正则表达式简介 在C#中&#xff0c;正则表达式&#xff08;Regular Expression&#xff0c;通常简写为regex或regexp&#xff09;是一种功能强大的文本处理工具&#xff0c;它使用特定的字符序列来定义搜索模式&#xff0c;从而实现对文本的高效搜索、匹配和替换操作。正则表达…