pytorch集智4-情绪分类器

1 目标

从中文文本中识别出句子里的情绪。和上一章节单车预测回归问题相比,这个问题是分类问题,不是回归问题

2 神经网络分类器

2.1 如何用神经网络分类

第二章节用torch.nn.Sequantial做的回归预测器,输出神经元只有一个。分类器和其区别如下

1 分类器输出单元有几个,就有几个分类

2 分类器各输出单元输出值加和为1(值表示某个预测分类的概率,概率求和为1)

3 回归预测最后一层可以用sigmoid,分类不行,因为sigmoid无法保证分类器输出单元值求和为1,可以用softmax,表达式如下

2.2 分类问题损失函数

分类问题损失函数一般用交叉熵:

N表示样本数量,Ci表示第i个样本的类别标签,y表示每个样本的损失值

3 词袋模型分类器

直白理解:就是统计句子里各词语出现数量,忽略语法,语义等因素

思想:1 将句子转化为词袋,然后可以onehot处理,方便用神经网络预测;2 大概思想:在词袋里寻找和输出相关关键词语,然后统计词语数量,哪个输出类别的敏感词语统计数量高,结果就是哪个输出类别

数据预处理:可以归一化处理,网络计算会更快

3.1 简单文本分类器

背景与数据:获取网购商城顾客评论信息,从评论信息获取句子情绪

神经网络组成:用前馈神经网络,有3层,分别是输入,中间,输出层,输出层数量有两个,分别代表情绪是正面还是负面

3.2 程序实现

3.2.1 数据处理

主要步骤:处理标点,句子分词,建立词表

标点处理:正则,涉及的标点全部删除

句子分词:使用jieba模块对句子分词

建立词表:python字典建立词表


def deal_punc(sentence):sentence = re.sub('[\s+\.\!\/_,$%^*(+\"\'“”《》?]+|[+——!,。?、~@#¥%……&*():)]+', '', sentence)return sentencedef prepare_data(good_file, bad_file, is_filter=True):print('prepare data begin')all_words = []pos_sen, neg_sen = [], []for path, sen in zip((good_file, bad_file), (pos_sen, neg_sen)):with open(path, 'r', encoding='utf-8') as f:for index, line in enumerate(f):if is_filter:line = deal_punc(line)words = jieba.lcut(line)if len(words) > 0:all_words += wordssen.append(words)print(f'{path} include {index} rows, all words:{len(all_words)}')print(f'pos_sen len:{len(pos_sen)}, neg_sen len:{len(neg_sen)}')diction = {}cnt = Counter(all_words)for word, freq in cnt.items():diction[word] = [len(diction), freq]print(f'diction len:{len(diction)}')return (pos_sen, neg_sen, diction)def word2index(word, diction):if word in diction:value = diction[word][0]else:value = -1return (value)def index2word(index, diction):for w, v in diction.items():if v[0] == index:return (w)return (None)

3.2.2 文本数据向量化

    pos_sen, neg_sen, diction = prepare_data_sample(good_file, bad_file)st = sorted([(v[1], w) for w, v in diction.items()])datasets, labels, sentences = [], [], []'''for sens, tag in zip((pos_sen, neg_sen), (0, 1)):for sen in sens:new_sen = []for l in sen:if l in diction:new_sen.append(word2index_sample(1, diction))datasets.append(sentence2vec_sample(new_sen, diction))labels.append(tag)sentences.append(sen)'''for sentence in pos_sen:new_sentence = []for l in sentence:if l in diction:new_sentence.append(word2index(l, diction))datasets.append(sentence2vec_sample(new_sentence, diction))labels.append(0) #正标签为0sentences.append(sentence)# 处理负向评论for sentence in neg_sen:new_sentence = []for l in sentence:if l in diction:new_sentence.append(word2index(l, diction))datasets.append(sentence2vec_sample(new_sentence, diction))labels.append(1) #负标签为1sentences.append(sentence)indices = np.random.permutation(len(datasets))datasets = [datasets[i] for i in indices]labels = [labels[i] for i in indices]sentences = [sentences[i] for i in indices]

3.3.3 划分数据集

训练集,校验集,测试集,比例一般为10:1:1

    # split train, test, verify datasetstest_size = int(len(datasets) // 10)train_data = datasets[2 * test_size :]train_label = labels[2 * test_size : ]valid_data = datasets[: test_size]valid_label = labels[: test_size]test_data = datasets[test_size : 2 * test_size]test_label = labels[test_size : 2 * test_size]

3.3.4 建立神经网络

结构:输入层(约7000个样本),一个中间层(10个隐单元),输出层(2个单元)

注意,此处用的是relu,不是sigmoid。原因:虽然x>0时没处理数据,但x<0时为0让relu有了非线性的能力,且运算比sigmoid简单,速度更快,方便反向误差传导


def plot(records):print('plot begin')a = [i[0] for i in records]b = [i[1] for i in records]c = [i[2] for i in records]pyplot.plot(a, label='train loss')pyplot.plot(b, label='verify loss')pyplot.plot(c, label='valid accuracy')pyplot.xlabel('step')pyplot.ylabel('losses & accuracy')pyplot.legend()pyplot.show()def main():model = torch.nn.Sequential(torch.nn.Linear(len(diction), 10),torch.nn.ReLU(),torch.nn.Linear(10, 2),torch.nn.LogSoftmax(dim=1),)cost = torch.nn.NLLLoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)records = []losses = []for epoch in range(3):for i, data in enumerate(zip(train_data, train_label)):x, y = datax = torch.tensor(x, requires_grad=True, dtype=torch.float).view(1, -1)y = torch.tensor(np.array([y]), dtype=torch.long)optimizer.zero_grad()predict = model(x)loss = cost(predict, y)losses.append(loss.data.numpy())loss.backward()optimizer.step()val_losses = []rights = []for j, val in enumerate(zip(valid_data, valid_label)):x, y = valx = torch.tensor(x, requires_grad=True, dtype=torch.float).view(1, -1)y = torch.tensor(np.array([y]), dtype = torch.long)predict = model(x)right = rightness_sample(predict, y)rights.append(right)loss = cost(predict, y)val_losses.append(loss.data.numpy())right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])print(f'No.{epoch}, train loss:{np.mean(losses):.4f}, verify loss:{np.mean(val_losses):.4f}, verify accuracy:{right_ratio}')records.append([np.mean(losses), np.mean(val_losses), right_ratio])# plotplot(records)

3.3 运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/623028.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT——connect的第五个参数 Qt::ConnectionType (及qt和c++的多线程的区别)

一直对QT的多线程和c的多线程的区别有疑惑&#xff0c;直到看到文档中这一部分内容才豁然开朗 一.ConnectionType参数的类型和区别 首先是官方文档中对于该枚举值的区别介绍&#xff1a; 对于队列&#xff08;queued &#xff09;连接&#xff0c;参数必须是 Qt 元对象系统已知…

强化学习应用(四):基于Q-learning的物流配送路径规划研究(提供Python代码)

一、Q-learning算法简介 Q-learning是一种强化学习算法&#xff0c;用于解决基于马尔可夫决策过程&#xff08;MDP&#xff09;的问题。它通过学习一个值函数来指导智能体在环境中做出决策&#xff0c;以最大化累积奖励。 Q-learning算法的核心思想是使用一个Q值函数来估计每…

助力工业园区作业违规行为检测预警,基于YOLOv7【tiny/l/x】不同系列参数模型开发构建工业园区场景下作业人员违规行为检测识别系统

在很多工业园区生产作业场景下保障合规合法进行作业生产操作&#xff0c;对于保护工人生命安全降低安全隐患有着非常重要的作用&#xff0c;但是往往在实际的作业生产中&#xff0c;因为一个安全观念的淡薄或者是粗心大意&#xff0c;对于纪律约束等意思薄弱&#xff0c;导致在…

maven镜像源设置aliyun提升下载速度

一、打开pom.xml project下在添加 <repositories><repository><id>aliyunmaven</id><name>aliyun</name><url>https://maven.aliyun.com/repository/public</url></repository><repository><id>central2&l…

分布形态的度量_峰度系数的探讨

集中趋势和离散程度是数据分布的两个重要特征,但要全面了解数据分布的特点&#xff0c;还应掌握数据分布的形态。 描述数据分布形态的度量有偏度系数和峰度系数, 其中偏度系数描述数据的对称性,峰度系数描述与正态分布的偏离程度。 峰度系数反映分布峰的尖峭程度的重要指标. 当…

【ESP32接入语言大模型之智谱清言】

1. 智谱清言 讲解视频&#xff1a; 随着人工智能技术的不断发展&#xff0c;自然语言处理领域也得到了广泛的关注和应用。智谱清言作为千亿参数对话模型 基于ChatGLM2模型开发&#xff0c;支持多轮对话&#xff0c;具备内容创作、信息归纳总结等能力。可以快速注册体验中国版…

远程开发之vscode端口转发

远程开发之vscode端口转发 涉及的软件forwarded port 通过端口转发&#xff0c;实现在本地电脑上访问远程服务器上的内网的服务。 涉及的软件 vscode、ssh forwarded port 在ports界面中的port字段&#xff0c;填需要转发的IP:PORT&#xff0c;即可转发远程服务器中的内网端…

增强FAQ搜索引擎:发挥Elasticsearch中KNN的威力

英文原文地址&#xff1a;https://medium.com/nerd-for-tech/enhancing-faq-search-engines-harnessing-the-power-of-knn-in-elasticsearch-76076f670580 增强FAQ搜索引擎&#xff1a;发挥Elasticsearch中KNN的威力 2023 年 10 月 21 日 在一个快速准确的信息检索至关重要的…

基于MOD02/MYD02获得亮度温度再转冰温

用HEG处理MOD02/MYD02,提取里面的EV_1KM_Emissive波段,band为11和12(其实就是band 31和32)。注意这里的band和output dile type 1. 获得之后,转辐射亮度。 参考:https://www.cnblogs.com/enviidl/p/16539422.html radiance_scales,和radiance_offset这两项参数代表波段…

【生存技能】git操作

先下载git https://git-scm.com/downloads 我这里是win64&#xff0c;下载了相应的直接安装版本 64-bit Git for Windows Setup 打开git bash 设置用户名和邮箱 查看设置的配置信息 获取本地仓库 在git bash或powershell执行git init&#xff0c;初始化当前目录成为git仓库…

LeetCode讲解篇之216. 组合总和 III

文章目录 题目描述题解思路题解代码 题目描述 题解思路 使用递归回溯算法&#xff0c;当选择数字num后&#xff0c;在去选择大于num的合法数字&#xff0c;计算过程中的数字和&#xff0c;直到选择了k次&#xff0c;如果数组和等于n则加入结果集 从1开始选择数字&#xff0c;直…

ubuntu 2022.04 安装vcs2018和verdi2018

主要参考网站朋友们的作业。 安装时参考&#xff1a; ubuntu18.04安装vcs、verdi2018_ubuntu安装vcs-CSDN博客https://blog.csdn.net/qq_24287711/article/details/130017583 编译时参考&#xff1a; 【ASIC】VCS报Error-[VCS_COM_UNE] Cannot find VCS compiler解决方法_e…

平凡之路_2023年

平凡之路总结 思路总结&#xff0c;以XMIND 为形式&#xff0c;构建思维大厦&#xff0c;蛰伏与积累&#xff0c;下面补充对XMIND的描述 内功修炼问题意识&#xff08;输入&#xff09;与结构化思维&#xff08;输出&#xff09; – 同如何成为一个领域的专家 2024.1.14 最大的…

统计学-R语言-4.4

文章目录 前言双变量数据分类型数据对分类型数据--二维表分类对分类--复式条形图分类对数值--并列箱线图 数值型数据对数值型数据散点图相关系数 练习 前言 上一篇文章介绍的是单变量数据&#xff0c;本篇将介绍双变量数据。 双变量数据 描述分类数据对分类数据的描述方法&am…

(菜鸟自学)搭建虚拟渗透实验室——安装Kali Linux

安装Kali Linux Kali Linux 是一种基于 Debian 的专为渗透测试和网络安全应用而设计的开源操作系统。它提供了广泛的渗透测试工具和安全审计工具&#xff0c;使安全专业人员和黑客可以评估和增强网络的安全性。 安装KaliLinux可参考我的另一篇文章《Kali Linux的下载安装以及基…

python统计分析——操作案例(模拟抽样)

参考资料&#xff1a;用python动手学统计学 import numpy as np import pandas as pd from matplotlib import pyplot as plt import seaborn as snsdata_setpd.read_csv(r"C:\python统计学\3-4-1-fish_length_100000.csv")[length] #此处将文件路径改为自己的路…

数据结构(c)冒泡排序

本文除了最下面的代码是我写的&#xff0c;其余是网上抄写的。 冒泡排序 什么是冒泡排序&#xff1f; 冒泡排序&#xff08;Bubble Sort&#xff09;是一种简单的排序算法。它重复地走访过要排序的数列&#xff0c;一次比较两个元素&#xff0c;如果他们的顺序错误就把他们交…

【5G Modem】5G modem架构介绍

博主未授权任何人或组织机构转载博主任何原创文章&#xff0c;感谢各位对原创的支持&#xff01; 博主链接 本人就职于国际知名终端厂商&#xff0c;负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作&#xff0c;目前牵头6G算力网络技术标准研究。 博客…

概率论与数理统计————1.随机事件与概率

一、随机事件 随机试验&#xff1a;满足三个特点 &#xff08;1&#xff09;可重复性&#xff1a;可在相同的条件下重复进行 &#xff08;2&#xff09;可预知性&#xff1a;每次试验的可能不止一个&#xff0c;事先知道试验的所有可能结果 &#xff08;3&#xff09;不确定…

matlab串口数据交互的使用

一、matlab将串口数据读取并储存到position中 delete(instrfindall);%注销系统之前已经打开的串口资源 clear s %清空s的数据 s serial(COM6,BaudRate,115200);%定义串口及波特率 fopen(s)%打开串口 fwrite(s,00AB,)%向串口写入读取电机位置指令 for i1:8 %共8个电机position…