ROCm上情感分析:使用循环神经网络

15.2. 情感分析:使用循环神经网络 — 动手学深度学习 2.0.0 documentation (d2l.ai)

代码

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)class BiRNN(nn.Module):def __init__(self, vocab_size, embed_size, num_hiddens,num_layers, **kwargs):super(BiRNN, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)# 将bidirectional设置为True以获取双向循环神经网络self.encoder = nn.LSTM(embed_size, num_hiddens, num_layers=num_layers,bidirectional=True)self.decoder = nn.Linear(4 * num_hiddens, 2)def forward(self, inputs):# inputs的形状是(批量大小,时间步数)# 因为长短期记忆网络要求其输入的第一个维度是时间维,# 所以在获得词元表示之前,输入会被转置。# 输出形状为(时间步数,批量大小,词向量维度)embeddings = self.embedding(inputs.T)self.encoder.flatten_parameters()# 返回上一个隐藏层在不同时间步的隐状态,# outputs的形状是(时间步数,批量大小,2*隐藏单元数)outputs, _ = self.encoder(embeddings)# 连结初始和最终时间步的隐状态,作为全连接层的输入,# 其形状为(批量大小,4*隐藏单元数)encoding = torch.cat((outputs[0], outputs[-1]), dim=1)outs = self.decoder(encoding)return outsembed_size, num_hiddens, num_layers = 100, 100, 2
devices = d2l.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)def init_weights(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)if type(m) == nn.LSTM:for param in m._flat_weights_names:if "weight" in param:nn.init.xavier_uniform_(m._parameters[param])
net.apply(init_weights);glove_embedding = d2l.TokenEmbedding('glove.6b.100d')embeds = glove_embedding[vocab.idx_to_token]
embeds.shapenet.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = Falselr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)#@save
def predict_sentiment(net, vocab, sequence):"""预测文本序列的情感"""sequence = torch.tensor(vocab[sequence.split()], device=d2l.try_gpu())label = torch.argmax(net(sequence.reshape(1, -1)), dim=1)return 'positive' if label == 1 else 'negative'predict_sentiment(net, vocab, 'this movie is so great')predict_sentiment(net, vocab, 'this movie is so bad')

代码解析

这段代码实现了一个用于情感分析的双向循环神经网络(BiRNN)。下面我将逐部分用中文解析它:
1. 导入所需的库和模块:

import torch
from torch import nn
from d2l import torch as d2l

这里导入了PyTorch库、神经网络模块`nn`和基于PyTorch的深度学习库`d2l`(深度学习的一本书)。
2. 加载数据集:

batch_size = 64
train_iter, test_iter, vocab = d2l.load_data_imdb(batch_size)

加载IMDB电影评论数据集,并用迭代器`train_iter`和`test_iter`进行训练和测试。`vocab`是数据集的词汇表。
3. 定义双向循环神经网络(BiRNN)模型:

class BiRNN(nn.Module):...

创建了一个名为`BiRNN`的类,用于定义双向LSTM模型。模型有一个嵌入层(`embedding`),将词汇映射到向量空间。LSTM层(`encoder`)设定为双向,输出经过全连接层(`decoder`)得到最终的分类结果。
4. 初始化模型参数:

def init_weights(m):...
net.apply(init_weights);

init_weights函数用于模型参数的初始化。`net.apply(init_weights);`使用这个函数来应用参数初始化。
5. 加载预训练的词向量:

glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = False

使用GloVe预训练的100维词向量,并将它们复制到嵌入层`net.embedding`。同时设置`requires_grad = False`使得这些词向量在训练中不被更新。
6. 训练模型:

lr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

设置学习率和迭代次数,使用Adam优化器和交叉熵损失函数。用`d2l.train_ch13`函数来训练和评估模型。
7. 定义预测函数:

def predict_sentiment(net, vocab, sequence):...

这个函数用于预测给定文本序列的情感标签(积极或消极)。
8. 使用模型进行预测:

predict_sentiment(net, vocab, 'this movie is so great')
predict_sentiment(net, vocab, 'this movie is so bad')

调用`predict_sentiment`函数分别对两个句子进行情感预测。
整体来看,这段代码主要是利用循环神经网络对电影评论的情感进行分类,它通过加载预训练好的词向量,构建一个双向LSTM网络,并在IMDB评论数据集上进行训练和测试。最后定义了一个实用函数,用于预测输入句子的情感倾向。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/14860.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java抽象类,接口,枚举练习题

第一题: 答案: class Animal{//成员变量protected String name;protected int weight;//构造方法public Animal(){this.name"refer";this.weight50;}public Animal(String name,int weight){this.namename;this.weightweight;}//成员方法publ…

Bugku Crypto 部分题目简单题解(四)

目录 python_jail 简单的rsa 托马斯.杰斐逊 这不是md5 进制转换 affine Crack it rsa python_jail 启动场景 使用虚拟机nc进行连接 输入print(flag) 发现报错,经过测试只能传入10个字符多了就会报错 利用python中help()函数,借报错信息带出flag变…

【力扣刷题笔记第三期】Python 数据结构与算法

先从简单的题型开始刷起,一起加油啊!! 点个关注和收藏呗,一起刷题鸭!! 第一批题目 1.设备编号 给定一个设备编号区间[start, end],包含4或18的编号都不能使用,如:418、…

java抽象类和接口知识总结

一.抽象类 1.啥是抽象类 用专业语言描述就是:如果一个类中没有包含足够的信息来描绘一个具体的对象,这样的类就是抽象类 当然这话说的也很抽象,所以我们来用人话来解释一下抽象类 抛开编程语言这些,就以现实举例,我…

每日练习之排序——链表的合并;完全背包—— 兑换零钱

链表的合并 题目描述 运行代码 #include<iostream> #include<algorithm> using namespace std; int main() { int a[31];for(int i 1;i < 30;i)cin>>a[i];sort(a 1,a 1 30);for(int i 1;i < 30;i)cout<<a[i]<<" ";cout&…

Mysql之Innodb存储引擎

1.Innodb数据存储 innodb如今能够做到mysql的默认数据存储引擎&#xff0c;肯定有着其好处的&#xff0c;那么innodb有什么好处呢? 1. 当意外断电或者重启&#xff0c; InnoDB 能够做到奔溃恢复&#xff0c;撤销没有提交的数据 2.InnoDB 存储引擎维护自己的缓冲池&#xff0c…

医院挂号就诊系统的设计与实现

前端使用Vue.js 后端使用SpiringBoot MyBatis 数据使用MySQL 需要项目和论文加企鹅&#xff1a;2583550535 医院挂号就诊系统的设计与实现_哔哩哔哩_bilibili 随着社会的发展&#xff0c;医疗资源分布不均&#xff0c;患者就诊难、排队时间长等问题日益突出&#xff0c;传统的…

Hadoop3:HDFS的Fsimage和Edits文件介绍

一、概念 Fsimage文件&#xff1a;HDFS文件系统元数据的一个永久性的检查点&#xff0c;其中包含HDFS文件系统的所有目 录和文件inode的序列化信息。 Edits文件&#xff1a;存放HDFS文件系统的所有更新操作的路径&#xff0c;文件系统客户端执行的所有写操作首先 会被记录到Ed…

二叉树的链式结构

1.二叉树的遍历 2.二叉树链式结构的实现 3.解决单值二叉树题 1.二叉树的遍历 1.1前序&#xff0c;中序以及后序遍历 二叉树的遍历是按照某种特定的规则&#xff0c;依次对二叉树的结点进行相应的操作&#xff0c;并且每个结点只操作一次。 二叉树的遍历有这些规则&#xff…

主流电商平台商品实时数据采集API接口||抖音电商数据分析实例|可视化

— 1 — 抖音电商数据【抖音电商API数据采集】分析场景 1. 这里&#xff0c;我们选择“伊利”这个品牌作为案例进行分析&#xff0c;在短短的4个月里&#xff0c;从最初每月营收17.07万&#xff0c;到6月份达到了2485.54 万&#xff0c;伊利的牛奶&#xff0c;有点牛&#xff…

Spring 对 Junit4,Junit5 的支持上的运用

1. Spring 对 Junit4,Junit5 的支持上的运用 文章目录 1. Spring 对 Junit4,Junit5 的支持上的运用每博一文案2. Spring对Junit4 的支持3. Spring对Junit5的支持4. 总结&#xff1a;5. 最后&#xff1a; 每博一文案 关于理想主义&#xff0c;在知乎上看到一句话&#xff1a;“…

Xline社区会议Call Up|在 CURP 算法中实现联合共识的安全性

为了更全面地向大家介绍Xline的进展&#xff0c;同时促进Xline社区的发展&#xff0c;我们将于2024年5月31日北京时间11:00 p.m.召开Xline社区会议。 欢迎您届时登陆zoom观看直播&#xff0c;或点击“阅读原文”链接加入会议&#xff1a; 会议号: 832 1086 6737 密码: 41125…

通过cmd命令行使用用3dmax自带的vray渲染

有时调试需要使用vray渲染vrscene文件看效果&#xff0c;只装有3dmax下可以使用自带vray渲染&#xff0c;在3dmax的渲染日志里面看自带引擎路径 使用命令行进入到此目录 执行命令指定vr文件即可看到效果&#xff0c;如&#xff1a;vray.exe -sceneFile“C:\test15\202405241…

Cesium与Three相机同步(2)

之前实现了将Three相机同步到Cesium相机Cesium与Three相机同步(1)-CSDN博客 现在是将Cesium相机同步到Three相机,从而实现了相机双向同步。 <!DOCTYPE html> <html lang="en"><head><title>three.js webgl - orbit controls</title&g…

【教学类-58-03】黑白三角拼图03(4*4宫格)总数算不出+随机抽取10张

背景需求&#xff1a; 【教学类-58-01】黑白三角拼图01&#xff08;2*2宫格&#xff09;256种-CSDN博客文章浏览阅读318次&#xff0c;点赞10次&#xff0c;收藏12次。【教学类-58-01】黑白三角拼图01&#xff08;2*2宫格&#xff09;256种https://blog.csdn.net/reasonsummer/…

【Jmeter】使用Jmeter进行接口测试、跨线程组获取参数

Jmeter接口测试 Jmeter设置成中文实操练习-跨线程组提取参数&#xff0c;使用值HTTP请求默认值&HTTP信息头管理器 相信打算从事测试工程师的同学们&#xff0c;肯定对Jmeter是耳熟能详的。使用Jmeter可以进行接口测试、性能测试、压力测试等等&#xff1b;这个章节介绍如何…

cocos 通过 electron 打包成 exe 文件,实现通信问题

cocos 通过 electron 打包成 exe 文件&#xff0c;实现通信问题 首先&#xff0c;我使用的 cocos 版本是 2.4.12&#xff0c;遇到一个问题&#xff0c;是啥子呢&#xff0c;就是我要把用 cocos 开发出来的项目打包成一个 exe 可执行程序&#xff0c;使用的是 electron &#xf…

【C++算法】BFS解决多源最短路问题相关经典算法题

1.01矩阵 既然本章是BFS解决多源最短路问题&#xff0c;也就是说有若干个起点&#xff0c;那我们就可以暴力一点&#xff0c;直接把多源最短路径问题转化成若干个单源最短路径问题&#xff0c;然后将每次的步数比较一下&#xff0c;取到最短的就是最短路径的结果&#xff0c;这…

arcgis 10.6 工具栏操作error 001143 后台服务器抛出异常

arcgis 10.6 工具栏操作error 001143 后台服务器抛出异常 环境 win10arcgis 10.6 问题 执行定义投影要素转线出现 Error: 001143:后台服务器抛出异常&#xff08;差点重装10.6&#xff09; 如下图所示&#xff1a; 解决方法 通过在菜单工具条上单击地理处理 > 地理处…

设计模式使用(成本扣除)

前言 名词解释 基础名词 订单金额&#xff1a;用户下单时支付的金额&#xff0c;这个最好理解 产品分成&#xff1a;也就是跟其他人合做以后我方能分到的金额&#xff0c;举个例子&#xff0c;比如用户订单金额是 100 块&#xff0c;我方的分成是 80%&#xff0c;那么也就是…