【人工智能Ⅰ】实验4:贝叶斯分类

实验4  贝叶斯分类

一、实验目的

1. 了解并学习机器学习相关库的使用。

2. 熟悉贝叶斯分类原理和方法,并对MNIST数据集进行分类。

二、实验内容

1. 使用贝叶斯方法对mnist或mnist variation数据集进行分类,并计算准确率。数据集从网上下载(如百度飞桨平台)。

2. 改变算法参数,观察对识别准确率的影响。

三、实验环境

平台

Jupyter Notebook (anaconda3)

Python版本

python 3.9

第三方依赖

numpy、scikit-learn、matplotlib

四、方法流程

1:配置第三方依赖;

2:下载mnist数据集;

3:划分数据集为训练集train和测试集test;

4:调用机器学习库的贝叶斯分类模型;

5:设置pred为模型预测的结果;

6:比较test和pred,计算准确率;

7:输出测试集test的混淆矩阵;

8:按照上述方法继续对其他模型进行调参并对比模型的训练结果。

五、实验展示(训练过程和训练部分结果进行可视化)

1:三种贝叶斯模型的训练结果对比

模型

名称

分类准确率

GaussianNB

高斯Bayes算法

55.58%

MultinomialNB

多项式Bayes算法

83.65%

BernoulliNB

伯努利Bayes算法

84.13%

可视化训练结果(混淆矩阵):


GaussianNB


MultinomialNB


BernoulliNB

训练数据:


X = mnist.data(像素结果)


y = mnist.target(标签结果)

2:高斯Bayes算法(GaussianNB)在不同平滑参数下的训练结果对比

平滑参数var_smoothing

分类准确率

0.95

74.97%

0.8

75.89%

0.5

77.91%

0.3

79.77%

0.05

81.40%

不采用

55.58%

 

3:高斯Bayes算法(GaussianNB)在设置均等先验概率下的训练结果


可以发现,与未设置priors参数时候的结果一致,说明模型预先就均分了先验概率。

4:伯努利Bayes算法(BernoulliNB)在是否禁用学习类的先验概率的训练结果对比

先验概率使用fit_prior

分类准确率

True

84.13%

False

84.15%

5:伯努利Bayes算法(BernoulliNB)在不同加法(Laplace/Lidstone)平滑参数下的训练结果对比

平滑参数alpha

分类准确率

0.1

84.15%

0.5

84.14%

0.9

84.14%

6:伯努利Bayes算法(BernoulliNB)在不同二值化阈值参数下的训练结果对比

二值化阈值参数binarize

分类准确率

0.0

84.13%

0.5

84.13%

1.0

84.13%

7:多项式Bayes算法(MultinomialNB)在不同加法(Laplace/Lidstone)平滑参数下的训练结果对比

平滑参数alpha

分类准确率

0.1

83.67%

0.5

83.66%

0.9

83.65%

8:伯努利Bayes算法(BernoulliNB)在是否禁用学习类的先验概率的训练结果对比

先验概率使用fit_prior

分类准确率

True

83.65%

False

83.65%

9:多项式Bayes算法(MultinomialNB)的单张图片测试结果

10:对数据集进行主成分分析、标准化和归一化处理后的训练结果

主成分分析:

PCA_Transfer=PCA(n_components=0.95)

PCA_x=PCA_Transfer.fit_transform(x)

标准化:

Standard_Transfer=StandardScaler()

standard_x_train=Standard_Transfer.fit_transform(x_train)

standard_x_test=Standard_Transfer.fit_transform(x_test)

归一化:

MinMax_Transfer=MinMaxScaler()

Guiyihua_x_train=MinMax_Transfer.fit_transform(x_train)

Guiyihua_x_test=MinMax_Transfer.fit_transform(x_test)

经过上述操作后的模型训练结果:

模型

名称

分类准确率

GaussianNB

高斯Bayes算法

77.49%

MultinomialNB

多项式Bayes算法

85.45%

BernoulliNB

伯努利Bayes算法

74.14%

六、实验结论

1:在GaussianNB中,我们可以改变的主要参数是priors和var_smoothing。priors参数可以手动设置每个类别的先验概率,默认让模型根据数据自动计算先验概率。var_smoothing参数用于控制对类别不确定性的处理方式。var_smoothing值过大,可能会让模型对实际数据中的变化不敏感,导致模型性能降低;var_smoothing值过小,可能会对数据中的噪声和异常值更敏感,导致训练数据过拟合。

2:在MultinomialNB中,我们可以改变的主要参数是alpha、fit_prior、class_prior和min_categories。加法平滑参数alpha用于处理因数据稀疏而在学习数据中未观察到的特征,防止概率计算时出现0值。是否学习类的先验概率fit_prior如果为假,使用均匀的先验概率,即认为所有输出类别的可能性相等。类别的先验概率class_prior如果指定,则不根据数据调整先验概率。指定每个特征的最小类别数min_categories帮助防止在特征维度非常高但训练样本相对较少时出现过拟合。

3:在BernoulliNB中,我们可以改变的主要参数是alpha、binarize、fit_prior和class_prior。与MultinomialNB相比,binarize是用于二值化输入特征的阈值。如果设定,则输入特征大于这个阈值的将会被二值化为1,否则二值化为0。

4:在模型选择方面,MultinomialNB和BernoulliNB更适合处理MNIST数据集,因为MNIST的图像可以表示为像素强度计数(适合多项式分布)或二值化的像素存在与否(适合伯努利分布)。GaussianNB可能在没有适当预处理(如归一化)的情况下表现不佳。

5:在使用贝叶斯分类器后,进行误差分析(比如混淆矩阵)可以揭示某些数字更难区分。数字之间难以区分通常是由贝叶斯模型的独立性假设造成的。

七、遇到的问题及其解决方案

问题1:安装第三方依赖时,显示pip有更新。

解决1:采用【pip install --upgrade pip】命令,升级pip即可。

问题2:Jupyter Notebook只显示ipykernel为python 3,不显示具体的python版本。

解决2:采用【!pip3 -V】命令进行查看,结果如下图所示。

八、附件

1:三类贝叶斯模型的基本调用源代码

import numpy as np 

import matplotlib.pyplot as plt 

from sklearn.datasets import fetch_openml 

 

# 加载MNIST或MNIST Variation数据集 

mnist = fetch_openml('mnist_784', version=1) 

X, y = mnist.data, mnist.target 

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:] 

# 训练贝叶斯分类器 

from sklearn.naive_bayes import GaussianNB 

gnb = GaussianNB() 

gnb.fit(X_train, y_train) 

 

# 预测测试集 

y_pred = gnb.predict(X_test) 

 

# 计算准确率

acc = np.sum(y_pred == y_test) / len(y_test) 

print("Accuracy:", acc) 

 

# 可视化结果 

plt.scatter(y_pred, y_test, s=20) 

plt.xlabel("Predicted label") 

plt.ylabel("True label") 

plt.title("Confusion Matrix") 

plt.show()

plt.savefig('ret.png')

# 训练贝叶斯分类器 

from sklearn.naive_bayes import MultinomialNB 

gnb1 = MultinomialNB()

gnb1.fit(X_train, y_train) 

 

# 预测测试集 

y_pred = gnb1.predict(X_test) 

 

# 计算准确率

acc = np.sum(y_pred == y_test) / len(y_test) 

print("Accuracy:", acc) 

 

# 可视化结果 

plt.scatter(y_pred, y_test, s=20) 

plt.xlabel("Predicted label") 

plt.ylabel("True label") 

plt.title("Confusion Matrix") 

plt.show()

plt.savefig('ret1.png')

# 训练贝叶斯分类器 

from sklearn.naive_bayes import BernoulliNB

gnb2 = BernoulliNB()

gnb2.fit(X_train, y_train) 

 

# 预测测试集 

y_pred = gnb2.predict(X_test) 

 

# 计算准确率

acc = np.sum(y_pred == y_test) / len(y_test) 

print("Accuracy:", acc) 

 

# 可视化结果 

plt.scatter(y_pred, y_test, s=20) 

plt.xlabel("Predicted label") 

plt.ylabel("True label") 

plt.title("Confusion Matrix") 

plt.show()

plt.savefig('ret2.png')

2:单张图片对模型进行测试的代码

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_openml

from sklearn.model_selection import train_test_split

from sklearn.naive_bayes import GaussianNB

from sklearn.metrics import accuracy_score

from skimage import io, color, filters

from skimage.transform import resize

import numpy as np

# 训练贝叶斯分类器 

from sklearn.naive_bayes import MultinomialNB 

gnb = MultinomialNB()

gnb.fit(X_train, y_train)

# 读取图像文件

image_path = r'C:\Users\86158\Desktop\train_labels\train_examples_labels\train_new\5.png'

image = io.imread(image_path)

# 将图像缩放到28x28像素

scaled_image = resize(image, (28, 28), anti_aliasing=True)

# 将图像数据转换为一维数组

flat_image = scaled_image.flatten()

# 将像素值缩放到0到1(如果你的模型是在这个范围的数据上训练的)

processed_image = flat_image / 255.0

# 确保这里的processed_image具有与模型训练数据相同的形状

# 如果是MNIST,它应该有784个特征

processed_image = processed_image.reshape(1, -1)

# 使用模型进行预测

prediction = gnb.predict(processed_image)

print(f"Predicted class for the input image: {prediction[0]}")

# 可选:显示图像

plt.imshow(scaled_image, cmap='gray')

plt.title(f'Predicted Class: {prediction[0]}')

plt.show()

3:对数据集进行主成分分析、标准化和归一化处理的完整代码

from sklearn.datasets import fetch_openml

from sklearn.model_selection import train_test_split

from sklearn.preprocessing import StandardScaler

from sklearn.decomposition import PCA

from sklearn.neighbors import KNeighborsClassifier

from sklearn.naive_bayes import MultinomialNB,GaussianNB,BernoulliNB

import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler

from sklearn.model_selection import GridSearchCV

from sklearn.tree import DecisionTreeClassifier

from sklearn.tree import plot_tree

from sklearn.tree import export_graphviz

from sklearn.ensemble import RandomForestClassifier

from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import learning_curve

from sklearn.metrics import classification_report

import numpy as np

mnist=fetch_openml("mnist_784",version=1,cache=True)

x=mnist.data

y=mnist.target

# 对输入特征进行主成分分析降维

print("降维前的特征数:",x.shape[1])

PCA_Transfer=PCA(n_components=0.95)

PCA_x=PCA_Transfer.fit_transform(x)

print("降维后的特征数:",PCA_x.shape[1])

# 划分训练集和测试集

x_train,x_test,y_train,y_test=train_test_split(PCA_x,y,test_size=0.2,random_state=1)

# 对输入特征值进行标准化处理

Standard_Transfer=StandardScaler()

standard_x_train=Standard_Transfer.fit_transform(x_train)

standard_x_test=Standard_Transfer.fit_transform(x_test)

# 对输入特征值进行归一化处理

MinMax_Transfer=MinMaxScaler()

Guiyihua_x_train=MinMax_Transfer.fit_transform(x_train)

Guiyihua_x_test=MinMax_Transfer.fit_transform(x_test)

Bayes_estimator1=MultinomialNB()

param_dic={"alpha":[0.5,0.6,0.7,0.8,0.9,1,1.1,1.2]}

Bayes_estimator1=GridSearchCV(Bayes_estimator1,param_grid=param_dic,cv=10,n_jobs=-1)

Bayes_estimator1.fit(Guiyihua_x_train,y_train)

print("多项式Bayes算法在测试集上的平均预测成功率:",Bayes_estimator1.score(Guiyihua_x_test,y_test))

Bayes_estimator2=GaussianNB()

Bayes_estimator2=GridSearchCV(Bayes_estimator2,cv=10,param_grid={},n_jobs=-1)

Bayes_estimator2.fit(x_train,y_train)

print("高斯Bayes算法在测试集上的平均预测成功率:",Bayes_estimator2.score(x_test,y_test))

Bayes_estimator3=BernoulliNB()

Bayes_estimator2=GridSearchCV(Bayes_estimator2,cv=10,param_grid={},n_jobs=-1)

Bayes_estimator3.fit(x_train,y_train)

print("伯努利Bayes算法在测试集上的平均预测成功率:",Bayes_estimator3.score(x_test,y_test))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/186465.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue.js ——Vuex

基本概念 vue进行开发过程中有没有遇到这样一种场景,就是有些时候一些数据是一种通用的共享数据(比如登录信息),那么这类数据在各个组件模块中可能都会用到,如果每个组件中都去后台重新获取那么势必会造成性能浪费&am…

websocket 消息包粗解

最近在搞websocket解析,记录一下: 原始字符串 �~�{"t":"d","d":{"b":{"p":"comds/comdssqmosm7k","d":{"comdss":{"cmdn":"success",…

免费使用GPT的网站

登录ChatGPT系统 登录ChatGPT系统 登录ChatGPT系统

ArkTs变量类型、数据类型基础语法

可以参考官网学习路径学习HarmonyOS第一课|应用开发视频教程学习|HarmonyOS应用开发官网 ArkTS是华为自研的开发语言。它在TypeScript(简称TS)的基础上,匹配ArkUI框架,扩展了声明式UI、状态管理等相应的能力,让开发者以…

浅谈安科瑞ASJ继电器在菲律宾矿厂的应用

摘要:对电气线路进行接地故障保护,方式接地故障电流引起的设备和电气火灾事故越来越成为日常所需。针对用户侧主要的用能节点,设计安装剩余电流继电器,实时监控各用能回路的剩余电流状态。通过实时监控用能以及相关电力参数、提高…

ESP32-Web-Server编程- 通过 Highcharts 创建图表(Chart)实时显示设备信息

ESP32-Web-Server编程- 通过 Highcharts 创建图表(Chart)实时显示设备信息 概述 上节讲述了通过 Server-Sent Events(以下简称 SSE) 实现在网页实时更新 ESP32 Web 服务器的传感器数据,并通过表格显示传感器的数据。…

操作系统--中断异常

操作系统第一章易错总结 1.操作系统的功能 ⭐ 编译器是操作系统的上层软件,不是操作系统需要提供的功能。 ⭐注意: 1.批处理的主要缺点是缺乏交互性 2.输入/输出指令需要中断操作,中断必须在核心态下执行 3.多道性是为了提高系统利用率和…

【Spring MVC】Filter 过滤器异常处理 HandlerExceptionResolver 分析

文章目录 前言版本说明测试 Demo1、自定义过滤器 DemoFilter2、自定义业务异常 ServiceException3、自定义异常处理类 DemoExceptionHandler4、DemoController5、请求测试 问题分析1、日志打印记录2、Debug 方法 解决方案1、修改自定义过滤器2、请求测试 解决方案分析1、日志打…

提升技能素养,AMCAP做出合适的决策

近年来,智能配置投资与理财逐渐受到关注并走俏。这是一种简单快捷的智慧化理财方式,通过将个人和家族的闲置资金投入到低风险高流动性的产品中。 国际财富管理投资机构AMCAP集团金融分析师表示:智能配置投资与理财之所以持续走俏&#xff0c…

6.3 Windows驱动开发:内核枚举IoTimer定时器

内核I/O定时器(Kernel I/O Timer)是Windows内核中的一个对象,它允许内核或驱动程序设置一个定时器,以便在指定的时间间隔内调用一个回调函数。通常,内核I/O定时器用于周期性地执行某个任务,例如检查驱动程序…

在Linux上安装KVM虚拟机

一、搭建KVM环境 KVM(Kernel-based Virtual Machine)是一个基于内核的系统虚拟化模块,从Linux内核版本2.6.20开始,各大Linux发行版就已经将其集成于发行版中。KVM与Xen等虚拟化相比,需要硬件支持的完全虚拟化。KVM由内…

使用 kubeadm 部署 Kubernetes 集群(一)linux环境准备

一、 初始化集群环境 准备三台 rocky8.8 操作系统的 linux 机器。每台机器配置:4VCPU/4G 内存/60G 硬盘 环境说明: IP 主机名 角色 内存 cpu 192.168.1.63 xuegod63 master 4G 4vCPU 192.168.1.64 xuegod64 worker 4G 4vCPU 192.168.1.62 xuegod62 work…

Python 异常处理(try except)

文章目录 1 概述1.1 异常示例 2 异常处理2.1 捕获异常 try except2.2 抛出异常 raise 3 异常类型3.1 内置异常3.2 自定义异常 1 概述 1.1 异常示例 异常:程序执行中出现错误,若不处理,则程序终止 示例代码: v 6 / 0 # 除数不…

基于matlab的图像去噪算法设计与实现

摘 要 随着我们生活水平的提高,科技产品飞速更新换代,在信息传输中,图像传输所占的比重越来越大。但自然噪声会在图像传输时干扰其传输过程,甚至会使图片不能表达其原来的意义。去噪处理就是为了去除图像中的噪声,从而…

【数据清洗 | 数据规约】数据类别型数据 编码最佳实践,确定不来看看?

🤵‍♂️ 个人主页: AI_magician 📡主页地址: 作者简介:CSDN内容合伙人,全栈领域优质创作者。 👨‍💻景愿:旨在于能和更多的热爱计算机的伙伴一起成长!!&…

tex2D使用学习

1. 背景&#xff1a; 项目中使用到了纹理进行插值的加速&#xff0c;因此记录一些自己在学习tex2D的一些过程 2. 代码&#xff1a; #include "cuda_runtime.h" #include "device_launch_parameters.h" #include <assert.h> #include <stdio.h>…

Maven的安装和使用

Maven是一个基于项目对象模型&#xff08;POM&#xff09;&#xff0c;可以管理项目构建、依赖管理、项目报告等的工具&#xff0c;使构建Java项目更容易。可以说Maven是一个项目管理和构建工具&#xff0c;它可以从管理项目的角度出发&#xff0c;将开发过程中的需求纳入进来&…

FFmpeg架构全面分析

一、简介 它的官网为&#xff1a;https://ffmpeg.org/&#xff0c;由Fabrice Bellard&#xff08;法国著名程序员Born in 1972&#xff09;于2000年发起创建的开源项目。该人是个牛人&#xff0c;在很多领域都有很大的贡献。 FFmpeg是多媒体领域的万能工具。只要涉及音视频领…

软文推广如何自然融入品牌?媒介盒子有妙招

软文推广作为一种柔性推广方式&#xff0c;能将品牌信息融入到用户日常浏览的内容中&#xff0c;让用户不知不觉接触品牌&#xff0c;从而产生好感&#xff0c;这种方式既可以避免广告带来的反感&#xff0c;又可以提高广告的有效性。那么在推广中应该如何自然融入品牌信息呢&a…

leetCode 78.子集 + 回溯算法 + 图解

给你一个整数数组 nums &#xff0c;数组中的元素 互不相同 。返回该数组所有可能的子集&#xff08;幂集&#xff09;。解集 不能 包含重复的子集。你可以按 任意顺序 返回解集 示例 1&#xff1a; 输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[],[1],[2],[1,2],[3],[1…