Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

模型训练与推理

1. 模型加载

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')
# 模型加载
model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)

其中:vocab, embeddings = load_glove(glove_path)
模型构建和实例化参数:

  embeddings:输入向量,是数据集经过glove模型统一处理的词向量数值特征,hidden_dim:隐藏层特征的维度, output_dim:输出维数, n_layers:RNN 层的数量,bidirectional:是否为双向 RNN, pad_idx:padding_idx参数用于标记输入中的填充值(padding value)。在自然语言处理任务中,文本序列的长度不一致是非常常见的。为了能够对不同长度的文本序列进行批处理,我们通常会使用填充值对较短的序列进行填补。

2.模型训练

def train():# 音频数据集imdb_path = r'./IMDB/aclImdb_v1.tar.gz'# 训练集和测试集生成imdb_train, imdb_test = load_imdb(imdb_path)  # review评论-标签,数据集# 预训练词向量表glove_path = r"./IMDB/glove.6B.zip"vocab, embeddings = load_glove(glove_path)  # 预定义词向量表# 语句标签-数据集。将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。每条评论500字。lookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')pad_op = ds.transforms.PadEnd([500],pad_value=vocab.tokens_to_ids('<pad>'))  # 使用PadEnd接口,定义最大长度和补齐值(pad_value),取最大长度为500type_cast_op = ds.transforms.TypeCast(ms.float32)  # 将label数据转为float32格式# 预处理操作流水线imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])# 由于IMDB数据集本身不包含验证集,我们手动将其分割为训练和验证两部分,比例取0.7, 0.3。imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])# 调用数据集的map、split、batch为数据集处理流水线增加对应操作,返回值为新的Dataset类型。现在仅定义流水线操作,在执行时开始执行数据处理流水线,获取最终处理好的数据并送入模型进行训练。imdb_train = imdb_train.batch(64, drop_remainder=True)imdb_valid = imdb_valid.batch(64, drop_remainder=True)# 定义训练参数hidden_size = 256output_size = 1num_layers = 2bidirectional = Truelr = 0.001pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)loss_fn = nn.BCEWithLogitsLoss(reduction='mean')optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total / step_total)t.update(1)num_epochs = 50best_valid_loss = float('inf')ckpt_file_name = os.path.join(cache_dir, 'sentiment-analysis.ckpt')for epoch in range(num_epochs):train_one_epoch(model, imdb_train, epoch)valid_loss = evaluate(model, imdb_valid, loss_fn, epoch)if valid_loss < best_valid_loss:best_valid_loss = valid_lossms.save_checkpoint(model, ckpt_file_name)if __name__ == "__main__":train()

训练完成:
在这里插入图片描述

3.小结

本节实现了情感分类模型的训练,精度达到99.9%,损失值降到0.0027。下一节将进行模型部署应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/51139.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

登录案例(JAVA)

练习1 package lx2;import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Scanner;public class demo1 {/*需求&#xff1a;写一个登陆小案例。步骤…

[工具] GitHub+Gridea+GitTalk 搭建个人免费博客

文章目录 起因GitHub创建个人仓库存主页创建用于Gridea连接的Token Gridea配置 GitTalk大功告成 起因 想要搭建自己的博客网站&#xff0c;又不想花钱买域名&#xff0c;也不会前端技术&#xff0c;只能求助于简单(傻逼式)且免费的博客搭建方式。偶然间看到这种方式&#xff0…

微信答题小程序产品研发-UI界面设计

高保真原型虽然已经很接近产品形态了&#xff0c;但毕竟还不能够直接交付给开发。这时就需要UI设计师依据之前的原型设计&#xff0c;进一步细化和实现界面的视觉元素&#xff0c;包括整体视觉风格、颜色、字体、图标、按钮以及交互细节优化等。 UI设计不仅关系到用户的直观感…

java项目数据库 mysql 迁移到 达梦

目录 一、下载安装达梦数据库 1、下载 2、解压 3、安装 二、迁移 三、更改SpringBoot 的 yml文件 1、达梦创建用户 2、修改yml 一、下载安装达梦数据库 1、下载 下载地址 https://eco.dameng.com/download/ 点击下载 开发版 (X86平台) , 然后选择操作系统并点击立…

uniapp实现局域网(内网)中APP自动检测版本,弹窗提醒升级

uniapp实现局域网&#xff08;内网&#xff09;中APP自动检测版本&#xff0c;弹窗提醒升级 在开发MES系统的过程中&#xff0c;涉及到了平板端APP的开发&#xff0c;既然是移动端的应用&#xff0c;那么肯定需要APP版本的自动更新功能。 查阅相关资料后&#xff0c;在uniapp的…

【初阶数据结构】复杂度算法题篇

旋转数组 力扣原题 方案一 循环K次将数组所有元素向后移动⼀位&#xff08;代码不通过) 时间复杂度O(n2) 空间复杂度O(1) void rotate(int* nums, int numsSize, int k) {while (k--) {int end nums[numsSize - 1];for (int i numsSize - 1; i > 0; i--) {nums[i] num…

Redis:十大数据类型

键&#xff08;key&#xff09; 常用命令 1. 字符串&#xff08;String&#xff09; 1.1 基本命令 set key value 如下&#xff1a;设置kv键值对&#xff0c;存货时长为30秒 get key mset key value [key value ...]mget key [key ...] 同时设置或者获取多个键值对 getrange…

【NPU 系列专栏 2.6 -- - NVIDIA Xavier SoC】

文章目录 NVIDIA Xavier SoCXavier 主要组件Xavier SoC 的型号Xavier SoC 的算力Xavier AGXXavier NXXavier 应用场景自动驾驶机器人物联网(IoT)医疗设备NPU 对比SummaryNVIDIA Xavier SoC 英伟达 Xavier SoC 是英伟达推出的一款高性能系统级芯片,专门为人工智能(AI)和自…

scratch聊天机器人 2024年6月scratch四级 中国电子学会图形化编程 少儿编程等级考试四级真题和答案解析

目录 scratch聊天机器人 一、题目要求 1、准备工作 2、功能实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 1、思路分析 四、程序编写 五、考点分析 六、推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 5、python资料 s…

面试题:为什么 一般 weight 选择对称量化,activation 选择非对称量化?

模型的剪枝是为了减少参数量和运算量&#xff0c;而量化是为了压缩数据的占用量。 量化概念 所谓的模型量化就是将浮点存储&#xff08;运算&#xff09;转换为整型存储&#xff08;运算&#xff09;的一种模型压缩技术。 优势&#xff1a; 可以提升计算效率&#xff1b;减少…

泛微开发修炼之旅--40考勤管理篇:根据班次规则、考勤组规则(含固定值和排班制),在三方系统中获取考勤签到数据,并同步到考勤管理中的解决方案

一、需求描述 我们最近在项目上遇到了一个需求&#xff0c;需要将工厂门禁刷脸数据&#xff0c;通过考勤管理配置的规则&#xff0c;获取到对应的考勤签到数据&#xff0c;依次作为上下班打卡的时间&#xff0c;以此作为员工每天考勤的依据&#xff0c;客户的考勤比较复杂&…

《python程序语言设计》第6章15题财务应用程序:打印税款表。利用程序清单4-7的代码

6.15 打印税款表 def computeTax(status_n, income):tax 0if status_n 0:if income < 8350:tax income * 0.10elif income < 33950:tax 8350 * 0.10 (income - 8350) * 0.15elif income < 82250:tax 8350 * 0.10 (33950 - 8350) * 0.15 (income - 33950) * 0.…

《九界ol游戏源码》(游戏源码+客户端+服务端+工具+视频教程)喜欢研究游戏源码的看过来...

《九界》游戏以网络同名热门小说为文化蓝本&#xff0c;构筑了一个地海陆空四维冒险的庞大游戏世界。《九界》以“团队修真”为核心研发理念&#xff0c;引擎采用OGRE引擎&#xff0c;GUI的设计采用CEGUI&#xff0c;游戏设计&#xff0c;地图&#xff0c;音效都是花费了相当的…

哈默纳科HarmonicDrive谐波减速机的使用寿命计算

在机械传动系统中&#xff0c;减速机的应用无处不在&#xff0c;而HarmonicDrive哈默纳科谐波减速机以其独特的优势&#xff0c;如轻量、小型、传动效率高、减速范围广、精度高等特点&#xff0c;成为了众多领域的选择。然而&#xff0c;任何机械设备都有其使用寿命&#xff0c…

Python爬虫-中国汽车市场月销量数据

前言 本文是该专栏的第34篇,后面会持续分享python爬虫干货知识,记得关注。 在本文中,笔者将通过某汽车平台,来采集“中国汽车市场”的月销量数据。 具体实现思路和详细逻辑,笔者将在正文结合完整代码进行详细介绍。废话不多说,下面跟着笔者直接往下看正文详细内容。(附…

硅纪元视角 | 语音克隆突破:微软VALL-E 2,Deepfake新纪元!

在数字化浪潮的推动下&#xff0c;人工智能&#xff08;AI&#xff09;正成为塑造未来的关键力量。硅纪元视角栏目紧跟AI科技的最新发展&#xff0c;捕捉行业动态&#xff1b;提供深入的新闻解读&#xff0c;助您洞悉技术背后的逻辑&#xff1b;汇聚行业专家的见解&#xff0c;…

Web前端知识视频教程分享(五) Bootstrap

资料下载地址&#xff1a; https://545c.com/f/45573183-1336822373-45bb4f?p7526 (访问密码: 7526)

Flink内存管理机制

前言 在Flink的后台界面&#xff0c;可以看到整个Flink的内存情况。 如JobManager的内存情况&#xff1a; TaskManager的内存情况 一、Flink内存管理 Flink TaskManager内存组成整体结构图如下&#xff1a; 二、总内存管理 三、JobManager内存管理内存管理 四、TaskManager内…

vue3前端架构---打包配置

最近看到几篇vue3配置项的文章&#xff0c;转载记录一下 Vue3.2 vue/cli-service 打包 chunk-vendors.js 文件过大导致页面加载缓慢解决方案-CSDN博客文章浏览阅读2k次&#xff0c;点赞8次&#xff0c;收藏9次。Vue3.2 vue/cli-service 打包 chunk-vendors.js 文件过大导致页…

Lago - 使用 ClickHouse 扩展事件引擎

本文字数&#xff1a;4540&#xff1b;估计阅读时间&#xff1a;12 分钟 作者&#xff1a;Mathew Pregasen 本文在公众号【ClickHouseInc】首发 本周&#xff0c;我们欢迎来自 Lago 的一篇博客文章&#xff0c;介绍了他们如何使用 ClickHouse 扩展一个事件引擎&#xff0c;并在…