NLP(六十二)HuggingFace中的Datasets使用

  Datasets库是HuggingFace生态系统中一个重要的数据集库,可用于轻松地访问和共享数据集,这些数据集是关于音频、计算机视觉、以及自然语言处理等领域。Datasets 库可以通过一行来加载一个数据集,并且可以使用 Hugging Face 强大的数据处理方法来快速准备好你的数据集。在 Apache Arrow 格式的支持下,通过 zero-copy read 来处理大型数据集,而没有任何内存限制,从而实现最佳速度和效率。

  当需要微调模型的时候,需要对数据集进行以下操作:

  1. 数据集加载:下载、加载数据集
  2. 数据集预处理:使用Dataset.map() 预处理数据
  3. 数据集评估指标:加载和计算指标

  可以在HuggingFace官网来搜共享索数据集:https://huggingface.co/datasets​ 。本文中使用的主要数据集为squad数据集,其在HuggingFace网站上的数据前几行如下:

squad数据集前几行

加载数据

  • 加载Dataset数据集

  Dataset数据集可以是HuggingFace Datasets网站上的数据集或者是本地路径对应的数据集,也可以同时加载多个数据集。

  以下是加载英语阅读理解数据集squad, 该数据集的网址为:https://huggingface.co/datasets/squad ,也是本文中使用的主要数据集。

import datasets# 加载单个数据集
raw_datasets = datasets.load_dataset('squad')
# 加载多个数据集
raw_datasets = datasets.load_dataset('glue', 'mrpc')
  • 从文件中加载数据

  支持csv, tsv, txt, json, jsonl等格式的文件

from datasets import load_datasetdata_files = {"train": "./data/sougou_mini/train.csv", "test": "./data/sougou_mini/test.csv"}
drug_dataset = load_dataset("csv", data_files=data_files, delimiter=",")
  • 从Dataframe中加载数据
import pandas as pd
from datasets import Dataset my_dict = {"a": [1, 2, 3], "b": ['A', 'B', 'C']} 
dataset1 = Dataset.from_dict(my_dict) df = pd.DataFrame(my_dict) 
dataset2 = Dataset.from_pandas(df)

查看数据

  • 数据结构

  数据结构包括:

  • 数据集的划分:train,valid,test数据集
  • 数据集的数量
  • 数据集的feature

  squad数据的数据结构如下:

DatasetDict({train: Dataset({features: ['id', 'title', 'context', 'question', 'answers'],num_rows: 87599})validation: Dataset({features: ['id', 'title', 'context', 'question', 'answers'],num_rows: 10570})
})
  • 数据切分
import datasetsraw_dataset = datasets.load_dataset('squad')# 获取某个划分数据集,比如train
train_dataset = raw_dataset['train']
# 获取前10条数据
head_dataset = train_dataset.select(range(10))
# 获取随机10条数据
shuffle_dataset = train_dataset.shuffle(seed=42).select(range(10))
# 数据切片
slice_dataset = train_dataset[10:20]

更多特性

  • 数据打乱(shuffle)

  shuffle的功能是打乱datasets中的数据,其中seed是设置打乱的参数,如果设置打乱的seed是相同的,那我们就可以得到一个完全相同的打乱结果,这样用相同的打乱结果才能重复的进行模型试验。

import datasetsraw_dataset = datasets.load_dataset('squad')
# 打乱数据集
shuffle_dataset = train_dataset.shuffle(seed=42)
  • 数据流(stream)

  stream的功能是将数据集进行流式化,可以不用在下载整个数据集的情况下使用该数据集。这在以下场景中特别有用:

  1. 你不想等待整个庞大的数据集下载完毕
  2. 数据集大小超过了你计算机的可用硬盘空间
  3. 你想快速探索数据集的少数样本
from datasets import load_datasetdataset = load_dataset('oscar-corpus/OSCAR-2201', 'en', split='train', streaming=True)
print(next(iter(dataset)))
  • 数据列重命名(rename columns)

  数据集支持对列重命名。下面的代码将squad数据集中的context列重命名为text:

from datasets import load_datasetsquad = load_dataset('squad')
squad = squad.rename_column('context', 'text')
  • 数据丢弃列(drop columns)

  数据集支持对列进行丢弃,在删除一个或多个列时,向remove_columns()函数提供要删除的列名。单个列删除传入列名,多个列删除传入列名的列表。下面的代码将squad数据集中的id列丢弃:

from datasets import load_datasetsquad = load_dataset('squad')
# 删除一个列
squad = squad.remove_columns('id')
# 删除多个列
squad = squad.remove_columns(['title', 'text'])
  • 数据新增列(add new columns)

  数据集支持新增列。下面的代码在squad数据集上新增一列test,内容全为字符串111:

from datasets import load_datasetsquad = load_dataset('squad')
# 新增列
new_train_squad = squad['train'].add_column("test", ['111'] * squad['train'].num_rows)
  • 数据类型转换(cast)

  cast()函数对一个或多个列的特征类型进行转换。这个函数接受你的新特征作为其参数。

from datasets import load_datasetsquad = load_dataset('squad')
# 新增列
new_train_squad = squad['train'].add_column("test", ['111'] * squad['train'].num_rows)
print(new_train_squad.features)
# 转换test列的数据类型
new_features = new_train_squad.features.copy()
new_features["test"] = Value("int64")
new_train_squad = new_train_squad.cast(new_features)
# 输出转换后的数据类型
print(new_train_squad.features)
  • 数据展平(flatten)

  针对嵌套结构的数据类型,可使用flatten()函数将子字段提取到它们自己的独立列中。

from datasets import load_datasetsquad = load_dataset('squad')
flatten_dataset = squad['train'].flatten()
print(flatten_dataset)

输出结果为:

Dataset({features: ['id', 'title', 'context', 'question', 'answers.text', 'answers.answer_start'],num_rows: 87599
})
  • 数据合并(Concatenate Multiple Datasets)

  如果独立的数据集有相同的列类型,那么它们可以被串联起来。用concatenate_datasets()来连接不同的数据集。

from datasets import concatenate_datasets, load_datasetsquad = load_dataset('squad')
squad_v2 = load_dataset('squad_v2')
# 合并数据集
squad_all = concatenate_datasets([squad['train'], squad_v2['train']])
  • 数据过滤(filter)

  filter()函数支持对数据集进行过滤,一般采用lambda函数实现。下面的代码对squad数据集中的训练集的question字段,过滤掉split后长度小于等于10的数据:

from datasets import load_datasetsquad = load_dataset('squad')
filter_dataset = squad['train'].filter(lambda x: len(x["question"].split()) > 10)

输出结果如下:

Dataset({features: ['id', 'title', 'context', 'question', 'answers'],num_rows: 34261
})
  • 数据排序(sort)

  使用sort()对列值根据其数值进行排序。下面的代码是对squad数据集中的训练集按照标题长度进行排序:

from datasets import load_datasetsquad = load_dataset('squad')
# 新增列, title_length, 标题长度
new_train_squad = squad['train'].add_column("title_length", [len(_) for _ in squad['train']['title']])
# 按照title_length排序
new_train_squad = new_train_squad.sort("title_length")
  • 数据格式(set_format)

  set_format()函数改变了一个列的格式,使之与一些常见的数据格式兼容。在类型参数中指定你想要的输出和你想要格式化的列。格式化是即时应用的。支持的数据格式有:None, numpy, torch, tensorflow, pandas, arrow, 如果选择None,就会返回python对象。

  下面的代码将新增标题长度列,并将其转化为numpy格式:

from datasets import load_datasetsquad = load_dataset('squad')
# 新增列, title_length, 标题长度
new_train_squad = squad['train'].add_column("title_length", [len(_) for _ in squad['train']['title']])
# 转换为numpy支持的数据格式
new_train_squad.set_format(type="numpy", columns=["title_length"])
  • 数据指标(load metrics)

  HuggingFace Hub上提供了一系列的评估指标(metrics),前20个指标如下:

from datasets import list_metrics
metrics_list = list_metrics()
print(', '.join(metric for metric in metrics_list[:20]))

输出结果如下:

accuracy, bertscore, bleu, bleurt, brier_score, cer, character, charcut_mt, chrf, code_eval, comet, competition_math, coval, cuad, exact_match, f1, frugalscore, glue, google_bleu, indic_glue

  从Hub中加载一个指标,使用 datasets.load_metric() 命令,比如加载squad数据集的指标:

from datasets import load_metric
metric = load_metric('squad')

  输出结果如下:

Metric(name: "squad", features: {'predictions': {'id': Value(dtype='string', id=None), 'prediction_text': Value(dtype='string', id=None)}, 'references': {'id': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}}, usage: """
Computes SQuAD scores (F1 and EM).
Args:predictions: List of question-answers dictionaries with the following key-values:- 'id': id of the question-answer pair as given in the references (see below)- 'prediction_text': the text of the answerreferences: List of question-answers dictionaries with the following key-values:- 'id': id of the question-answer pair (see above),- 'answers': a Dict in the SQuAD dataset format{'text': list of possible texts for the answer, as a list of strings'answer_start': list of start positions for the answer, as a list of ints}Note that answer_start values are not taken into account to compute the metric.
Returns:'exact_match': Exact match (the normalized answer exactly match the gold answer)'f1': The F-score of predicted tokens versus the gold answer
Examples:>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]>>> squad_metric = datasets.load_metric("squad")>>> results = squad_metric.compute(predictions=predictions, references=references)>>> print(results){'exact_match': 100.0, 'f1': 100.0}
""", stored examples: 0)

  load_metric还支持分布式计算,本文不再详细讲述。

  load_metric现在已经是老版本了,新版本将用evaluate模块代替,访问网址为:https://github.com/huggingface/evaluate 。

  • 数据映射(map)

  map就是映射,它接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset。常见的map函数的应用是对文本进行tokenize:

from datasets import load_dataset
from transformers import AutoTokenizersquad_dataset = load_dataset('squad')checkpoint = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)def tokenize_function(sample):return tokenizer(sample['context'], truncation=True, max_length=256)tokenized_dataset = squad_dataset.map(tokenize_function, batched=True)

  输出结果如下:

DatasetDict({train: Dataset({features: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 87599})validation: Dataset({features: ['id', 'title', 'context', 'question', 'answers', 'input_ids', 'token_type_ids', 'attention_mask'],num_rows: 10570})
})
  • 数据保存/加载(save to disk/ load from disk)

  使用save_to_disk()来保存数据集,方便在以后重新使用它,使用 load_from_disk()函数重新加载数据集。我们将上面map后的tokenized_dataset数据集进行保存:

tokenized_dataset.save_to_disk("squad_tokenized")

保存后的文件结构如下:

squad_tokenized/
├── dataset_dict.json
├── train
│   ├── data-00000-of-00001.arrow
│   ├── dataset_info.json
│   └── state.json
└── validation├── data-00000-of-00001.arrow├── dataset_info.json└── state.json

  加载数据的代码如下:

from datasets import load_from_disk
reloaded_dataset = load_from_disk("squad_tokenized") 

总结

  本文可作为dataset库的入门,详细介绍了数据集的各种操作,这样方便后续进行模型训练。

参考文献

  1. Datasets: https://www.huaxiaozhuan.com/工具/huggingface_transformer/chapters/2_datasets.html
  2. Huggingface详细入门介绍之dataset库:https://zhuanlan.zhihu.com/p/554678463
  3. Stream: https://huggingface.co/docs/datasets/stream
  4. HuggingFace教程 Datasets基本操作: Process: https://zhuanlan.zhihu.com/p/557032513

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/7645.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Tools: tiny-cuda-nn] Linux安装

official repo: https://github.com/NVlabs/tiny-cuda-nn 该包可以显著提高NeRF训练速度,是Instant-NGP、Threestudio和NeRFstudio等框架中,必须使用的。 1. 命令行安装 最便捷的安装方式,如果安装失败考虑本地编译。 pip install ninja g…

区块链与加密货币在Web3中的融入及意义

Web3是指下一代互联网,也被称为去中心化互联网。它的核心理念是建立一个去中心化的经济和社会系统,使得个人和社区能够更加自治和自主,而不依赖于中心化的机构和权力。 在Web3中,区块链和加密货币是非常重要的技术和概念。区块链是…

1.前端入门

文章目录 一、基础认知1.1 认识网页:1.2 五大浏览器1.3 Web标准 总结 提示:以下是本篇文章正文内容,下面案例可供参考 一、基础认知 1.1 认识网页: 1.网页由哪些部分组成? 文字、图片、音频、视频、超链接。 2.我们…

【机器学习】异常检测

异常检测 假设你是一名飞机涡扇引擎工程师,你在每个引擎出厂之前都需要检测两个指标——启动震动幅度和温度,查看其是否正常。在此之前你已经积累了相当多合格的发动机的出厂检测数据,如下图所示 我们把上述的正常启动的数据集总结为 D a t…

Jmeter常见问题之URI异常

这篇文章介绍一下"http://"重复导致的URI异常问题,通常从浏览器地址栏复制url,直接粘贴到Jmeter的http请求的服务器地址中会默认带上“http://”,要将http://删除,只写IP地址,如下图: 否则&…

项目开启启动命令整合

启动RabbitMQ管理插件 1.启动 RabbitMQ 管理插件。 rabbitmq-plugins enable rabbitmq_management rabbitmq-server # 直接启动,如果关闭窗⼝或需要在该窗⼝使⽤其他命令时应⽤就会停⽌ rabbitmq-server -detached # 后台启动 rabbitmq-server start # 启⽤服务 rab…

16.喝水

喝水 html部分 <h1>Goal: 2 Liters</h1> <div class"cup cupbig"><div class"remained"><span id"liters">2L</span><small>Remained</small></div><div class"percentage&quo…

PHY芯片的使用(三)在linux下网络PHY的移植

1 前言 配置设备树请参考上一章。此次说明还是以裕太的YT8511芯片为例。 2 需要配置的文件及路径 a. 在 .. /drivers/net/phy 目录下添加 yt_phy.c 文件&#xff08;一般来说该驱动文件由厂家提供&#xff09;&#xff1b; b. 修改.. /drivers/net/phy 目录下的 Kconfig 文…

win10电脑便签常驻桌面怎么设置?

你是否曾经因为繁忙的工作而忘记了一些重要的事项&#xff1f;相信很多人都会回答&#xff1a;忘记过&#xff01;其实在快节奏的职场中&#xff0c;我们经常需要记录一些重要的信息&#xff0c;例如会议时间、约见客户时间、今天需要完成的工作任务等。而为了能够方便地记录和…

nodejs+vue+elementui学习交流和学习笔记分享系统

Node.js 是一个基于 Chrome JavaScript 运行时建立的一个平台。 前端技术&#xff1a;nodejsvueelementui,视图层其实质就是vue页面&#xff0c;通过编写vue页面从而展示在浏览器中&#xff0c;编写完成的vue页面要能够和控制器类进行交互&#xff0c;从而使得用户在点击网页进…

Spring Cloud Alibaba 集成 Skywalking 链路追踪

Spring Cloud Alibaba 集成 Skywalking 链路追踪 简介 skywalking 是一个国产开源框架&#xff0c;2015 年由吴晟开源 &#xff0c; 2017 年加入 Apache 孵化器。skywalking 是分布式系统的应用程序性能监视工具&#xff0c;专为微服务、云原生架构和基于容器&#xff08;Doc…

redis中使用bloomfilter的白名单功能解决缓存预热问题

一 缓存预热 1.1 缓存预热 将需要的数据提前缓存到缓存redis中&#xff0c;可以在服务启动时候&#xff0c;或者在使用前一天完成数据的同步等操作。保证后续能够正常使用。 1.2 解决办法PostConstruct注解初始化

【复习16-18天】【我们一起60天准备考研算法面试(大全)-第二十四天 24/60】

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客&#xff0c;如有问题交流&#xff0c;欢迎评论区留言&#xff0c;一定尽快回复&#xff01;&#xff08;大家可以去看我的专栏&#xff0c;是所有文章的目录&#xff09;   文章字体风格&#xff1a; 红色文字表示&#…

【MATLAB】GM(1,1) 灰色预测模型及算法

一、灰色预测模型概念 灰色预测是一种对含有不确定因素的系统进行预测的方法。 灰色预测通过鉴别系统因素之间发展趋势的相异程度&#xff0c;即进行关联分析&#xff0c;并对原始数据进行生成处理来寻找系统变动的规律&#xff0c;生成有较强规律性的数据序列&#xff0c;然后…

Python TypeError: unsupported operand type(s) for +: ‘int‘ and ‘str‘

在键入数值进行相加运算时&#xff0c;报了这样一个错误 类型错误&#xff1a;不支持操作类型为整数和字符串 错误分析&#xff1a;sumsuminput() 未被系统识别&#xff0c;导致程序错误 解决方法&#xff1a;给键入的数值定义&#xff0c;声明为整数 sumsumint(input()) 即…

【论文阅读】DEPIMPACT:反向传播系统依赖对攻击调查的影响(USENIX-2022)

Fang P, Gao P, Liu C, et al. Back-Propagating System Dependency Impact for Attack Investigation[C]//31st USENIX Security Symposium (USENIX Security 22). 2022: 2461-2478. 攻击调查、关键边、入口点 开源&#xff1a;GitHub - usenixsub/DepImpact 目录 1. 摘要2. 引…

前端学习——ajax (Day3)

AJAX原理 - XMLHttpRequest 使用 XMLHttpRequest <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport&…

消息队列(一)-- RabbitMQ入门(1)

初识 RabbitMQ 核心思想&#xff1a;接收并转发消息。可以把它想象成一个邮局。 producer&#xff1a;生产者 queue&#xff1a;队列 consumer&#xff1a;消费者什么是消息队列 MQ&#xff08;Message Queue&#xff09;&#xff1a;本质是队列&#xff0c;FIFO先入先出&…

【【直流电机驱动PWN】】

直流电机驱动PWN 前面都是沙县小吃&#xff0c;这里才是满汉全席 直流电机是一种电能转化成机械能的装置 直流电机有两个电极 当电机正接 电机正转 当电机负接 电机倒转 电机还有步进电机 舵机 无刷电机 空心杯电机 因为电机是一个大功率器件并不太好直接接在IO端口上所以我…

脑电信号处理与特征提取——1. 脑电、诱发电位和事件相关电位(胡理)

目录 一、 脑电、诱发电位和事件相关电位 1.1 EEG基本知识 1.2 经典的ERPs成分及研究 1.2.1 ERPs命名规则及分类 1.2.2 常见的脑电成分 1.2.3 P300及Oddball范式 1.2.4 N400成分 一、 脑电、诱发电位和事件相关电位 1.1 EEG基本知识 EEG(Electroencephalogram)&#x…