基于MLP完成CIFAR-10数据集和UCI wine数据集的分类

基于MLP完成CIFAR-10数据集和UCI wine数据集的分类,使用到了sklearn和tensorflow,并对图片分类进行了数据可视化展示

数据集介绍

UCI wine数据集:

http://archive.ics.uci.edu/dataset/109/wine

这些数据是对意大利同一地区种植的葡萄酒进行化学分析的结果,但来自三个不同的品种。该分析确定了三种葡萄酒中每一种中发现的13种成分的数量。

CIFAR-10数据集:

https://www.cs.toronto.edu/~kriz/cifar.html

CIFAR-10 数据集由 10 类 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。

数据集分为 5 个训练批次和 1 个测试批次,每个批次有 10000 张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类。在它们之间,训练批次正好包含来自每个类的 5000 张图像

MLP算法

MLP 代表多层感知器(Multilayer Perceptron),是一种基本的前馈神经网络(Feedforward Neural Network)模型。它由一个输入层、一个或多个隐藏层和一个输出层组成,其中每个层都包含多个神经元(或称为节点)。MLP 是一种强大的模型,常用于解决分类和回归问题。

MLP 的基本组成部分如下:

  • 输入层(Input Layer): 接收原始数据的输入层,每个输入节点对应输入特征。

  • 隐藏层(Hidden Layer):
    位于输入层和输出层之间的一层或多层神经元。每个神经元通过权重与前一层的所有节点相连接,并通过激活函数进行非线性变换。隐藏层的存在使得网络能够学习输入数据的复杂特征。

  • 输出层(Output Layer): 提供最终的网络输出。对于不同的问题,输出层的激活函数可能不同。例如,对于二分类问题,可以使用
    sigmoid 激活函数;对于多分类问题,可以使用 softmax 激活函数。

模型构建

UCI wine:

我们加载 sklearn.datasets 中的 load_wine作为训练数据,划分为数据集和测试集,并进行标准化操作

接着调用 MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42) 创建模型

训练后在测试集上预测,最后评估模型
在这里插入图片描述

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_iris
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler# 加载Iris数据集
# iris = load_iris()
# X = iris.data
# y = iris.targetwine = load_wine()
X = wine.data
y = wine.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)# 构建MLP模型
mlp = MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42)# 训练模型
mlp.fit(X_train_scaled, y_train)# 在测试集上进行预测
y_pred = mlp.predict(X_test_scaled)# 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)# 打印结果
print("Accuracy:", accuracy)
print("\nConfusion Matrix:\n", conf_matrix)
print("\nClassification Report:\n", class_report)

CIFAR-10:

我们使用 tf.keras.datasets.cifar10中自带的数据进行训练

使用 tf.keras.Sequential() 这个函数创建模型,设置四层网络

接着对代码进行批量训练,评估和保留模型后对结果进行可视化处理

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

########cifar10数据集##########
###########保存模型############
########卷积神经网络##########
#train_x:(50000, 32, 32, 3), train_y:(50000, 1), test_x:(10000, 32, 32, 3), test_y:(10000, 1)
#60000条训练数据和10000条测试数据,32x32像素的RGB图像
#第一层两个卷积层16个3*3卷积核,一个池化层:最大池化法2*2卷积核,激活函数:ReLU
#第二层两个卷积层32个3*3卷积核,一个池化层:最大池化法2*2卷积核,激活函数:ReLU
#隐含层激活函数:ReLU函数
#输出层激活函数:softmax函数(实现多分类)
#损失函数:稀疏交叉熵损失函数
#隐含层有128个神经元,输出层有10个节点
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npimport time
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print(nowtime)#指定GPU
#import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']#加载数据
cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))#数据预处理
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)#建立模型
model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])#训练模型
#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练前时刻:'+str(nowtime))history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练后时刻:'+str(nowtime))#评估模型
model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')#结果可视化
print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
#根据需要自行选择
#plt.ion()       #打开交互式操作模式
#plt.show()
#plt.pause(5)
#plt.close()#使用模型
plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
#print('X_test[0:5]: %s'%(X_test[0:5].shape))
#print('y_pred: %s'%(y_pred))#plt.ion()       #打开交互式操作模式
plt.show()
#plt.pause(5)
#plt.close()

项目地址

https://gitee.com/yishangyishang/homeword.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/236758.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[SWPUCTF 2021 新生赛]jicao

首先打开环境 代码审计,他这儿需要进行GET传参和POST传参,需要进行POST请求 变量idwllmNB,进行GET请求变量json里需要含参数x以及jsonwllm 构造 得到flag

在线更换Proxmox VE超融合集群Ceph OSD磁盘

因为资源紧张的原因,担心一旦关机,虚拟机因为没有空闲的资源而被冻结,以致于不能漂移,导致部分服务停止,只好让机房帮忙热插拔。 幸运的是,插上去能够被系统所识别(/dev/sdf就是新插入的硬盘&am…

华为数通方向HCIP-DataCom H12-831题库(多选题:181-200)

第181题 如图所示,R1、R2、R3、R4都部署为SPF区域0,链路的cost值如图中标识。R1、R2R3、R4的Loopback0通告入OSPF。R1、R2、R3与R4使用Loopback0作为连接接口,建立BGP对等体关系,其中R4为RR设备,R1、R2、R3是R4的客户端。当R4的直连地址172.20,1,4/32通告入BGP后,以下关R…

mysql部署 --(docker)

先查找MySQL 镜像 Docker search mysql ; 拉取mysql镜像,默认拉取最新的; 创建mysql容器,-p 代表端口映射,格式为 宿主机端口:容器运行端口 -e 代表添加环境变量,MYSQL_ROOT_PASSWORD是root用户…

2024云渲染平台推荐!免费云渲染平台(注册账号)推荐

近期,由于信息科技和云计算的迅猛发展,云渲染服务也逐步蓬勃发展并已达到成熟阶段,它因其高效办公、成本效益优越以及操作简便易行,赢得了广大设计师和 CG 艺术家的热烈欢迎,在过去的岁月里,创造一张高清晰…

OpenCV技术应用(8)— 如何将视频分解

前言:Hello大家好,我是小哥谈。本节课就手把手教大家如何将一幅图像转化成热力图,希望大家学习之后能够有所收获~!🌈 目录 🚀1.技术介绍 🚀2.实现代码 🚀1.技术介绍 视频是…

charles和谷歌浏览器在Mac上进行软件安装,桌面上会显示一个虚拟磁盘,关掉页面推出磁盘内容都消失掉了 需要再次安装问题解决

其他软件也会有这种情况,这里我们以charles为例。绿色背景的内容是重点步骤。 1.如图,我下载了一个charles一个版本的dmg文件。 2.打开后,选择Agree 3.桌面会出现一个磁盘和如下页面 4.错误操作------可以不看 直接看第5步正确操作 常规情…

Socket.D 基于消息的响应式应用层网络协议

首先根据 Socket.D 官网的副标题,Socket.D 的自我定义是: 基于事件和语义消息流的网络应用协议。官网定义的特点是: 基于事件,每个消息都可事件路由所谓语义,通过元信息进行语义描述流关联性,有相关的消息…

Android Studio设置android:background 属性背景颜色

除了默认的颜色之外都要自己添加。 添加颜色的操作步骤: 打开res文件夹,找values,里面有个colors.xml的文件。然后在里面定义一些颜色。 完成

21、Web攻防——JavaWeb项目JWT身份攻击组件安全访问控制

文章目录 一、JavaWeb二、JWT攻击 一、JavaWeb webgoat 1.java web的配置文件,配置文件一般在META-INF目录下,文件名常为pom.xml或web.xml 2.如何通过请求,查看运行的java代码。 地址信息PathTraversal/profile-upload 直接找到以该字符P…

网络安全知识图谱 图数据库介绍及语法

本体构建: 资产: 系统,软件 威胁: 攻击: 建模: 3个本体 5个实体类型 CWE漏洞库 http://cwe.mitre.org/data/downloads.html CPECP攻击模式分类库 http://capec.mitre.org/data/downloads.html CPE通用组件库 http:…

Modbus RTU协议与S7 200 PLC通讯

一、Modbus RTU功能码 二、功能码使用与解析实例 01功能码 –读线圈状态 主机发送:01 01 00 01 00 08 6C 0C 从机回复: 01 01 01 2F 10 54 主机解析:01 地址(设备ID); 01 功能码;00 01 代表查询的起始线圈地址,即…

二叉树题目:输出二叉树

文章目录 题目标题和出处难度题目描述要求示例数据范围 前言解法一思路和算法代码复杂度分析 解法二思路和算法代码复杂度分析 题目 标题和出处 标题:输出二叉树 出处:655. 输出二叉树 难度 6 级 题目描述 要求 给定二叉树的根结点 root \textt…

链接未来:深入理解链表数据结构(一.c语言实现无头单向非循环链表)

在上一篇文章中,我们探索了顺序表这一基础的数据结构,它提供了一种有序存储数据的方法,使得数据的访 问和操作变得更加高效。想要进一步了解,大家可以移步于上一篇文章:探索顺序表:数据结构中的秩序之美 今…

06.仿简道云公式函数实战-前瞻

1.前言 在上篇文章中,我们介绍了QLExpress的进阶知识,扩展操作符,自定义操作符和自定义函数等内容。学了上面的内容后,目前对于QLExpress使用已经问题不大,从这篇文章,我们就进入我们的主题仿简道云公式函…

CentOS:Docker容器中安装vim

在使用docker容器时,里边没有安装vim时,敲vim命令时提示说:vim: command not found 这个时候就须要安装vim,安装命令: apt-get install vim 出现以下错误: 解决方法: apt-get update 这个命令的…

Spark中使用scala完成数据抽取任务 -- 总结

如题 任务二:离线数据处理,校赛题目需要使用spark框架将mysql数据库中ds_db01数据库的user_info表的内容抽取到Hive库的user_info表中,并且添加一个字段设置字段的格式 第二个任务和第一个的内容几乎一样。 在该任务中主要需要完成以下几个阶…

刷题记录第五十一天-去除重复字母

题目要求的是字典序最小的结果。只需要理解一点就是按大小顺序排列的字符串的字典序就是最小的,如“abcd”这种。 解题思路如下: 首先明确要使用栈结构,并且是从栈底到栈顶递增,要尽可能保证递增,这样就能保证字典序最…

前端项目常用函数封装(二)

文章目录 前端项目常用函数封装(一)判断两个数组是否有相同元素 返回相同元素(数组)判断hex颜色值是深色还是浅色随机生成深浅样色 js判断是手机端还是移动端使用UA判断使用媒体查询判断 fetch直接读文件内容,解决乱码问题下载文件将字符串下…

ansibe的脚本---playbook剧本(1)

playbook剧本组成部分: 1、task 任务: 主要是包含要在目标主机上的操作,使用模块定义操作。每个任务都是模块的调用。 2、variables变量:存储和传递数据。变量可自定义,可以在playbook中定义为全局变量,可…