基于MLP完成CIFAR-10数据集和UCI wine数据集的分类

基于MLP完成CIFAR-10数据集和UCI wine数据集的分类,使用到了sklearn和tensorflow,并对图片分类进行了数据可视化展示

数据集介绍

UCI wine数据集:

http://archive.ics.uci.edu/dataset/109/wine

这些数据是对意大利同一地区种植的葡萄酒进行化学分析的结果,但来自三个不同的品种。该分析确定了三种葡萄酒中每一种中发现的13种成分的数量。

CIFAR-10数据集:

https://www.cs.toronto.edu/~kriz/cifar.html

CIFAR-10 数据集由 10 类 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。

数据集分为 5 个训练批次和 1 个测试批次,每个批次有 10000 张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类。在它们之间,训练批次正好包含来自每个类的 5000 张图像

MLP算法

MLP 代表多层感知器(Multilayer Perceptron),是一种基本的前馈神经网络(Feedforward Neural Network)模型。它由一个输入层、一个或多个隐藏层和一个输出层组成,其中每个层都包含多个神经元(或称为节点)。MLP 是一种强大的模型,常用于解决分类和回归问题。

MLP 的基本组成部分如下:

  • 输入层(Input Layer): 接收原始数据的输入层,每个输入节点对应输入特征。

  • 隐藏层(Hidden Layer):
    位于输入层和输出层之间的一层或多层神经元。每个神经元通过权重与前一层的所有节点相连接,并通过激活函数进行非线性变换。隐藏层的存在使得网络能够学习输入数据的复杂特征。

  • 输出层(Output Layer): 提供最终的网络输出。对于不同的问题,输出层的激活函数可能不同。例如,对于二分类问题,可以使用
    sigmoid 激活函数;对于多分类问题,可以使用 softmax 激活函数。

模型构建

UCI wine:

我们加载 sklearn.datasets 中的 load_wine作为训练数据,划分为数据集和测试集,并进行标准化操作

接着调用 MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42) 创建模型

训练后在测试集上预测,最后评估模型
在这里插入图片描述

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_iris
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler# 加载Iris数据集
# iris = load_iris()
# X = iris.data
# y = iris.targetwine = load_wine()
X = wine.data
y = wine.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)# 构建MLP模型
mlp = MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42)# 训练模型
mlp.fit(X_train_scaled, y_train)# 在测试集上进行预测
y_pred = mlp.predict(X_test_scaled)# 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)# 打印结果
print("Accuracy:", accuracy)
print("\nConfusion Matrix:\n", conf_matrix)
print("\nClassification Report:\n", class_report)

CIFAR-10:

我们使用 tf.keras.datasets.cifar10中自带的数据进行训练

使用 tf.keras.Sequential() 这个函数创建模型,设置四层网络

接着对代码进行批量训练,评估和保留模型后对结果进行可视化处理

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

########cifar10数据集##########
###########保存模型############
########卷积神经网络##########
#train_x:(50000, 32, 32, 3), train_y:(50000, 1), test_x:(10000, 32, 32, 3), test_y:(10000, 1)
#60000条训练数据和10000条测试数据,32x32像素的RGB图像
#第一层两个卷积层16个3*3卷积核,一个池化层:最大池化法2*2卷积核,激活函数:ReLU
#第二层两个卷积层32个3*3卷积核,一个池化层:最大池化法2*2卷积核,激活函数:ReLU
#隐含层激活函数:ReLU函数
#输出层激活函数:softmax函数(实现多分类)
#损失函数:稀疏交叉熵损失函数
#隐含层有128个神经元,输出层有10个节点
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npimport time
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print(nowtime)#指定GPU
#import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']#加载数据
cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))#数据预处理
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)#建立模型
model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])#训练模型
#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练前时刻:'+str(nowtime))history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练后时刻:'+str(nowtime))#评估模型
model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')#结果可视化
print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
#根据需要自行选择
#plt.ion()       #打开交互式操作模式
#plt.show()
#plt.pause(5)
#plt.close()#使用模型
plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
#print('X_test[0:5]: %s'%(X_test[0:5].shape))
#print('y_pred: %s'%(y_pred))#plt.ion()       #打开交互式操作模式
plt.show()
#plt.pause(5)
#plt.close()

项目地址

https://gitee.com/yishangyishang/homeword.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/236758.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[SWPUCTF 2021 新生赛]jicao

首先打开环境 代码审计,他这儿需要进行GET传参和POST传参,需要进行POST请求 变量idwllmNB,进行GET请求变量json里需要含参数x以及jsonwllm 构造 得到flag

在线更换Proxmox VE超融合集群Ceph OSD磁盘

因为资源紧张的原因,担心一旦关机,虚拟机因为没有空闲的资源而被冻结,以致于不能漂移,导致部分服务停止,只好让机房帮忙热插拔。 幸运的是,插上去能够被系统所识别(/dev/sdf就是新插入的硬盘&am…

ip静态好还是dhcp好?

选择使用静态 IP 还是 DHCP(动态主机配置协议)取决于您的网络需求和环境。下面是它们的一些特点和适用场景: 静态 IP: 固定的 IP 地址:静态 IP 是手动配置在设备上的固定 IP 地址,不会随时间或网络变化而改…

指针(3)计算最长的字符串长度(本题要求实现一个函数,用于计算有n个元素的指针数组s中最长的字符串的长度)以及就题讲解malloc函数

题目要求 本题要求实现一个函数,用于计算有n个元素的指针数组s中最长的字符串的长度。 函数接口定义: int max_len( char *s[], int n ); 其中n个字符串存储在s[]中,函数max_len应返回其中最长字符串的长度。 裁判测试程序样例&#xff…

Redis多路复用在不同操作系统的性能

Redis多路复用 在redis中在不同系统中会使用不同的模型,但是多路复用的原理还是一个进程或者线程可以同时处理多个连接,不需要为每个连接都创建一个单独的线程或者进程。不同的多路复用机制在具体实现有一些差异,但是核心思想还是单个进程或…

C++ 如何使用二维map(二维map使用简单例子)

在C中&#xff0c;可以使用std::map来创建二维map。 首先&#xff0c;你需要包含<map>头文件&#xff1a; #include <map>然后&#xff0c;你可以声明一个map对象&#xff0c;其中一个维度用作键&#xff0c;另一个维度用作值。例如&#xff0c;如果你想创建一个…

华为数通方向HCIP-DataCom H12-831题库(多选题:181-200)

第181题 如图所示,R1、R2、R3、R4都部署为SPF区域0,链路的cost值如图中标识。R1、R2R3、R4的Loopback0通告入OSPF。R1、R2、R3与R4使用Loopback0作为连接接口,建立BGP对等体关系,其中R4为RR设备,R1、R2、R3是R4的客户端。当R4的直连地址172.20,1,4/32通告入BGP后,以下关R…

mysql部署 --(docker)

先查找MySQL 镜像 Docker search mysql &#xff1b; 拉取mysql镜像&#xff0c;默认拉取最新的&#xff1b; 创建mysql容器&#xff0c;-p 代表端口映射&#xff0c;格式为 宿主机端口&#xff1a;容器运行端口 -e 代表添加环境变量&#xff0c;MYSQL_ROOT_PASSWORD是root用户…

2024云渲染平台推荐!免费云渲染平台(注册账号)推荐

近期&#xff0c;由于信息科技和云计算的迅猛发展&#xff0c;云渲染服务也逐步蓬勃发展并已达到成熟阶段&#xff0c;它因其高效办公、成本效益优越以及操作简便易行&#xff0c;赢得了广大设计师和 CG 艺术家的热烈欢迎&#xff0c;在过去的岁月里&#xff0c;创造一张高清晰…

OpenCV技术应用(8)— 如何将视频分解

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。本节课就手把手教大家如何将一幅图像转化成热力图&#xff0c;希望大家学习之后能够有所收获~&#xff01;&#x1f308; 目录 &#x1f680;1.技术介绍 &#x1f680;2.实现代码 &#x1f680;1.技术介绍 视频是…

charles和谷歌浏览器在Mac上进行软件安装,桌面上会显示一个虚拟磁盘,关掉页面推出磁盘内容都消失掉了 需要再次安装问题解决

其他软件也会有这种情况&#xff0c;这里我们以charles为例。绿色背景的内容是重点步骤。 1.如图&#xff0c;我下载了一个charles一个版本的dmg文件。 2.打开后&#xff0c;选择Agree 3.桌面会出现一个磁盘和如下页面 4.错误操作------可以不看 直接看第5步正确操作 常规情…

Socket.D 基于消息的响应式应用层网络协议

首先根据 Socket.D 官网的副标题&#xff0c;Socket.D 的自我定义是&#xff1a; 基于事件和语义消息流的网络应用协议。官网定义的特点是&#xff1a; 基于事件&#xff0c;每个消息都可事件路由所谓语义&#xff0c;通过元信息进行语义描述流关联性&#xff0c;有相关的消息…

【Java JVM】JVM 分析工具

在 $JAVA_HOME/bin 的目录下, 存在着许多小工具, 除了编译和运行 Java 程序外, 打包, 部署, 签名, 调试, 监控, 运维等各种场景都可能会用到它们。 1 常用的命令行工具 1.1 jps (JVM Process Status Tool) - 虚拟机进程状况工具 列出正在运行的虚拟机进程, 并显示虚拟机执行…

Android Studio设置android:background 属性背景颜色

除了默认的颜色之外都要自己添加。 添加颜色的操作步骤&#xff1a; 打开res文件夹&#xff0c;找values&#xff0c;里面有个colors.xml的文件。然后在里面定义一些颜色。 完成

java 10. 注解

java 10. 注解 文章目录 10.1 什么是注解10.2 内置的注解10.2.1 Override 10.3 自定义注解10.4 元注解 10.1 什么是注解 注解&#xff0c;一种元数据形式&#xff0c;提供有关程序的数据&#xff0c;该数据不属于程序本身。注释对其注释的代码的操作没有直接影响。 注解的语…

Spring和Spring Boot的主要区别

Spring 和 Spring Boot 是两种不同的 Java 框架&#xff0c;尽管它们具有许多相似之处&#xff0c;但它们在目的、配置方式、起步依赖、运行方式和适应性等方面仍然存在许多不同。 首先&#xff0c;他们的目的不同&#xff0c;Spring 的主要目的是为 Java 应用程序提供一个基础…

21、Web攻防——JavaWeb项目JWT身份攻击组件安全访问控制

文章目录 一、JavaWeb二、JWT攻击 一、JavaWeb webgoat 1.java web的配置文件&#xff0c;配置文件一般在META-INF目录下&#xff0c;文件名常为pom.xml或web.xml 2.如何通过请求&#xff0c;查看运行的java代码。 地址信息PathTraversal/profile-upload 直接找到以该字符P…

网络安全知识图谱 图数据库介绍及语法

本体构建: 资产&#xff1a; 系统&#xff0c;软件 威胁&#xff1a; 攻击&#xff1a; 建模&#xff1a; 3个本体 5个实体类型 CWE漏洞库 http://cwe.mitre.org/data/downloads.html CPECP攻击模式分类库 http://capec.mitre.org/data/downloads.html CPE通用组件库 http:…

spring使用@Scheduled配置定时任务

schedule的使用 在Spring中&#xff0c;你可以使用Scheduled注解来标记一个方法&#xff0c;使其成为一个定时任务。一般情况下Scheduled注解修饰的任务方法没有返回值也没有入参。 开启schedule Configuration EnableScheduling public class ScheduleConfig {}在配置类上使…

B树和B+树的区别

1.节点结构&#xff1a; a.B树&#xff1a; B树的每个节点包含键和对应的值&#xff0c;子节点的数量和键的数量相等。在B树中&#xff0c;每个节点都存储键和值&#xff0c;并且非叶子节点的键值对应于其子节点的范围。 b.B树&#xff1a; B树的非叶子节点只包含键&#xff0…