【古诗生成AI实战】之四——模型包装器与模型的训练

  在上一篇博客中,我们已经利用任务加载器task成功地从数据集文件中加载了文本数据,并通过预处理器processor构建了词典和编码器。在这一过程中,我们还完成了词向量的提取。

  接下来的步骤涉及到定义模型、加载数据,并开始训练过程。

  为了确保项目代码能够快速切换到不同的模型,并且能够有效地支持transformers库中的预训练模型,我们不仅仅是定义模型那么简单。为此,我们采取了进一步的措施:在模型外面再套上一个额外的层,我称之为模型包装器NNModelWrapper。此外,为了提高配置的灵活性和可维护性,我们将所有的配置项(如批量大小、数据集地址、训练周期数、学习率等)抽取出来,统一放置在一个名为WrapperConfig的配置容器中。通过这种方式,我们就可以避免直接在代码中修改配置参数,而是通过更改配置文件来实现,从而使得整个项目更加模块化和易于管理。

  本章内容属于模型训练阶段,将分别介绍包装器配置WrapperConfig、模型包装器NNModelWrapper和模型Model

在这里插入图片描述

[1] 包装器配置WrapperConfig

  我们把配置全部放在yaml文件里,然后读取里面的配置,赋值给WrapperConfig类。定义如下:

class WrapperConfig(object):"""A configuration for a :class:`NNModelWrapper`."""def __init__(self,tokenizer,max_seq_len: int,vocab_num: int,word2vec_path: str,batch_size: int = 1,epoch_num: int = 1,learning_rate: float = 0.001):self.tokenizer = tokenizerself.max_seq_len = max_seq_lenself.batch_size = batch_sizeself.epoch_num = epoch_numself.learning_rate = learning_rateself.word2vec_path = word2vec_pathself.vocab_num = vocab_num

   WrapperConfig 类用于配置神经网络模型包装器(NNModelWrapper)。类的构造函数接受多个参数来初始化配置:

  tokenizer: 分词器对象,用于文本处理或文本转换为模型可理解的格式。其实就是预处理器processor提供的tokenizer

  max_seq_len (int): 模型可以处理的最大序列长度。

  vocab_num (int): 词汇表的大小。

  word2vec_path (str):预训练的词向量模型的文件路径。即上文提取的词向量。

  batch_size (int): 每个批次处理的数据样本数量。

  epoch_num (int): 训练轮次。

  learning_rate (float): 学习率。

[2] 模型包装器NNModelWrapper

  模型包装器NNModelWrapper接受2个参数,一个是包装器配置WrapperConfig,另外一个是自定义模型Model。代码如下:

class NNModelWrapper:"""A wrapper around a Transformer-based language model."""def __init__(self, config: WrapperConfig, model):"""Create a new wrapper from the given config."""self.config = configself.model = model(self.config)def generate_dataset(self, data, labeled=True):"""Generate a dataset from the given examples."""features = self._convert_examples_to_features(data)feature_dict = {'input_ids': torch.tensor([f.input_ids for f in features], dtype=torch.long),'labels': torch.tensor([f.labels for f in features], dtype=torch.long),}if not labeled:del feature_dict['labels']return DictDataset(**feature_dict)def _convert_examples_to_features(self, examples) -> List[InputFeatures]:"""Convert a set of examples into a list of input features."""features = []for (ex_index, example) in tqdm(enumerate(examples)):if ex_index % 5000 == 0:logging.info("Writing example {}".format(ex_index))input_features = self.get_input_features(example)features.append(input_features)# logging.info(f"最终数据构造形式:{features[0]}")return featuresdef get_input_features(self, example) -> InputFeatures:"""Convert the given example into a set of input features"""text = example.textinput_ids = self.config.tokenizer(text)labels = np.copy(input_ids)labels[:-1] = input_ids[1:]assert len(input_ids) == self.config.max_seq_lenreturn InputFeatures(input_ids=input_ids, attention_mask=None, token_type_ids=None, labels=labels)

   NNModelWrapper 类是围绕一个神经网络语言模型的封装器,提供了模型的初始化和数据处理的方法。

  · 类初始化 (init):
  config: 接收一个 WrapperConfig 类的实例,包含模型的配置信息。
  model: 接收一个模型构造函数,该函数使用配置信息来初始化模型。

  · 生成数据集 (generate_dataset):从给定的数据样本中生成一个数据集。首先把数据样本转换为特征(通过 _convert_examples_to_features 方法),然后根据这些特征创建一个 DictDataset 对象。如果数据未标记(labeled=False),则从特征字典中删除 labels 键。

  · 转换样本为特征 (_convert_examples_to_features):这是个私有方法,把数据样本转换为模型可以理解的输入特征。对于每个样本,使用 get_input_features 方法来生成输入特征。使用 tqdm 显示处理进度,并利用 logging 记录处理信息。

  · 获取输入特征 (get_input_features):此方法将单个数据样本转换为输入特征。首先获取文本内容,然后使用配置中的分词器(tokenizer)将文本转换为 input_ids。标签(labels)是 input_ids 的一个变体,其中每个元素都向右移动一个位置。用断言确保 input_ids 的长度与配置中的 max_seq_len 相等。

[3] 模型Model

  模型包装器NNModelWrapper里面的第二个参数Model才是我们真正的模型。

  在古诗生成AI任务中,RNN是比较适配任务的模型,我们定义的RNN模型如下:

class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()V = config.vocab_num  # vocab_numE = 300  # embed_dimH = 256  # hidden_sizeembedding_pretrained = torch.tensor(np.load(config.word2vec_path)["embeddings"].astype('float32'))self.embedding = nn.Embedding.from_pretrained(embedding_pretrained, freeze=False)self.lstm = nn.LSTM(E, H, 1, bidirectional=False, batch_first=True, dropout=0.1)self.fc = nn.Linear(H, V)self.loss = nn.CrossEntropyLoss()def forward(self, input_ids, labels=None):embed = self.embedding(input_ids)  # [batch_size, seq_len, embed_dim]out, _ = self.lstm(embed)  # [batch_size, seq_len, hidden_size]logit = self.fc(out)  # [batch_size, seq_len, vocab_num]if labels is not None:loss = self.loss(logit.view(-1, logit.shape[-1]), labels.view(-1))return loss, logitelse:return logit[None, :]

  在我们的模型中,特别值得一提的是嵌入层(embedding layer)。在初始化这一层时,我们使用的是之前提取出的词向量。这种做法有助于模型更好地理解和处理文本数据。

  此次我们定义的模型是一个基于RNN的结构,它包括三个主要部分:embedding层、lstm层和fc(全连接)层。

  在模型的前向传播(forward)过程中,输入input_ids的形状为[batch_size, seq_len],即每个批次有多少文本,每个文本的序列长度是多少。输入数据首先通过嵌入层处理,输出的embed的形状为[batch_size, seq_len, embed_dim],即每个单词都被转换成了对应的嵌入向量。接着,数据通过一个单层的lstm网络,得到输出out,最后经过全连接层fc,得到最终的概率分布logit

  这个概率分布logit的含义是:对于每个批次中的文本,每个文本在序列的每个位置上,都有vocab_num个可能的词可以填入,而logit中存储的正是这些词的概率。为了生成文本,我们提取每个位置上概率最高的词的索引,然后根据这些索引在词典中查找对应的词。这就是我们通过模型运行文本生成得到的结果。

[4] 训练

  所有的工作都准备好了,下面我们正式开始模型的训练。

  对于神经网络的训练、验证、测试、优化等等操作,我采用了transformersTrainer极大的简化了项目操作。

  第一步,加载yaml配置文件,读取所有配置项:

    with open('config.yaml', 'r', encoding='utf-8') as f:conf = yaml.load(f.read(),Loader=yaml.FullLoader)conf_train = conf['train']conf_sys = conf['sys']

  第二步,初始化任务加载器,加载数据集:

	Task = TASKS[conf_train['task_name']]()data = Task.get_train_examples(conf_train['dataset_url'])index = int(len(data) * conf_train['rate'])train_data, dev_data = data[:index], data[index:]

  第三步,初始化数据预处理器,并向外提供tokenizer

	Processor = PROCESSORS[conf_train['task_name']](data, conf_train['max_seq_len'], conf_train['vocab_path'])tokenizer = lambda text: Processor.tokenizer(text)

  第四步,初始化模型包装配置:

	wrapper_config = WrapperConfig(tokenizer=tokenizer,max_seq_len=conf_train['max_seq_len'],batch_size=conf_train['batch_size'],epoch_num=conf_train['epoch_num'],learning_rate=conf_train['learning_rate'],word2vec_path=conf_train['word2vec_path'],vocab_num=len(Processor.vocab))

  第五步,加载模型,初始化模型包装器:

	x = import_module(f'main.model.{conf_train["model_name"]}')wrapper = NNModelWrapper(wrapper_config, x.Model)print(f'模型有 {sum(p.numel() for p in wrapper.model.parameters() if p.requires_grad):,} 个训练参数')

  第六步,使用模型包装器生成数据集向量:

train_dataset = wrapper.generate_dataset(train_data)val_dataset = wrapper.generate_dataset(dev_data)

  第七步,创建训练器:

# 构建trainer
def create_trainer(wrapper, train_dataset, val_dataset):# 模型model = wrapper.modelargs = TrainingArguments('./checkpoints',  # 模型保存的输出目录save_strategy=IntervalStrategy.STEPS,  # 模型保存策略save_steps=50,  # 每n步保存一次模型  1步表示一个batch训练结束evaluation_strategy=IntervalStrategy.STEPS,eval_steps=50,overwrite_output_dir=True,  # 设置overwrite_output_dir参数为True,表示覆盖输出目录中已有的模型文件logging_dir='./logs',  # 可视化数据文件存储地址log_level="warning",logging_steps=50,  # 每n步保存一次评价指标  1步表示一个batch训练结束 | 还控制着控制台的打印频率 每n步打印一下评价指标 | n过大时,只会保存最后一次的评价指标disable_tqdm=True,  # 是否不显示数据训练进度条learning_rate=wrapper.config.learning_rate,per_device_train_batch_size=wrapper.config.batch_size,per_device_eval_batch_size=wrapper.config.batch_size,num_train_epochs=wrapper.config.epoch_num,dataloader_num_workers=2,  # 数据加载的子进程数weight_decay=0.01,save_total_limit=2,load_best_model_at_end=True)# 早停设置early_stopping = EarlyStoppingCallback(early_stopping_patience=3,  # 如果8次验证集性能没有提升,则停止训练early_stopping_threshold=0,  # 验证集的性能提高不到0时也停止训练)trainer = Trainer(model,args,train_dataset=train_dataset,eval_dataset=val_dataset,callbacks=[early_stopping],  # 添加EarlyStoppingCallback回调函数)return trainertrainer = create_trainer(wrapper, train_dataset, val_dataset)

  第八步,开始训练并设置保存模型:

	trainer.train()trainer.save_model(conf_train['model_save_dir'] + conf_train['task_name'] + '/' + conf_train['model_name'])

  训练的整体代码如下:

# 构建trainer
def create_trainer(wrapper, train_dataset, val_dataset):# 模型model = wrapper.modelargs = TrainingArguments('./checkpoints',  # 模型保存的输出目录save_strategy=IntervalStrategy.STEPS,  # 模型保存策略save_steps=50,  # 每n步保存一次模型  1步表示一个batch训练结束evaluation_strategy=IntervalStrategy.STEPS,eval_steps=50,overwrite_output_dir=True,  # 设置overwrite_output_dir参数为True,表示覆盖输出目录中已有的模型文件logging_dir='./logs',  # 可视化数据文件存储地址log_level="warning",logging_steps=50,  # 每n步保存一次评价指标  1步表示一个batch训练结束 | 还控制着控制台的打印频率 每n步打印一下评价指标 | n过大时,只会保存最后一次的评价指标disable_tqdm=True,  # 是否不显示数据训练进度条learning_rate=wrapper.config.learning_rate,per_device_train_batch_size=wrapper.config.batch_size,per_device_eval_batch_size=wrapper.config.batch_size,num_train_epochs=wrapper.config.epoch_num,dataloader_num_workers=2,  # 数据加载的子进程数weight_decay=0.01,save_total_limit=2,load_best_model_at_end=True)# 早停设置early_stopping = EarlyStoppingCallback(early_stopping_patience=3,  # 如果8次验证集性能没有提升,则停止训练early_stopping_threshold=0,  # 验证集的性能提高不到0时也停止训练)trainer = Trainer(model,args,train_dataset=train_dataset,eval_dataset=val_dataset,callbacks=[early_stopping],  # 添加EarlyStoppingCallback回调函数)return trainerdef main():# ### @通用配置# ##with open('config.yaml', 'r', encoding='utf-8') as f:conf = yaml.load(f.read(),Loader=yaml.FullLoader)conf_train = conf['train']conf_sys = conf['sys']# 系统设置初始化System(conf_sys).init_system()# 初始化任务加载器Task = TASKS[conf_train['task_name']]()data = Task.get_train_examples(conf_train['dataset_url'])index = int(len(data) * conf_train['rate'])train_data, dev_data = data[:index], data[index:]# 初始化数据预处理器Processor = PROCESSORS[conf_train['task_name']](data, conf_train['max_seq_len'], conf_train['vocab_path'])tokenizer = lambda text: Processor.tokenizer(text)# 初始化模型包装配置wrapper_config = WrapperConfig(tokenizer=tokenizer,max_seq_len=conf_train['max_seq_len'],batch_size=conf_train['batch_size'],epoch_num=conf_train['epoch_num'],learning_rate=conf_train['learning_rate'],word2vec_path=conf_train['word2vec_path'],vocab_num=len(Processor.vocab))x = import_module(f'main.model.{conf_train["model_name"]}')wrapper = NNModelWrapper(wrapper_config, x.Model)print(f'模型有 {sum(p.numel() for p in wrapper.model.parameters() if p.requires_grad):,} 个训练参数')# 生成数据集train_dataset = wrapper.generate_dataset(train_data)val_dataset = wrapper.generate_dataset(dev_data)# 训练与保存trainer = create_trainer(wrapper, train_dataset, val_dataset)trainer.train()trainer.save_model(conf_train['model_save_dir'] + conf_train['task_name'] + '/' + conf_train['model_name'])if __name__ == '__main__':main()

  运行之后,看到下面输出代表项目成功运行:

在这里插入图片描述

[5] 进行下一篇实战

  【古诗生成AI实战】之五——加载模型进行古诗生成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/173651.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何快速检测硬盘健康程度?

当我们使用Windows11/10/8/7计算机时,可能会遇到各种各样的问题,比如蓝屏报错、系统崩溃或其他运行不正常的状况。很多时候都是因为硬盘错误或故障导致的。那么,我们该如何快速检测硬盘健康程度呢? 在驱动器属性中执行硬盘查错 硬…

【Cisco Packet Tracer】电子邮箱仿真搭建

本文使用Cisco Packet Tracer,搭建电子邮箱仿真系统,使得zhangsancisco.com可以和lisicisco.com可以互相发送邮件。 电子邮箱账号(为了简单起见,账号密码设置一致):zhangsan/lisi 域名:cisco.…

office tool plus工具破解word、visio等软件步骤

第一步:下载工具 破解需要用到office tool plus软件 office tool plus软件下载地址:Office Tool Plus 官方网站 - 一键部署 Office 选择其中一个下载到本地(本人选择的是第一个的云图小镇下载方式) 第二步:启动工具 …

Sass混合器的详细使用教程

文章目录 前言混合器何时使用混合器混合器中的CSS规则给混合器传参默认参数值后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:Sass和Less 🐱‍👓博主在前端领域还有很多知识和技术需要掌握,正在不断努…

vue3对象reactive()数据改变页面不刷新

问题vue3对象reactive()数据改变页面不刷新 首先定义一个对象 const tableData reactive({ })原因 调用后端接口赋值后页面不刷新 reactive生成的响应式数据属性 但是赋值后变成了普通数据 导致失去响应式 页面无法更新 解决方法 1.里面定义一个属性a并赋值给属性a cons…

回归预测 | MATLAB实现SMA+WOA+BOA-LSSVM基于黏菌算法+鲸鱼算法+蝴蝶算法优化LSSVM回归预测

回归预测 | MATLAB实现SMAWOABOA-LSSVM基于黏菌算法鲸鱼算法蝴蝶算法优化LSSVM回归预测 目录 回归预测 | MATLAB实现SMAWOABOA-LSSVM基于黏菌算法鲸鱼算法蝴蝶算法优化LSSVM回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 MATLAB实现SMAWOABOA-LSSVM基于黏菌算法…

【腾讯云云上实验室】用向量数据库——实现高效文本检索功能

文章目录 前言Tencent Cloud VectorDB 简介Tencent Cloud VectorDB 使用实战申请腾讯云向量数据库腾讯云向量数据库使用步骤腾讯云向量数据库实现文本检索 结论和建议 前言 想必各位开发者一定使用过关系型数据库MySQL去存储我们的项目的数据,也有部分人使用过非关…

CloudCompare 源码编译

一、下载源码 二、cmake 编译 这里面有四个比较重要的地方 1、源码的位置 2、生成的位置 3、项目的位置 4、qt 的位置 三、编译 开始测试,先用那个项目做测试 没有问题 然后用build的那个打开 加入Qt 的相关库到qcc中 启动项目生成cloudcompare 启动 ok ,完成…

本地Nginx服务搭建结合内网穿透实现多个Windows Web站点公网访问

文章目录 1. 下载windows版Nginx2. 配置Nginx3. 测试局域网访问4. cpolar内网穿透5. 测试公网访问6. 配置固定二级子域名7. 测试访问公网固定二级子域名 1. 下载windows版Nginx 进入官方网站(http://nginx.org/en/download.html)下载windows版的nginx 下载好后解压进入nginx目…

动态规划学习——等差子序列问题

目录 一,最长等差子序列 1.题目 2.题目接口 3.解题思路及其代码 二,等差序列的划分——子序列 1.题目 2.题目接口 3.解题思路及其代码 一,最长等差子序列 1.题目 给你一个整数数组 nums,返回 nums 中最长等差子序列的长度…

Golang Proxy Protocol详解

在计算机网络中,代理服务器是一种位于客户端和目标服务器之间的中间服务器,用来转发客户端请求和响应,从而实现一些特定的功能,如访问控制、安全过滤、负载均衡等。在Go语言中,我们可以使用代理协议来实现自定义的代理…

NX二次开发UF_CURVE_create_arc_center_radius 函数介绍

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_create_arc_center_radius Defined in: uf_curve.h int UF_CURVE_create_arc_center_radius(tag_t center, double radius, tag_t help_point, UF_CURVE_limit_p_t limit_p…

SparkDesk知识库 + ChuanhuChatGPT前端 = 实现轻量化知识库问答

上一篇 讯飞星火知识库文档问答Web API的使用(二) 把星火知识库搞明白了; 然后又花了时间学习了一下gradio的一些基础内容: 在Gradio实现两个下拉框进行联动案例解读:change/click/input实践(三) 在Gradio实…

49-设计问题-最小栈

原题链接: 198. 打家劫舍 题目描述: 你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间相邻的房屋在同一晚上被小偷闯入&a…

作为IT行业的过来人,宝贵的经验分享给刚入行的你

恍然间,发现自己已经在这个行业五年之久,回顾过往,思绪良多,一路走来,或多或少都经历过一些坎坷,也碰到过不少大大小小的困难。在此就不多加叙述了。 本篇文章主要想写给刚入门的程序员几个忠告&#xff0…

vue项目门店官网页面, 根据视口大小自动跳转页面逻辑(pc --> mobile / mobile -->pc)

vue门店官网页面, 根据视口大小自动跳转页面逻辑(pc --> mobile / mobile -->pc) 在app.html文件添加以下代码逻辑 pc --> mobile // PC切换M端 ;(function () {function resizeEventHandler() {var isMobile /(iPhone|iPad|iPod|iOS|Android)/i.test(window.navig…

数据结构与算法编程题27

计算二叉树深度 #define _CRT_SECURE_NO_WARNINGS#include <iostream> using namespace std;typedef char ElemType; #define ERROR 0 #define OK 1 #define Maxsize 100 #define STR_SIZE 1024typedef struct BiTNode {ElemType data;BiTNode* lchild, * rchild; }BiTNo…

2023中学生古诗文阅读专辑(初中适用)使用和备考的几点建议

上周六的2023年第八届小学生古诗文大会复选结束后&#xff0c;很多孩子和家长大呼“太难了”&#xff0c;平时刷的题好像都没用&#xff0c;蓦然回首&#xff0c;发现很多题目都在主办方出版的《古诗文阅读专辑》上&#xff0c;只是考得非常的细。 所以&#xff0c;昨天有家长在…

计算机毕业设计|基于SpringBoot+MyBatis框架的电脑商城的设计与实现(系统概述与环境搭建)

计算机毕业设计|基于SpringBootMyBatis框架的电脑商城的设计与实现&#xff08;系统概述与环境搭建&#xff09; 该项目分析着重于设计和实现基于SpringBootMyBatis框架的电脑商城。首先&#xff0c;通过深入分析项目所需数据&#xff0c;包括用户、商品、商品类别、收藏、订单…

Vue组件的自定义属性Props

Vue的组件相当于HTML中的自定义标签&#xff0c;与HTML标签属性对应的概念就是组件的Props。组件的Props是给父组件使用的&#xff0c;使用时需要明确指定属性的值&#xff0c;或者是在组件定义时&#xff0c;给属性提供默认值。组件对象使用Props时&#xff0c;要更多的地应用…