朴素贝叶斯--文档分类

原文:http://ihoge.cn/2018/MultinomialNB.html

把文档转换成向量

TF-IDF是一种统计方法,用以评估一个词语对于一份文档的重要程度。

  • TF表示词频, 即:词语在一片文档中出现的次数 ÷ 词语总数
  • IDF表示一个词的逆向文档频率指数, 即:对(总文档数目÷包含该词语的文档的数目)的商取对数 log(m/miinm)log(m/mi−in−m)

基础原理:词语的重要性随着它在文档中出现的次数成正比例增加,但同时会随着它在语料库中出现的频率呈反比下降。

sklearn中有包实现了把文档转换成向量的过程,首先把训练用额语料库读入内存:

from time import time 
from sklearn.datasets import load_filest = time()
news_train = load_files('code/datasets/mlcomp/379/train')
print(len(news_train.data), "\n",len(news_train.target_names))
print("done in {} seconds".format(time() - t))
13180 20
done in 6.034918308258057 seconds

news_train.data是一个数组,包含了所有文档的文本信息。
news_train.target_names也是一个数组,包含了所有文档的属性类别,对应的是读取train文件夹时,train文件夹下所有的子文件夹名称。

该语料库总共有13180个文档,其中分成20个类别,接着需要转换成由TF-IDF表达的权重信息构成向量。

from sklearn.feature_extraction.text import TfidfVectorizert = time()
vectorizer  = TfidfVectorizer(encoding = 'latin-1')
X_train = vectorizer.fit_transform((d for  d in news_train.data))
print("文档 [{0}]特征值的非零个数:{1}".format(news_train.filenames[0] , X_train[0].getnnz()))
print("训练集:",X_train.shape)
print("耗时: {0} s.".format(time() - t))
文档 [code/datasets/mlcomp/379/train/talk.politics.misc/17860-178992]特征值的非零个数:108
训练集: (13180, 130274)
耗时: 3.740567207336426 s.

TfidfVectorizer类是用来把所有的文档转换成矩阵,该矩阵每一行都代表一个文档,一行中的每个元素代表一个对应的词语的重要性,词语的重要性由TF-IDF来表示。其fit_transform()方法是fit()transform()的结合,fit()先完成语料库分析,提取词典等操作transform()把每篇文档转换为向量,最终构成一个矩阵,保存在X_train里。

程序输出可以看到该词典总共有130274个词语,即每篇文档都可以转换成一个13274维的向量组。第一篇文档中只有108个非零元素,即这篇文档由108个不重复的单词组成,在这篇文档中出现的这108个单词次的TF-IDF会被计算出来,保存在向量的指定位置。这里的到X_train是一个纬度为12180 x 130274的系数矩阵。

训练模型

from sklearn.naive_bayes import MultinomialNBt = time()
y_train = news_train.target
clf = MultinomialNB(alpha=0.001)  #alpga表示平滑参数,越小越容易造成过拟合;越大越容易欠拟合。
clf.fit(X_train, y_train)print("train_score:", clf.score(X_train, y_train))
print("耗时:{0}s".format(time() - t))
train_score: 0.9974203338391502
耗时:0.23757004737854004s
# 加载测试集检验结果
news_test = load_files('code/datasets/mlcomp/379/test')
print(len(news_test.data))
print(len(news_test.target_names))
5648
20
# 把测试集文档数学向量化
t = time()
# vectorizer  = TfidfVectorizer(encoding = 'latin-1')  # 这里注意vectorizer这条语句上文已经生成执行,这里不可重复执行
X_test = vectorizer.transform((d for  d in news_test.data))
y_test = news_test.targetprint("测试集:",X_test.shape)
print("耗时: {0} s.".format(time() - t))
测试集: (5648, 130274)
耗时: 1.64164400100708 s.
import numpy as np
from sklearn import metrics y_pred = clf.predict(X_test)
print("Train_score:", clf.score(X_train, y_train))
print("Test_score:", clf.score(X_test, y_test))for i in range(10):r = np.random.randint(X_test.shape[0])if clf.predict(X_test[r]) == y_test[r]:print("√:{0}".format(r))else:print("X:{0}".format(r))
Train_score: 0.9974203338391502
Test_score: 0.9123583569405099
√:1874
√:2214
√:2579
√:1247
√:375
√:5384
√:5029
√:1951
√:4885
√:1980

评价模型:

classification_report()查看查准率、召回率、F1

使用classification_report()函数查看针对每个类别的预测准确性:

from sklearn.metrics import classification_reportprint(clf)
print("查看针对每个类别的预测准确性:")
print(classification_report(y_test, y_pred, target_names = news_test.target_names))
MultinomialNB(alpha=0.001, class_prior=None, fit_prior=True)
查看针对每个类别的预测准确性:precision    recall  f1-score   supportalt.atheism       0.90      0.92      0.91       245comp.graphics       0.80      0.90      0.84       298comp.os.ms-windows.misc       0.85      0.80      0.82       292
comp.sys.ibm.pc.hardware       0.81      0.82      0.81       301comp.sys.mac.hardware       0.90      0.92      0.91       256comp.windows.x       0.89      0.88      0.88       297misc.forsale       0.88      0.82      0.85       290rec.autos       0.93      0.93      0.93       324rec.motorcycles       0.97      0.97      0.97       294rec.sport.baseball       0.97      0.96      0.97       315rec.sport.hockey       0.97      0.99      0.98       302sci.crypt       0.96      0.95      0.96       297sci.electronics       0.91      0.85      0.88       313sci.med       0.96      0.96      0.96       277sci.space       0.95      0.97      0.96       305soc.religion.christian       0.93      0.96      0.94       293talk.politics.guns       0.90      0.96      0.93       246talk.politics.mideast       0.95      0.98      0.97       296talk.politics.misc       0.91      0.89      0.90       236talk.religion.misc       0.89      0.77      0.82       171avg / total       0.91      0.91      0.91      5648

confusion_matrix混淆矩阵

通过confusion_matrix函数生成混淆矩阵,观察每种类别别错误分类的情况。例如,这些被错误分类的文档是被错误分类到哪些类别里。

from sklearn.metrics import confusion_matrixcm = confusion_matrix(y_test, y_pred)
print(cm)# 第一行表示类别0的文档被正确分类的由255个,其中有2、5、13个错误分类被分到了14、15、19类中了。
[[225   0   0   0   0   0   0   0   0   0   0   0   0   0   2   5   0   0   0  13][  1 267   6   4   2   8   1   1   0   0   0   2   3   2   1   0   0   0   0   0][  1  12 233  26   4   9   3   0   0   0   0   0   2   1   0   0   0   0   1   0][  0   9  16 246   7   3  10   1   0   0   1   0   8   0   0   0   0   0   0   0][  0   2   3   5 236   2   2   1   0   0   0   3   1   0   1   0   0   0   0   0][  0  22   6   3   0 260   0   0   0   2   0   1   0   0   1   0   2   0   0   0][  0   2   5  11   3   1 238   9   2   3   1   0   7   0   1   0   2   2   3   0][  0   1   0   0   1   0   7 302   4   1   0   0   1   2   3   0   2   0   0   0][  0   0   0   0   0   2   2   3 285   0   0   0   1   0   0   0   0   0   0   1][  0   1   0   0   1   1   1   2   0 302   6   0   0   1   0   0   0   0   0   0][  0   0   0   0   0   0   0   0   2   1 299   0   0   0   0   0   0   0   0   0][  0   1   2   1   1   1   2   0   0   0   0 283   1   0   0   0   2   1   2   0][  0  11   2   6   5   2   4   5   1   1   1   3 267   1   3   0   0   0   1   0][  1   1   0   1   1   1   0   0   0   0   0   1   1 265   2   1   0   0   2   0][  0   3   0   0   1   0   0   0   0   0   0   1   1   1 296   0   1   0   1   0][  3   1   0   1   0   0   0   0   0   0   1   0   0   2   0 281   0   1   2   1][  1   0   1   0   0   0   0   0   1   0   0   0   0   0   0   0 237   1   4   1][  1   0   0   0   0   1   0   0   0   0   0   0   0   0   0   3   0 290   1   0][  1   1   0   0   1   1   0   1   0   0   0   0   0   0   0   1  12   7 210   1][ 16   1   0   0   0   0   0   0   0   0   0   0   0   0   0  12   5   2   4 131]]
%matplotlib inline
from matplotlib import pyplot as pltplt.figure(figsize=(6, 6), dpi=120)
plt.title('Confusion matrix of the classifier')
ax = plt.gca()                                  
ax.spines['right'].set_color('none')            
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.matshow(cm, fignum=1, cmap='gray')
plt.colorbar();# 除对角线外,颜色越浅说明错误越多

# 上图不直观,重新画图
import random
from pyecharts import HeatMapx_axis = np.arange(20)
y_axis = np.arange(20)
data = [[i, j, cm[i][j]] for i in range(20) for j in range(20)]
heatmap = HeatMap()
heatmap.add("混淆矩阵", x_axis, y_axis, data, is_visualmap=True,visual_text_color="#fff", visual_orient='horizontal')
# heatmap.render()
# heatmap


原文:http://ihoge.cn/2018/MultinomialNB.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/293061.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

USENIX 最佳论文奖:擦除 Windows Azure 存储编码

我们发表了一篇介绍Windows Azure 存储如何用编码方式擦除数据的论文,此论文在 2012 年 6 月的 USENIX 技术年会上荣获最佳论文奖。这是 MicrosoftResearch 和 Windows Azure 存储团队共同努力的成果。 您可以在此处找到此论文。 Windows …

Linux I/O 模型(待修改)

2019独角兽企业重金招聘Python工程师标准>>> 最近看到“服务器并发处理能力”章节,被里面的“I/O模型“搞得有点头晕,所以这里希望通过概念的辨析和对比,能更好的理解Linux的 I/O模型。 同步(synchronous)…

linux之ls只显示文件或者文件夹

只显示文件夹 ls -l | grep ^d 只显示文件 ls -l | grep ^- 解释一下: ls -l 之后会得到下面的内容 drwx------ 4 jinwang users 4096 2012-02-09 15:00 .xchat2 -rw-r--r-- 1 jinwang users 1690399 2012-06-04 12:16 45s.txt 文件是以 &q…

git代码提交流程

从master创建任务分支1.需要先将master分支代码更新到最新然后再切新分支;2.新需求和hotfix需要从master切分支,若是在QA测试阶段或者预发布阶段的bug,则需要再该功能分支上进行修改;提交代码到自己的任务分支commit之后一定要pus…

PCA主成分分析+SVM实现人脸识别

原文地址: http://ihoge.cn/2018/PCASVM人脸识别.html 加载数据 这里使用的测试数据共包含40位人员照片,每个人10张照片。也可登陆http://www.cl.cam.ac.uk/research/dtg/attarchive/facesataglance.html 查看400张照片的缩略图。 import time impo…

Lua 学习笔记(一)

Lua学习笔记 1、lua的优势 a、可扩张性 b、简单 c、高效率 d、和平台无关 2、注释 a、单行注释 -- b、多行注释 --[[ --]] 3、类型和值 8个基本类型,检测变量类型用type a、nil print(type(nil)) -->nil 全局变量没有…

python inspect模块解析

来源:https://my.oschina.net/taisha/blog/55597 inspect模块主要提供了四种用处: (1) 对是否是模块,框架,函数等进行类型检查。 (2) 获取源码 (3) 获取类或函数的参数的信息 (4) 解析堆栈 使用inspect模块可以提供自省功能&#…

龙芯发布.NET 6.0.100开发者内测版

龙芯在龙芯开源社区发布了LoongArch64-.NET-SDK-6.0.100开发者内测版的新闻 ,龙芯.NET基于上游社区 版本 适配支持龙芯平台架构。目前支持LoongArch64架构和MIPS64架构,LoongArch64架构的.NET-SDK-3.1已完成,安装包下载地址LoongArch64-.NET …

Redis系统性介绍

虽然Redis已经很火了,相信还是有很多同学对Redis只是有所听闻或者了解并不全面,下面是一个比较系统的Redis介绍,对Redis的特性及各种数据类型及操作进行了介绍。是一个很不错的Redis入门教程。 1.介绍 1.1 Redis是什么 REmote DIctionary Ser…

mysql 不支持 select into

替代方案 insert into newTableName(column1,column2) select * from oldTableName INSERT INTO aw_daily_call_task(tenantId,sub_product_name, call_type,call_date,expire_date,username,customer_type,start_date,end_date,budget,product_type) SELECT t1.tenant_id …

linux之wc命令

Linux系统中的wc(Word Count)命令的功能为统计指定文件中的字节数、字数、行数,并将统计结果显示输出。 1.命令格式: wc [选项]文件... 2.命令功能: 统计指定文件中的字节数、字数、行数,并将统计结果显示输…

数据挖掘的9大成熟技术和应用

http://ihoge.cn/2018/DataMining.html 数据挖掘的9大成熟技术和应用 基于数据挖掘的9大主要成熟技术以及在数据化运营中的主要应用: 1、决策树 2、神经网络 3、回归 4、关联规则 5、聚类 6、贝叶斯分类 7、支持向量机 8、主成分分析 9、假设检验 1 决…

LVS:三种负载均衡方式与八种均衡算法

1、什么是LVS? 首先简单介绍一下LVS (Linux Virtual Server)到底是什么东西,其实它是一种集群(Cluster)技术,采用IP负载均衡技术和基于内容请求分发技术。调度器具有很好的吞吐率,将请求均衡地转移到不同的服务器上执行&#xff0…

排查 .NET开发的工厂MES系统 内存泄漏分析

一:背景 1. 讲故事上个月有位朋友加微信求助,说他的程序跑着跑着就内存爆掉了,寻求如何解决,截图如下:从聊天内容看,这位朋友压力还是蛮大的,话说这貌似是我分析的第三个 MES 系统了&#xff0c…

DataGirdView 常用操作

1、将数据源的某列添加到已有DataGirdView的列 例如:将文件夹下所有文件名添加到DataGirdView 的文件名一列,图片如下: 首先在datagridview把文件名列的DATAPROPERTYNAME设为你要显示的数据列的名字.此处我绑定的是folder.Name,所以直接在DAT…

Android之Android studio Gradle sync failed: Unknown host ‘services.gradle.org

错误描述: Gradle sync failed: Unknown host services.gradle.org. You may need to adjust the proxy settings in Gradle.Consult IDE log for more details (Help | Show Log)解决办法: 下载gradlectrlalts 然后输入gradle;在project-…

使用aconda3-5.1.0(Python3.6.4) 搭建pyspark远程部署

参考:http://ihoge.cn/2018/anacondaPyspark.html 前言 首次安装的环境搭配是这样的: jdk8 hadoop2.6.5 spark2.1 scala2.12.4 Anaconda3-5.1.0 一连串的报错让人惊喜无限,尽管反复调整配置始终无法解决。 坑了一整天后最后最终发现…

Entity Framework在Asp.net MVC中的实现One Context Per Request(附源码)

上篇中"Entity Framework中的Identity map和Unit of Work模式", 由于EF中的Identity map和Unit of Work模式,EF体现出来如下特性: 唯一性: 在一个Context的生命周期中,一个Entity只会有一个实例,任何对该实例的修改&…