6.2 通过构建情感分类器训练词向量

        在上一节中,我们简要地了解了词向量,但并没有去实现它。在本节中,我们将下载一个名为IMDB的数据集(其中包含了评论),然后构建一个用于计算评论的情感是正面、负面还是未知的情感分类器。在构建过程中,还将为 IMDB 数据集中存在的词进行词向量的训练。我们将使用一个名为 torchtext 的库,这个库使下载、向量化文本和批处理等许多过程变得更加容易。训练情感分类器将包括以下步骤。

  1. 下载 IMDB 数据并对文本分词;
  2. 建立词表;
  3. 生成向量的批数据;
  4. 使用词向量创建网络模型;
  5. 训练模型。

6.2.1 下载 IMDB 数据并对文本分词

        对于与计算机视觉相关的应用,我们使用过 torchvision 库。它提供了许多实用功能,并帮助我们构建计算机视觉应用程序。同样,有一个名为 torchtext 的库,它也是 PyTorch 的一部分,它与 PyTorch 一起工作,通过为文本提供不同的数据加载器和抽象,简化了许多与自然语言处理相关的活动。在本书写作时,torchtext 没有包含在 PyTorch 包内,需要独立安装。可以在计算机的命令行中运行以下代码来安装torchtext:

pip install torchtext

        安装完成后就可以使用它了。torchtext 提供了两个重要的模块:torchtext.data和torchtext.datasets。

        1. torchtext.data

        torchtext.data 实例定义了一个名为 Field 的类,它可以用来定义数据如何读取和分词。让我们看一下使用它来准备 IMDB 数据集的示例:

from torchtext import data
TEXT = data.Field(lower=True, batch_first=True, fix_length=20)
LABEL = data.Field(sequential=False)

        在上述代码中,我们定义了两个 Field 对象,一个用于实际的文本,另一个用于标签数据。对于实际的文本,我们期望 torchtext 将所有文本都小写并对文本分词,同时将其修整为最大长度为20。如果我们正在为生产环境构建应用程序,则可以将长度修正为更大的数字。当然对于当前练习的例子,20的长度够用了。Field 的构造函数还接受另一个名为 tokenize 的参数,该参数默认使用str.split 函数。还可以指定spaCy作为参数或任何其他分词器。我们的例子将使用 str.split。

        2. torchtext.datasets

        torchtext.datasets 实例提供了使用不同数据集的封装,如IMDB、TREC(问题分类)、语言建模(WikiText-2)和一些其他数据集。我们将使用 torch.datasets 下载 IMDB 数据集并将其拆分为 train 和 test 数据集。以下代码执行此操作,当第一次运行它时,可能需要几分钟,具体取决于网络连接速度,因为它是从 Internet 上下载 IMDB 数据集的:

train,test=datasets.IMDB.splits(TEXT,LABEL)

        之前的数据集的 IMDB 类抽象出了下载、分词和将数据库拆分为 train 和 test 数据集涉及的所有复杂度。train.fields 包含一个字典,其中 TEXT 是键,值是 LABEL。让我们看看 train.fields 和 train 集合的每个元素:

print('train.fields',train.fields)

        从这些结果中可以看到,单个元素包含了一个字段 text 和表示 text 的所有 token,以及包含了文本标签的字段 label。现在已准备好对 IMDB 数据集进行批处理了。

6.2.2 构建词表

        当为 thor_review 创建独热编码时,同时创建了一个作为词表的 word2idx 字典,它包含文档中唯一词的所有细节。torchtext 实例使处理更加容易。在加载数据后,可以调用 build_vocab 并传入负责为数据构建词表的必要参数。以下代码说明了如何构建词表:

TEXT.build_vocab(train,vectors=GloVe(name=,6B,dim=300),max_size=10000,min_freq=10)
LABEL.build_vocab(train)

        在上述代码中,传入了需要构建词表的 train 对象,并让它使用维度为 300 的预训练词向量来初始化向量。当使用预训练权重训练情感分类器时,build_vocab 对象只是下载并创建稍后将使用的维度。max_size 实例限制了词表中词的数量,而min_freg删除了出现不超过10 次的词,其中 10是可配置的。
        当词汇表构建完成后,我们就可以获得例如词频、词索引和每个词的向量表示等不同的值。下面的代码演示了如何访问这些值:

print(TEXT.vocab.freqs)

        以下代码演示了如何访问结果:

print(TEXT.vocab.vectors)

        使用 stoi 访问包含词及其索引的字典。

6.2.3 生成向量的批数据

        torchtext 提供了 BucketIterator,它有助于批处理所有文本并将词替换成词的索引。BucketIterator 实例带有许多有用的参数,如batch_size、device(GPU或CPU)和 shuffle (是否必须对数据进行混洗)。下面的代码演示了如何为 train 和 test 数据集创建生成批处理的迭代器:

train_iter, test_iter = data.BucketIterator.splits((train, test),
batch_size=128,device=-1,shuffle=True)
#device = -1 表示使用 cpu,设置为 None 时使用 gpu.

        上述代码为 train 和 test 数据集提供了一个 BucketIterator 对象。以下代码将说明如何创建 batch 并显示 batch 的结果:

batch = next(iter(train_iter))
batch.text

        从上面代码段的结果中,可以看到文本数据如何转换为 batch_size * fix_len (即128x20) 大小的矩阵。

6.2.4 使用词向量创建网络模型

        我们之前简要地讨论过词向量。在本节中,我们将创建作为网络架构的一部分的词向量,并训练整个模型用以预测每个评论的情感。在训练结束时,将得到一个情感分类器模型,以及 IMDB 数据集的词向量。以下代码演示了如何使用词向量创建用于情感预测的网络架构:

class EmbNet(nn.Module):def _init_(self,emb_size,hidden_sizel,hidden_size2 = 400):super()._init_()self.embedding = nn.Embedding(emb_size,hidden_sizel)self.fc = nn.Linear(hidden_size2,3)def forward(self,x):embeds = self.embedding(x).view(x.size(0),-1)out = self.fc(embeds)return F.log_softmax(out, dim = -1)

        在上述代码中,EmbNet 创建了情感分类模型。在_init_函数中,我们使用两个参数初始化了 nn.Embedding 类的一个对象,它接收两个参数,即词表的大小和希望为每个单词创建的维度。由于限制了唯一单词的数量,因此词表的大小将为10,000,并且我们可以从一个小的向量尺寸(比如10)开始。为了快速运行程序,有必要使用个小尺寸的向量值,但是当试图为生产系统构建应用程序时,请使用大尺寸的词向量。我们还有一个线性层,将词向量映射到情感的类别(如正面、负面或未知)。
        forward 函数确定了输入数据的处理方式。对于批量大小为 32 以及最大长度为 20 个词的句子,输入形状为 32x20。第一个 embedding 层充当查找表,用相应的词向量替换掉每个词。对于向量维度 10,当每个词被其相应的词向量替换时,输出形状变成了 32x20x10。view 函数将使 embedding 层的结果变得扁平。传递给 view 函数的第一个参数将保持维数不变。在我们的例子中,我们不希望组合来自不同批次的数据,因此保留第一个维数并将张量中的其余值扁平化。在应用 view 函数后,张量形状变为 32x200。全连接层将扁平化的词向量映射到类别的编号。定义了网络后就可以像往常一样训练它了。

6.2.5 训练模型

        训练模型与在构建图像分类器时看到的非常类似,因此将使用相同的函数。我们把批数据传入模型并计算输出和损失,然后优化包括词向量权重在内的模型权重。以下代码执行此操作:

def fit(epoch,model,data_loader,phase=,training,,volatile=False):if phase == 'training':model.train()if phase == 'validation':model.eval()volatile = True running_loss = 0.0running_correct = 0for batch_idx r batch in enumerate(data_loader):text, target = batch.text r batch.labelif is_cuda:text,target = text.cuda(), target.cuda()if phase =='training':optimizer.zero_grad()output = model(text)loss = F.nll_loss(output,target)running loss += F.nll loss(output,target,size_average=False).data[0]preds = output.data.max(dim=1,keepdim=True)[1]running_correct += preds.eq(target.data.view_as(preds)).cpu().sum()if phase == 'training': loss.backward()optimizer.step()loss = running_loss/len(data_loader.dataset)accuracy = 100. * running_correct/len(data_loader.dataset)print(f'{phase} loss is (loss:(5}.{2}} and {phase} accuracy is {running_correct}/{len(data_loader.dataset)}{accuracy:{10}.{4}}')return loss,accuracytrain_losses,train_accuracy = [],[]val_losses,val_accuracy = [],[]train_iter.repeat = Falsetest_iter.repeat = Falsefor epoch in range(1,10):epoch_loss,epoch_accuracy = fit(epoch,model,train_iter,phase='training')val_epoch_loss,val_epoch_accuracy = fit(epoch,model,test_iter,phase='validation')train_losses.append(epoch_loss)train_accuracy.append(epoch_accuracy)val_losses.append(val_epoch_loss)val_accuracy.append(val_epoch_accuracy)

        在上述代码中,通过传入为批处理数据创建的 BucketIterator 对象来调用 fit 方法。默认情况下,迭代器不会停止生成批数据,因此必须将 BucketIterator 对象的 repeat 变量设置为 False。如果不将 repeat 变量设置为 False,那么 fit 函数将无限地运行。模型训练10轮后得到的验证准确率约为70%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/34741.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第二期书生·浦语大模型实战营优秀项目一览

书生浦语社区于 2023 年年底正式推出了书生浦语大模型实战营系列活动,至今已有两期五批次同学参加大模型学习、实战,线上课程累计学习超过 10 万人次。 实战营特设项目实践环节,提供 A100 算力支持,鼓励学员动手开发。第 2 期实战…

【移动应用开发期末复习】第五/六章

系列文章 第一章——Android平台概述 第一章例题 第二章——Android开发环境 第二章例题 第三章 第三章例题 第四章 系列文章界面布局设计线性布局表格布局帧布局相对布局约束布局控制视图界面的其他方法代码控制视图界面数据存储与共享首选项信息数据文件SQLite数据库Content…

HarmonyOS Next开发学习手册——进程模型线程模型

进程模型 系统的进程模型如下图所示: 应用中(同一包名)的所有PageAbility、ServiceAbility、DataAbility、FormAbility运行在同一个独立进程中,即图中绿色部分的“Main Process”。 WebView拥有独立的渲染进程,即图中…

2023: 芒种集•序言

2023: 芒种集•序言 2023: 芒种集•序言 从西南旅游回来,一直忙着整理游记“2024:追寻红色足迹”,之后又应初建平索要刘桂蓉遗作“我们一起走过”,于是把“别了,老屋”和诗作“二月”一并合编,把我写的悼念…

oceanbase数据库安装和连接实战(阿里云服务器操作)

本文主要是安装oceanbase的单机版进行数据库的基础使用,oceanbase的数据库是兼容mysql数据库的,实际的兼容程度需要更深度的测试,本文主要是安装oceanbase并使用SQLynx的mysql驱动连接使用oceanbase数据库。 目录 1. 基础介绍 2. 安装说明 …

【Python datetime模块精讲】:时间旅行者的日志,精准操控日期与时间

文章目录 前言一、datetime模块简介二、常用类和方法三、date类四、time类五、datetime类六、timedelta类七、常用的函数和属性八、代码及其演示 前言 Python的datetime模块提供了日期和时间的类,用于处理日期和时间的算术运算。这个模块包括date、time、datetime和…

STL迭代器的基础应用

STL迭代器的应用 迭代器的定义方法: 类型作用定义方式正向迭代器正序遍历STL容器容器类名::iterator 迭代器名常量正向迭代器以只读方式正序遍历STL容器容器类名::const_iterator 迭代器名反向迭代器逆序遍历STL容器容器类名::reverse_iterator 迭代器名常量反向迭…

C# SerialPort串口通讯

串口通信 在.NET平台下创建C#串口通信程序,.NET 2.0提供了串口通信的功能,其命名空间是System.IO.Ports。这个新的框架不但可以访问计算机上的串口,还可以和串口设备进行通信。 创建C#串口通信程序之命名空间 System.IO.Ports命名空间中最重…

solidity智能合约如何实现跨合约调用函数

背景 比如现在有一个需求、我需要通过外部合约获取BRC20 token的总交易量。那么我需要在brc20的转账函数里面做一些调整,主要是两个函数内统计转移量。然后再提供外部获取函数。 /*** dev Sets amount as the allowance of spender over the callers tokens.** Ret…

文化财经wh6boll带macd多空转折点提示指标公式源码

文化财经wh6boll带macd多空转折点提示指标公式源码: DIFF:EMA(CLOSE,12) - EMA(CLOSE,26); DEA:EMA(DIFF,9); MACD:2*(DIFF-DEA); MID:MA(CLOSE,26);//求N个周期的收盘价均线,称为布林通道中轨 TMP2:STD(CLOSE,26);//求M个周期内的收盘价的标准差 …

onlyoffice实现在单页面加载文档的功能

草图 实现案例的基本原型 这里我们的样式库使用的是Tailwindcss,我们的前端UI组件库使用的是Ant Design Vue。 基本原型是,有个按钮,没有点击按钮的时候,页面显示的时普通的内容。当点击这个按钮的时候,页面加载文档…

【Linux】线程Thread

🔥博客主页: 我要成为C领域大神🎥系列专栏:【C核心编程】 【计算机网络】 【Linux编程】 【操作系统】 ❤️感谢大家点赞👍收藏⭐评论✍️ 本博客致力于知识分享,与更多的人进行学习交流 ​ ​ 线程概述 …

云层区分神经网络模型——二分类

云层区分神经网络模型——二分类 问奶奶,是什么让他们维护一份感情长达年,奶奶说那个年代什么东西坏了都会想要修,现在什么坏了都想着换。 安装依赖 # 要运行脚本,请先安装以下库:pip install tensorflowpip install …

JAVA每日作业day6.26

ok了家人们,今天我们学习了面向对象-多态,话不多说我们一起来看看吧 一.多态概述 面向对象的第三大特性:封装、继承、多态 我们拿一个生活中的例子来看 生活中,比如跑的动作,小猫、小狗和大象,跑起来是不一…

山水风景视频素材去哪里下?去哪里找?山水风景下载网站分享

在这个数字时代,视频已经成为最直观、有效的传达情感和分享故事的工具。对于那些渴望通过视频传递视觉美感和情感共鸣的创作者来说,拥有高质量的山水风景视频素材是关键。互联网虽然是一个信息量庞大的平台,但找到令人赞叹的山水风景视频素材…

【Linux】使用ntpdate同步时间

ntpdate 是一个在 Linux 系统中用于同步系统时间的命令行工具,它通过与 NTP 服务器通信来调整本地系统时钟。然而,需要注意的是,ntpdate 已经被许多现代 Linux 发行版弃用。 安装 yum install -y ntpdate 查看时间 date同步时间 ntpdate ntp…

问界M9累计大定破10万,创中国豪车新纪录

ChatGPT狂飙160天,世界已经不是之前的样子。 更多资源欢迎关注 6月26日消息,华为常务董事、终端BG董事长、智能汽车解决方案BU董事长余承东今日宣布,问界M9上市6个月,累计大定突破10万辆。 这一成绩,也创造了中国市场…

postman汉化中文(Windows)

Postman 是一款专业的 API 开发工具,为开发者提供了创建、测试、调试和分享 HTTP 请求的便利性和灵活性。其主要功能包括请求构建与发送、自动化测试、团队协作与分享、实时监视与调试以及环境与变量管理。无论是个人开发者还是团队,Postman 都能有效地提…

深入了解 msvcr120.dll问题解决指南,msvcr120.dll在电脑中的重要性

在Windows操作系统中,.dll 文件扮演了非常重要的角色,它们包含许多程序运行所需的代码和数据。其中 msvcr120.dll 是一个常见的动态链接库文件,是 Microsoft Visual C Redistributable Packages 的一部分。这篇文章将探讨 msvcr120.dll 的功能…

使用Python进行并发和并行编程:提高效率的秘诀

使用Python进行并发和并行编程:提高效率的秘诀 ​ 大家好,今天我们来聊聊如何使用Python进行并发和并行编程,以提升数据处理的效率;在之前的文章中,我们探讨了Python的函数式编程和数据流处理。今天,我们将…