【mT5多语言翻译】之四——加载:加载数据集与模型

·请参考本系列目录:【mT5多语言翻译】之一——实战项目总览

[1] 加载数据集

  在上一篇实战博客中,我们介绍了如何下载和预处理数据集,并且介绍了如何将数据集内的文本进行分词然后保存为pt文件。

  接下来,我们讲解在模型训练时加载数据集的写法。

def build_dataset(config):def load_dataset(path, name):data = torch.cat(([torch.load(_) for _ in path]), dim=0)# 设置随机种子seed = 42torch.manual_seed(seed)# 生成一个随机排列random_permutation = torch.randperm(data.size(0))# 根据随机排列打乱张量data = data[random_permutation]for p in path:logger.info(f"{name}数据集地址========{p}")logger.info(f"数据集文本总数========{data.size(0)}")return datadev = load_dataset(config.dev_path, 'dev')train = load_dataset(config.train_path, 'train')return train, dev

  由于我们的项目有2个数据集(中韩翻译和中日翻译),因此config.dev_pathconfig.train_path是数组,内部分别是2个语料的pt文件地址:

'dev_path': [ '/home/20240321/@_Dataset_Processed/kor-zho/eval.pt','/home/20240321/@_Dataset_Processed/jpn-zho/eval.pt'],
'train_path': [ '/home/20240321/@_Dataset_Processed/kor-zho/train.pt','/home/20240321/@_Dataset_Processed/jpn-zho/train.pt']

  首先使用cat函数将两个数据集拼接起来,然后乱序打乱,最后返回即可。

  打乱的目的是让中韩翻译文本和中日翻译文本均匀分布,模型在训练时能够尽可能地均匀学习特征。

[2] 构造数据加载器

  数据加载器的作用是对上面加载的(1200万条)数据按照自定义的batch_size自动地进行批划分。代码如下:

class DatasetIterator(object):"""初始化时把数据加载到CPU迭代器每次加载batch_size个样本到GPU或CPU"""def __init__(self, batches, batch_size, device):if batch_size <= 0:raise ValueError("batch_size 必须大于 0。")self.batch_size = batch_sizeself.batches = batchesself.n_batches = len(batches) // batch_sizeself.has_residual = False  # 记录batch数量是否为整数if len(batches) % self.n_batches != 0:self.has_residual = Trueself.index = 0self.device = devicedef _to_tensor(self, datas):input_token_ids = torch.stack([data[0] for data in datas], dim=0).to(self.device)target_token_ids = torch.stack([data[1] for data in datas], dim=0).to(self.device)return input_token_ids, target_token_idsdef __next__(self):if self.has_residual and self.index == self.n_batches:batches = self.batches[self.index * self.batch_size: len(self.batches)]self.index += 1batches = self._to_tensor(batches)return batcheselif self.index >= self.n_batches:self.index = 0raise StopIterationelse:batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]self.index += 1batches = self._to_tensor(batches)return batchesdef __iter__(self):return selfdef __len__(self):if self.has_residual:return self.n_batches + 1else:return self.n_batches

  目前很多深度学习库都有成熟的数据加载器可以调用。这里选择手搓的原因是我觉得手搓的数据加载器更轻,可以非常直观的进行自定义,虽然代码少但是完全满足我们项目的功能需要。

【注】其实transformers库也提供了非常友好的代码训练api,可能都不用写多少代码,即可开始自定义任务、模型的训练。但是在使用过程中,还是遇到了一些问题:1)莫名其妙的多消耗显存,由于封装的功能太多,我很难定位到哪里出了问题;2)它提供的训练配置中有很多我不需要的功能,比如检查点保存、内部日志等等,同样也是集成的功能太多,但是我不需要。
————————————
在训练过程中,我想使项目代码尽可能的“轻量化”,并且直观、易于修改。第三方提供的训练方案都像黑盒一样,自定义起来比较麻烦。因此我还是选择了手写训练代码,并且代码并不多。

  DatasetIterator自定义了一些在迭代(for循环)时才会触发的函数。迭代时会先触发__next__函数,然后去调用_to_tensor函数将batch_size条数据加载到GPU上,之后送给模型训练。

  因此,也不同担心数据集太大,显存会承受不了的问题。

  项目中的1200万条数据先是加载到内存,然后每次只加载batch_size条数据加载到GPU上训练。

[3] 访问批数据

  定义好数据集加载函数和数据加载器之后,我们只需要写如下三行代码即可构造训练集和验证集的DataLoader:

train_data, dev_data = build_dataset(conf)
train_iter = build_iterator(train_data, conf)
dev_iter = build_iterator(dev_data, conf)

  最后我们写个循环去遍历DataLoader,即可访问批数据:

for i, (input_batch, label_batch) in enumerate(train_iter):passfor i, (input_batch, label_batch) in enumerate(dev_iter):pass

[4] 进行下一篇实战

  【mT5多语言翻译】之五——训练:中央日志、训练可视化、PEFT微调

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/806506.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【TensorRT】TensorRT C# API 项目更新 (1):支持动态Bath输入模型推理(下篇)

4. 接口应用 关于该项目的调用方式在上一篇文章中已经进行了详细介绍&#xff0c;具体使用可以参考《最新发布&#xff01;TensorRT C# API &#xff1a;基于C#与TensorRT部署深度学习模型》&#xff0c;下面结合Yolov8-cls模型详细介绍一下更新的接口使用方法。 4.1 创建并配…

Java零基础入门-Java反射机制

一、概述 我们都听说过java有个反射机制&#xff0c;通过反射机制我们可以更深入的控制程序的运行过程。例如&#xff0c;在程序进入到运行期间&#xff0c;由用户输入一个类名&#xff0c;然后我们可以动态获取到该类拥有的所有类结构、属性名和方法&#xff0c;甚至还可以任意…

sift 解释

转载 https://www.cnblogs.com/xc90/articles/11571995.html opensift code Rob Hess robwhess/opensift: Open-Source SIFT Library (github.com) 在构建图像尺度空间的过程中&#xff0c;唯一使用的核函数是高斯核&#xff0c;这一点被T Lindeber在文献《Scale-space th…

Java快速入门系列-9(Spring框架与Spring Boot —— 深度探索及实践指南)

第九章:Spring框架与Spring Boot —— 深度探索及实践指南 9.1 Spring框架概述9.2 Spring IoC容器9.3 Spring AOP9.4 Spring MVC9.5 Spring Data JPA/Hibernate9.6 Spring Boot快速入门与核心特性9.7 Spring Boot的自动配置与启动流程详解9.8 创建RESTful服务与数据库交互实践…

专为苹果系统设计的精美可视化图表 | 开源日报 No.219

danielgindi/Charts Stars: 27.3k License: Apache-2.0 Charts 是为 iOS/tvOS/OSX 提供美观图表的开源项目&#xff0c;是跨平台 MPAndroidChart 在苹果设备上的实现。该项目提供了以下主要功能和优势&#xff1a; 支持 iOS、tvOS 和 macOS 平台使用 Swift 编写&#xff0c;可…

Ceph学习 -6.Nautilus版本集群部署

文章目录 1.集群部署1.1 环境概述1.1.1 基础知识1.1.2 环境规划1.1.3 小结 1.2 准备工作1.2.1 基本环境1.2.2 软件安装1.2.3 小结 1.3 Ceph部署1.3.1 集群创建1.3.2 部署Mon1.3.3 小结 1.4 Ceph部署21.4.1 Mon认证1.4.2 Mgr环境1.4.3 小结 1.5 OSD环境1.5.1 基本环境1.5.2 OSD实…

数据结构-移除元素(简单)

题目描述 给你一个数组 nums 和一个值 val&#xff0c;你需要 原地 移除所有数值等于 val 的元素&#xff0c;并返回移除后数组的新长度。 不要使用额外的数组空间&#xff0c;你必须仅使用 O(1) 额外空间并 原地 修改输入数组。 元素的顺序可以改变。你不需要考虑数组中超出…

可视化大屏的应用(11):智慧运维领域的得力助手

一、什么是智慧运维 智慧运维&#xff08;Smart Operations and Maintenance&#xff0c;简称智慧运维&#xff09;是一种利用先进的信息技术和数据分析手段&#xff0c;对设备、设施或系统进行监测、分析和优化管理的运维方式。它通过实时监测数据、智能分析和预测&#xff0…

wpf viewmodel和界面双向通知

通知模型&#xff08;Model&#xff09;或视图模型&#xff08;ViewModel&#xff09; 这两个有什么区别?分别给我代码例子 在MVVM&#xff08;Model-View-ViewModel&#xff09;架构中&#xff0c;Model和ViewModel扮演不同的角色&#xff1a; Model表示应用程序的数据域&am…

Redis中的集群(五)

集群 在集群中执行命令 MOVED错误。 当节点发现键所在的槽并非由自己负责处理的时候&#xff0c;节点就会向客户端返回一个MOVED错误&#xff0c;指引客户端转向至正在负责槽的节点&#xff0c;MOVED错误的格式为: MOVED <slot> <ip>:<port>其中slot为键…

centos 7.9 nginx本地化安装,把镜像改成阿里云

1.把centos7.9系统切换到阿里云的镜像源 1.1.先备份 mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup1.2.下载新的CentOS-Base.repo配置文件 wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repo特别…

盲人独立出行的新里程:“盲人软件”赋能无障碍生活

作为一名资深记者&#xff0c;我始终致力于探索并分享那些以科技之力提升特殊群体生活质量的故事。最近&#xff0c;一款名为蝙蝠避障的盲人软件进入了我的视野&#xff0c;其强大的避障导航功能正悄然改变着视障人士的出行方式&#xff0c;赋予他们前所未有的独立生活能力。 …

【多线程】线程池Future和FutureTask

【多线程】线程池Future和FutureTask 【一】Future概述【1】Future的出现原因【2】Future结构图 【二】Future详解【1】Future接口源码【2】Future的5个方法【3】ThreadPoolExecutor提供了三个方法&#xff0c;来获取返回值&#xff08;1&#xff09;submit(Runnable r)&#x…

利用数组储存表格数据

原理以及普通数组储存表格信息 在介绍数组的时候说过&#xff0c;数组能够用来储存任何同类型的数据&#xff0c;这里的意思就表明只要是同一个类型的数组据就可以储存到一个数组中。那么在表格中同一行的数据是否可以储存到同一个数组中呢&#xff1f;答案自然是可以&#xff…

【ARM 裸机】汇编 led 驱动之原理分析

1、我们为什么要学习汇编&#xff1f;&#xff1f;&#xff1f; 之前我们或许接触过 STM32 以及其他的 32 位的 MCU ,都是基于 C 语言环境进行编程的&#xff0c;都没怎么注意汇编&#xff0c;是因为 ST 公司早已将启动文件写好了&#xff0c;新建一个 STM32 工程的时候&#…

回归预测 | Matlab实现SSA-GRNN麻雀算法优化广义回归神经网络多变量回归预测(含优化前后预测可视化)

回归预测 | Matlab实现SSA-GRNN麻雀算法优化广义回归神经网络多变量回归预测(含优化前后预测可视化) 目录 回归预测 | Matlab实现SSA-GRNN麻雀算法优化广义回归神经网络多变量回归预测(含优化前后预测可视化)预测效果基本介绍程序设计参考资料预测效果

图像生成:Pytorch实现一个简单的对抗生成网络模型

图像生成&#xff1a;Pytorch实现一个简单的对抗生成网络模型 前言相关介绍具体步骤准备并读取数据集定义生成器定义判别器定义损失函数定义优化器开始训练完整代码 训练生成的图片 前言 由于本人水平有限&#xff0c;难免出现错漏&#xff0c;敬请批评改正。更多精彩内容&…

Linux losetup命令教程:设置和控制循环设备(附实例详解和注意事项)

Linux losetup命令介绍 losetup&#xff08;Loop device setup&#xff09;命令在Linux操作系统中用于设置和控制循环设备。循环设备是一种伪设备&#xff0c;它使文件可以作为块设备进行访问。如果只给出了loopdev参数&#xff0c;那么将显示相应循环设备的状态。 Linux los…

实战Java高并发程序设计课

课程介绍 实战Java高并发程序设计课是一门针对Java开发者的培训课程&#xff0c;重点关注如何设计和优化高并发的程序。学员将学习到并发编程的基本概念、线程池的使用、锁机制、并发集合等技术&#xff0c;并通过实际案例进行实践操作。这门课程旨在帮助开发者掌握并发编程的…

【2024年5月备考新增】《软考案例分析答题技巧(5)采购、配置与变更、其他》

2.10 项目采购管理 采购管理过程:规划采购管理-实施采购-控制采购。 釆购步骤: ①准备釆购工作说明书(SOW)或工作大纲(TOR); ②准备高层级的成本估算,制定预算; ③发布招标广告; ④确定合格卖方的名单; ⑤准备并发布招标文件; ⑥由卖方准备并提交建议书; ⑦对建…