NLP_新闻主题分类_7(代码示例)

目标

  • 有关新闻主题分类和有关数据.
  • 使用浅层网络构建新闻主题分类器的实现过程.

1 案例说明

  • 关于新闻主题分类任务:

    • 以一段新闻报道中的文本描述内容为输入, 使用模型帮助我们判断它最有可能属于哪一种类型的新闻, 这是典型的文本分类问题, 我们这里假定每种类型是互斥的, 即文本描述有且只有一种类型.
    • 新闻主题分类数据:        
  • 数据文件预览:
  • # 数据集在虚拟机/root/data/ag_news_csv下
    - data/- ag_news_csv/classes.txtreadme.txttest.csvtrain.csv

    文件说明:

    • train.csv表示训练数据, 共12万条数据; test.csv表示验证数据, 共7600条数据; classes.txt是标签(新闻主题)含义文件, 里面有四个单词'World', 'Sports', 'Business', 'Sci/Tech'代表新闻的四个主题, readme.txt是该数据集的英文说明.
    • train.csv预览:
    • "3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
      "3","Carlyle Looks Toward Commercial Aerospace (Reuters)","Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market."
      "3","Oil and Economy Cloud Stocks' Outlook (Reuters)","Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums."
      "3","Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)","Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday."
      "3","Oil prices soar to all-time record, posing new menace to US economy (AFP)","AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections."
      "3","Stocks End Up, But Near Year Lows (Reuters)","Reuters - Stocks ended slightly higher on Friday\but stayed near lows for the year as oil prices surged past  #36;46\a barrel, offsetting a positive outlook from computer maker\Dell Inc. (DELL.O)"
      "3","Money Funds Fell in Latest Week (AP)","AP - Assets of the nation's retail money market mutual funds fell by  #36;1.17 billion in the latest week to  #36;849.98 trillion, the Investment Company Institute said Thursday."
      "3","Fed minutes show dissent over inflation (USATODAY.com)","USATODAY.com - Retail sales bounced back a bit in July, and new claims for jobless benefits fell last week, the government said Thursday, indicating the economy is improving from a midsummer slump."
      "3","Safety Net (Forbes.com)","Forbes.com - After earning a PH.D. in Sociology, Danny Bazil Riley started to work as the general manager at a commercial real estate firm at an annual base salary of  #36;70,000. Soon after, a financial planner stopped by his desk to drop off brochures about insurance benefits available through his employer. But, at 32, ""buying insurance was the furthest thing from my mind,"" says Riley."
      "3","Wall St. Bears Claw Back Into the Black"," NEW YORK (Reuters) - Short-sellers, Wall Street's dwindling  band of ultra-cynics, are seeing green again."

    • 文件内容说明:
      • train.csv共由3列组成, 使用','进行分隔, 分别代表: 标签, 新闻标题, 新闻简述; 其中标签用"1", "2", "3", "4"表示, 依次对应classes中的内容.
      • test.csv与train.csv内容格式与含义相同.
  • 从本地进行数据的加载,实现代码如下:
from torchtext.legacy.datasets.text_classification import _csv_iterator, _create_data_from_iterator, TextClassificationDataset
from torchtext.utils import extract_archive
from torchtext.vocab import build_vocab_from_iterator, Vocab
# 从本地加载数据的方式,本地数据在虚拟机/root/data/ag_news_csv中
# 定义加载函数
def setup_datasets(ngrams=2, vocab_train=None, vocab_test=None, include_unk=False):train_csv_path = 'data/ag_news_csv/train.csv'test_csv_path = 'data/ag_news_csv/test.csv'if vocab_train is None:vocab_train = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))else:if not isinstance(vocab, Vocab):raise TypeError("Passed vocabulary is not of type Vocab")if vocab_test is None:vocab_test = build_vocab_from_iterator(_csv_iterator(test_csv_path, ngrams))else:if not isinstance(vocab, Vocab):raise TypeError("Passed vocabulary is not of type Vocab")train_data, train_labels = _create_data_from_iterator(vocab_train, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)test_data, test_labels = _create_data_from_iterator(vocab_test, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)if len(train_labels ^ test_labels) > 0:raise ValueError("Training and test labels don't match")return (TextClassificationDataset(vocab_train, train_data, train_labels),TextClassificationDataset(vocab_test, test_data, test_labels))# 调用函数, 加载本地数据
train_dataset, test_dataset = setup_datasets()
print("train_dataset", train_dataset)

2 案例实现

整个案例的实现可分为以下五个步骤

  • 第一步: 构建带有Embedding层的文本分类模型.
  • 第二步: 对数据进行batch处理.
  • 第三步: 构建训练与验证函数.
  • 第四步: 进行模型训练和验证.
  • 第五步: 查看embedding层嵌入的词向量.

2.1 构建带有Embedding层的文本分类模型

# 导入必备的torch模型构建工具
import torch.nn as nn
import torch.nn.functional as F# 指定BATCH_SIZE的大小
BATCH_SIZE = 16# 进行可用设备检测, 有GPU的话将优先使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class TextSentiment(nn.Module):"""文本分类模型"""def __init__(self, vocab_size, embed_dim, num_class):"""description: 类的初始化函数:param vocab_size: 整个语料包含的不同词汇总数:param embed_dim: 指定词嵌入的维度:param num_class: 文本分类的类别总数""" super().__init__()# 实例化embedding层, sparse=True代表每次对该层求解梯度时, 只更新部分权重.self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=True)# 实例化线性层, 参数分别是embed_dim和num_class.self.fc = nn.Linear(embed_dim, num_class)# 为各层初始化权重self.init_weights()def init_weights(self):"""初始化权重函数"""# 指定初始权重的取值范围数initrange = 0.5# 各层的权重参数都是初始化为均匀分布self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)# 偏置初始化为0self.fc.bias.data.zero_()def forward(self, text):""":param text: 文本数值映射后的结果:return: 与类别数尺寸相同的张量, 用以判断文本类别"""# 获得embedding的结果embedded# >>> embedded.shape# (m, 32) 其中m是BATCH_SIZE大小的数据中词汇总数embedded = self.embedding(text)# 接下来我们需要将(m, 32)转化成(BATCH_SIZE, 32)# 以便通过fc层后能计算相应的损失# 首先, 我们已知m的值远大于BATCH_SIZE=16,# 用m整除BATCH_SIZE, 获得m中共包含c个BATCH_SIZEc = embedded.size(0) // BATCH_SIZE# 之后再从embedded中取c*BATCH_SIZE个向量得到新的embedded# 这个新的embedded中的向量个数可以整除BATCH_SIZEembedded = embedded[:BATCH_SIZE*c]# 因为我们想利用平均池化的方法求embedded中指定行数的列的平均数,# 但平均池化方法是作用在行上的, 并且需要3维输入# 因此我们对新的embedded进行转置并拓展维度embedded = embedded.transpose(1, 0).unsqueeze(0)# 然后就是调用平均池化的方法, 并且核的大小为c# 即取每c的元素计算一次均值作为结果embedded = F.avg_pool1d(embedded, kernel_size=c)# 最后,还需要减去新增的维度, 然后转置回去输送给fc层return self.fc(embedded[0].transpose(1, 0))
  • 实例化模型:
# 获得整个语料包含的不同词汇总数
VOCAB_SIZE = len(train_dataset.get_vocab())
# 指定词嵌入维度
EMBED_DIM = 32
# 获得类别总数
NUN_CLASS = len(train_dataset.get_labels())
# 实例化模型
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

2.2 对数据进行batch处理

def generate_batch(batch):"""description: 生成batch数据函数:param batch: 由样本张量和对应标签的元组组成的batch_size大小的列表形如:[(label1, sample1), (lable2, sample2), ..., (labelN, sampleN)]return: 样本张量和标签各自的列表形式(张量)形如:text = tensor([sample1, sample2, ..., sampleN])label = tensor([label1, label2, ..., labelN])"""# 从batch中获得标签张量label = torch.tensor([entry[0] for entry in batch])# 从batch中获得样本张量text = [entry[1] for entry in batch]text = torch.cat(text)# 返回结果return text, label
  • 调用:
# 假设一个输入:
batch = [(1, torch.tensor([3, 23, 2, 8])), (0, torch.tensor([3, 45, 21, 6]))]
res = generate_batch(batch)
print(res)
  • 输出效果:
# 对应输入的两条数据进行了相应的拼接
(tensor([ 3, 23,  2,  8,  3, 45, 21,  6]), tensor([1, 0]))

2.3 构建训练与验证函数

# 导入torch中的数据加载器方法
from torch.utils.data import DataLoaderdef train(train_data):"""模型训练函数"""# 初始化训练损失和准确率为0train_loss = 0train_acc = 0# 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练# data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器data = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,collate_fn=generate_batch)# 对data进行循环遍历, 使用每个batch的数据进行参数更新for i, (text, cls) in enumerate(data):# 设置优化器初始梯度为0optimizer.zero_grad()# 模型输入一个批次数据, 获得输出output = model(text)# 根据真实标签与模型输出计算损失loss = criterion(output, cls)# 将该批次的损失加到总损失中train_loss += loss.item()# 误差反向传播loss.backward()# 参数进行更新optimizer.step()# 将该批次的准确率加到总准确率中train_acc += (output.argmax(1) == cls).sum().item()# 调整优化器学习率  scheduler.step()# 返回本轮训练的平均损失和平均准确率return train_loss / len(train_data), train_acc / len(train_data)def valid(valid_data):"""模型验证函数"""# 初始化验证损失和准确率为0loss = 0acc = 0# 和训练相同, 使用DataLoader获得训练数据生成器data = DataLoader(valid_data, batch_size=BATCH_SIZE, collate_fn=generate_batch)# 按批次取出数据验证for text, cls in data:# 验证阶段, 不再求解梯度with torch.no_grad():# 使用模型获得输出output = model(text)# 计算损失loss = criterion(output, cls)# 将损失和准确率加到总损失和准确率中loss += loss.item()acc += (output.argmax(1) == cls).sum().item()# 返回本轮验证的平均损失和平均准确率return loss / len(valid_data), acc / len(valid_data)

2.4 进行模型训练和验证

# 导入时间工具包
import time# 导入数据随机划分方法工具
from torch.utils.data.dataset import random_split# 指定训练轮数
N_EPOCHS = 10# 定义初始的验证损失
min_valid_loss = float('inf')# 选择损失函数, 这里选择预定义的交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss().to(device)
# 选择随机梯度下降优化器
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
# 选择优化器步长调节方法StepLR, 用来衰减学习率
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)# 从train_dataset取出0.95作为训练集, 先取其长度
train_len = int(len(train_dataset) * 0.95)# 然后使用random_split进行乱序划分, 得到对应的训练集和验证集
sub_train_, sub_valid_ = \random_split(train_dataset, [train_len, len(train_dataset) - train_len])# 开始每一轮训练
for epoch in range(N_EPOCHS):# 记录概论训练的开始时间start_time = time.time()# 调用train和valid函数得到训练和验证的平均损失, 平均准确率train_loss, train_acc = train(sub_train_)valid_loss, valid_acc = valid(sub_valid_)# 计算训练和验证的总耗时(秒)secs = int(time.time() - start_time)# 用分钟和秒表示mins = secs / 60secs = secs % 60# 打印训练和验证耗时,平均损失,平均准确率print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
  • 输出效果:
120000lines [00:06, 17834.17lines/s]
120000lines [00:11, 10071.77lines/s]
7600lines [00:00, 10432.95lines/s]Epoch: 1  | time in 0 minutes, 36 secondsLoss: 0.0592(train) |   Acc: 63.9%(train)Loss: 0.0005(valid) |   Acc: 69.2%(valid)
Epoch: 2  | time in 0 minutes, 37 secondsLoss: 0.0507(train) |   Acc: 71.3%(train)Loss: 0.0005(valid) |   Acc: 70.7%(valid)
Epoch: 3  | time in 0 minutes, 36 secondsLoss: 0.0484(train) |   Acc: 72.8%(train)Loss: 0.0005(valid) |   Acc: 71.4%(valid)
Epoch: 4  | time in 0 minutes, 36 secondsLoss: 0.0474(train) |   Acc: 73.4%(train)Loss: 0.0004(valid) |   Acc: 72.0%(valid)
Epoch: 5  | time in 0 minutes, 36 secondsLoss: 0.0455(train) |   Acc: 74.8%(train)Loss: 0.0004(valid) |   Acc: 72.5%(valid)
Epoch: 6  | time in 0 minutes, 36 secondsLoss: 0.0451(train) |   Acc: 74.9%(train)Loss: 0.0004(valid) |   Acc: 72.3%(valid)
Epoch: 7  | time in 0 minutes, 36 secondsLoss: 0.0446(train) |   Acc: 75.3%(train)Loss: 0.0004(valid) |   Acc: 72.0%(valid)
Epoch: 8  | time in 0 minutes, 36 secondsLoss: 0.0437(train) |   Acc: 75.9%(train)Loss: 0.0004(valid) |   Acc: 71.4%(valid)
Epoch: 9  | time in 0 minutes, 36 secondsLoss: 0.0431(train) |   Acc: 76.2%(train)Loss: 0.0004(valid) |   Acc: 72.7%(valid)
Epoch: 10  | time in 0 minutes, 36 secondsLoss: 0.0426(train) |   Acc: 76.6%(train)Loss: 0.0004(valid) |   Acc: 72.6%(valid)

2.5 查看embedding层嵌入的词向量

# 打印从模型的状态字典中获得的Embedding矩阵
print(model.state_dict()['embedding.weight'])
  • 输出效果:
tensor([[ 0.4401, -0.4177, -0.4161,  ...,  0.2497, -0.4657, -0.1861],[-0.2574, -0.1952,  0.1443,  ..., -0.4687, -0.0742,  0.2606],[-0.1926, -0.1153, -0.0167,  ..., -0.0954,  0.0134, -0.0632],...,[-0.0780, -0.2331, -0.3656,  ..., -0.1899,  0.4083,  0.3002],[-0.0696,  0.4396, -0.1350,  ...,  0.1019,  0.2792, -0.4749],[-0.2978,  0.1872, -0.1994,  ...,  0.3435,  0.4729, -0.2608]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/720298.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

11. Nginx进阶-HTTPS

简介 基本概述 SSL SSL是安全套接层。 主要用于认证用户和服务器,确保数据发送到正确的客户机和服务器上。 SSL可以加密数据,防止数据中途被窃取。 SSL也可以维护数据的完整性,确保数据在传输过程中不被改变。 HTTPS HTTPS就是基于SSL来…

【Unity】Node.js安装与配置环境

引言 我们在使用unity开发的时候,有时候会使用一些辅助工具。 Node.js就是开发中,经常会遇到的一款软件。 1.下载Node.js 下载地址:https://nodejs.org/en 2.安装Node.js ①点击直接点击Next下一步 ②把协议勾上,继续点击…

【lua】lua内存优化记录

这边有一个Unity项目用的tolua, 游戏运行后手机上lua内存占用 基本要到 189M, 之前峰值有200多。 优化点1 加快gc频度: 用uwa抓取的lua内存, 和unity的mono很像,内存会先涨 然后突然gc一下,降下来。 这样…

java数据结构与算法刷题-----LeetCode687. 最长同值路径

java数据结构与算法刷题目录(剑指Offer、LeetCode、ACM)-----主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/123063846 文章目录 1. 深度优先,用下面的儿子判断2. 深度优先&#xff0…

UI 自动化测试实战(二)| 测试数据的数据驱动

数据驱动就是通过数据的改变驱动自动化测试的执行,最终引起测试结果的改变。简单来说,就是参数化在自动化测试中的应用。 测试过程中使用数据驱动的优势主要体现在以下几点: 1.提高代码复用率,相同的测试逻辑只需编写一条测试用例…

【01】区块链科普100天-模块化区块链

模块化区块链以其高拓展性的特点成为热点 导言: 模块化区块链通过将不同功能分解为不同的模块(层)来提高系统的可拓展性、安全性和灵活性 例如Celestia公链,通过将数据可用性与共识分离来提高网络的可拓展性和灵活性 1.层次架构&a…

服务器后端是学习java还是php

没有绝对的"最好"语言,每种后端语言都有其适用的场景和特点。以下是几种常用的后端语言: 1. Java:Java是一种通用且强大的语言,广泛用于企业级应用和大型系统。它有很好的性能和可靠性,并且具有优秀的生态系…

光辐射测量(1)基本介绍+辐射度量、光辐射度量基础

基本情况:本门课就是对“三度学”进行学习。“三度学”包括辐射度学、光度学、色度学。主要掌握其基本概念、原理、物理量的互相转换关系、计算分析方法、测量仪器与测试计量方法等。 三者所覆盖的范围如图。 辐射度学: 辐射度学是一门研究电磁辐射能测…

自测-5 Shuffling Machine(python版本)

文章预览: 题目翻译算法python代码oj反馈结果 题目 翻译 shuffle是用于随机化一副扑克牌的过程。由于标准的洗牌技术被认为是薄弱的,并且为了避免员工通过不适当的洗牌与赌徒合作的“内部工作”,许多赌场使用了自动洗牌机。你的任务是模拟一…

XGB-17:模型截距

在 XGBoost 中,模型截距(也称为基本分数)是一个值,表示在考虑任何特征之前模型的起始预测。它本质上是处理回归任务时训练数据的平均目标值,或者是分类任务的赔率对数。 在 XGBoost 中,每个叶子节点都会输…

H5小游戏,象棋

H5小游戏源码、JS开发网页小游戏开源源码大合集。无需运行环境,解压后浏览器直接打开。有需要的订阅后,私信本人,发源码,含60小游戏源码。如五子棋、象棋、植物大战僵尸、贪吃蛇、飞机大战、坦克大战、开心消消乐、扑鱼达人、扫雷…

C++:Vector的使用

一、vector的介绍 vector的文档介绍 1. vector是表示可变大小数组的序列容器。 2. 就像数组一样,vector也采用的连续存储空间来存储元素。也就是意味着可以采用下标对vector的元素进行访问,和数组一样高效。但是又不像数组,它的大小是可以…

ABAP - 增强:一代增强User exit

一代增强是基于源代码的增强,一般是名字UserExit_开头空代码的子例程,所以一代增强的别称用户出口。需要修改SAP标准标准代码集中在名称倒数第二位为’Z‘的include程序里面。所有的全局数据可用那么该如何找到一代增强呢?以销售订单为例&…

《操作系统真相还原》读书笔记一:环境搭建 32位centos6.3+bochs

下载32位的centos6.3centos6.3 https://archive.kernel.org/centos-vault/6.3/isos/i386/

ubuntu22.04 成功编译llvm和clang 3.4.0,及 bitcode 函数名示例,备忘

1, 获取llvm 仓库 从github上获取: $ git clone --recursive https://github.com/llvm/llvm-project.git 2, 检出 llvmorg-3.4.0 tag 针对llvm 3.4.0版本,检出 $ cd llvm-project $ git tag $ git checkout llvmorg-3.4.0 3, 配置并编译llvm 使用 M…

EmoLLM(心理健康大模型)——探索心灵的深海,用智能的语言照亮情感的迷雾。

文章目录 介绍:应用地址:模型地址:Github地址:视频介绍:效果图: 介绍: EmoLLM是一个基于 InternLM 等模型微调的心理健康大模型,它涵盖了认知、情感、行为、社会环境、生理健康、心…

08 OpenCV 腐蚀和膨胀

文章目录 作用算子代码 作用 膨胀与腐蚀是数学形态学在图像处理中最基础的操作。其卷积操作非常简单,对于图像的每个像素,取其一定的邻域,计算最大值/最小值作为新图像对应像素位置的像素值。其中,取最大值就是膨胀,取最小值就是腐…

10 - 安装 image2df

1 背景 在使用 容器镜像 时可能遇到的场景: 我们想要通过已有的镜像来获取 Dockerfile,比如常用的使用 docker history 命令来查看镜像信息,然后分析生成 Dockerfile。但是,这个方法有些缺点:生成的 Dockerfile 少了 F…

奇安信发布《2024人工智能安全报告》,AI深度伪造欺诈激增30倍

2024年2月29日,奇安信集团对外发布《2024人工智能安全报告》(以下简称《报告》)。《报告》认为,人工智能技术的恶意使用将快速增长,在政治安全、网络安全、物理安全和军事安全等方面构成严重威胁。 《报告》揭示了基于…

就业班 2401--3.4 Linux Day10--软件管理

一、软件管理 导语: 安装软件 rpm yum 源码安装 ​ 卸载软件 rpm介绍 rpm软件包名称: 软件名称 版本号(主版本、次版本、修订号) 操作系统 -----90%的规律 #有依赖关系,不能自动解决依赖关系。 举例:openssh-6.6.1p1-31.el7.x86_64.rpm 数字前面的是名…