Keras多分类鸢尾花DEMO

完整的一个小demo:

pandas==1.2.4

numpy==1.19.2

python==3.9.2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pandas import DataFrame
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout
from sklearn import preprocessing
from sklearn.datasets import load_iris
# 映射函数iris_type: 将string的label映射至数字label
import os# data.to_csv('data.csv',index=False)  #cvs保存文件不会保存index列
# data = pd.read_csv('data.csv',index_col=0)  #读取csv文件的时候选择不读取第一列信息
def downLoad():path="../httdemo/"iris = load_iris()data = iris.data #获取特征数据target = iris.target#获取目标数据data_information = DataFrame(data, columns=['bcalyx', 'scalyx', 'length', 'width']) #重新定义特征数据的列名data_target = DataFrame(target, columns=['target'])#目标数据列名targetdata_csv = pd.concat([data_information, data_target], axis=1) #合并特征数据和目标数据到一个DataFrameif not os.path.exists(path):#把DataFrame数据保存到本地,以.CVS的格式保存os.makedirs(path)filename = path + 'iris.csv'  #定义保存路径data_csv.to_csv(filename,index=False) #index==False表示,序号下表列不做保存# 本地数据保存为excel文件# outputfile = "iris.xls"  # 保存文件路径名# column = list(data['feature_names'])# dd = pd.DataFrame(data.data, index=range(150), columns=column)# dt = pd.DataFrame(data.target, index=range(150), columns=['outcome'])# jj = dd.join(dt, how='outer')  # 用到DataFrame的合并方法,将data.data数据与data.target数据合并# jj.to_excel(outputfile)  # 将数据保存到outputfile文件中def readData(path):Data = pd.read_csv(path,names=['bcalyx', 'scalyx', 'length', 'width','target']) #读取本地保存的CVS数据Data.head(10)#展示前10# 变量初始化# 最后一列为y,其余为xcols = Data.shape[1]  # 获取列数 shape[0]行数 [1]列数X = Data.iloc[1:, 0:cols - 1].astype(float)  # 获取得到特征数据,转换为Float的格式,如果输入str,会报错的,取前cols-1列,即输入向量y = Data.iloc[1:, cols - 1:cols]  # 取最后一列,即目标变量X = np.array(X)y = np.array(y)print(y)return X,ydef startM():path = "../httdemo/iris.csv"X,y=readData(path)  #加载数据from sklearn.preprocessing import OneHotEncoder# 创建独热编码器对象encoder = OneHotEncoder() #sklearn创建热编码器对象# 训练独热编码器 (将目标数据进行训练)encoder.fit(y)# 转换特征向量 (将目标数据y转换为特征向量[[0,0,1][0,1,0][0,0,1]])格式encoded_data = encoder.transform(y).toarray()# shuffle = True 随机打乱后再进行分割数据X_train, X_test, y_train, y_test = train_test_split(X, encoded_data, test_size=0.3,shuffle=True)#构建网络模型model = Sequential()model.add(Dense(units=1024, activation='relu', input_dim=4))  # 输入层,1024个激活单元,激活函数为relu,输入数据维度为(4,)model.add(Dense(units=512, activation='relu'))  # 隐藏层,512个激活单元,激活函数为relumodel.add(Dense(units=256, activation='relu'))  # 隐藏层,256个激活单元,激活函数为relumodel.add(Dropout(0.1)) #丢到10%的数据model.add(Dense(units=3, activation='softmax'))  # 输出层,3个输出单元,激活函数为softmax)model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])#开始训练model.fit(X_train, y_train, batch_size=30, epochs=32)#预测测试集的结果result = model.predict(X_test)yTest=np.round(result, 2)#保留俩位小数print(yTest)#测试机准确率评估score = model.evaluate(X_test, y_test)print('loss值为:', score[0])print('准确率为:', score[1])if __name__=='__main__':startM()# downLoad()

 

特征数据是str,需要转换成float 

X = Data.iloc[1:, 0:cols - 1].astype(float)

 

target的数据打印:

热编码转换之后的数据:

测试集预测结果:表示的位概率值,那个数值比较大,就是哪一个类别,每一个数组表示A,B,C

[[0.01 0.28 0.71]
 [0.91 0.06 0.03]
 [0.01 0.28 0.71]
 [0.02 0.33 0.66]
 [0.01 0.28 0.71]
 [0.06 0.51 0.44]
 [0.92 0.05 0.02]
 [0.04 0.43 0.53]
 [0.02 0.38 0.6 ]
 [0.01 0.31 0.67]
 [0.03 0.42 0.55]
 [0.01 0.24 0.76]
 [0.01 0.32 0.67]
 [0.06 0.49 0.45]
 [0.01 0.25 0.74]
 [0.01 0.31 0.68]
 [0.08 0.51 0.42]
 [0.92 0.06 0.02]
 [0.88 0.09 0.04]
 [0.02 0.33 0.65]
 [0.01 0.29 0.7 ]
 [0.01 0.28 0.71]
 [0.04 0.47 0.49]
 [0.9  0.07 0.03]
 [0.91 0.06 0.03]
 [0.86 0.1  0.04]
 [0.18 0.5  0.31]
 [0.89 0.08 0.03]
 [0.91 0.07 0.03]
 [0.06 0.47 0.48]
 [0.02 0.37 0.61]
 [0.04 0.39 0.57]
 [0.87 0.09 0.04]
 [0.05 0.46 0.49]
 [0.01 0.27 0.72]
 [0.02 0.34 0.64]
 [0.05 0.45 0.5 ]
 [0.92 0.06 0.02]
 [0.09 0.53 0.38]
 [0.04 0.48 0.48]
 [0.95 0.04 0.02]
 [0.01 0.26 0.73]
 [0.   0.24 0.76]
 [0.78 0.15 0.07]
 [0.   0.21 0.79]]

运行结果:

训练完成之后保存模型,然后测试模型:

 

读取模型,开始预测:

from tensorflow.keras.models import load_model
import numpy as np
# 模型的导入
model = load_model('../httdemo/httmodel.h5')
# 对数据的预测输入分别为[花萼长,花萼宽,花瓣长,花瓣宽]
y_pred = model.predict([[2,1,5.5,2],[2.3,4.5,5.2,9]])
print(y_pred)
for i in y_pred:a = np.argmax(i)if a == 0 : print('该花为A')elif a == 1 : print('该花为B')elif a == 2 : print('该花为C')

测试结果:准确预测出来为C种类 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/582966.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MYSQL】MYSQL 的学习教程(十)之 InnoDB 锁

数据库为什么需要加锁呢? 如果有多个并发请求存取数据,在数据就可能会产生多个事务同时操作同一行数据。如果并发操作不加控制,不加锁的话,就可能写入了不正确的数据,或者导致读取了不正确的数据,破坏了数…

MySQL数据库多版本并发控制(MVCC)

在数据库中,并发控制是确保多个事务能够同时执行,而不会导致数据不一致或冲突的关键机制。多版本并发控制(MVCC)是一种流行的并发控制方法,它可以允许多个事务同时读取同一数据项的不同版本,而不会相互阻塞。本文将讨论MVCC的原理…

在 iPhone 手机上恢复数据的 7 个有效应用程序

我们的生活离不开 iPhone。无论我们走到哪里,他们都陪伴着我们,让我们保持联系、拍摄照片和视频,并提供娱乐。与此同时,您将计算机安全地放在办公桌上,不受天气影响,也不受伤害。如果您要在任何地方丢失重要…

Redis实现滚动周榜|滚动榜单|直播间榜单|排行榜|Redis实现日榜04

上述文章主要探讨了实现滚动榜单的两种方式。第一种方式是同步写n天滚动榜单,但这种方式存在一个严重的缺点:每天都需要编写多个榜单。尽管在实现三天或七天滚动榜单时相对简单,但若要实现近30天的滚动榜单,这种方式显得不够智能。…

API集群负载统计 (100%用例)C卷 (JavaPythonNode.jsC语言C++)

某个产品的RESTful API集合部署在服务器集群的多个节点上, 近期对客户端访问日志进行了采集,需要统计各个API的访问频次, 根据热点信息在服务器节点之间做负载均衡,现在需要实现热点信息统计查询功能。 RESTful API的由多个层级构成,层级之间使用/连接,如/A/B/C/D这个地址…

【ES】Elasticsearch常见问题与解决(持续更新)

目录 Elasticsearch常见问题 1. 集群健康问题 2. 性能问题 3. 映射问题 4. 分片问题 5. 内存问题 6. 硬件问题 7. 配置问题 8. 安全问题 9. 网络问题 10. 版本不兼容 Elasticsearch日常使用小结 【Q】离线告警,有IP已离线 【Q】统计某个应用的某个索引…

Spring Boot笔记2

3. SpringBoot原理分析 3.1. 起步依赖原理解析 3.1.1. 分析spring-boot-starter-parent 按住Ctrl键,然后点击pom.xml中的spring-boot-starter-parent,跳转到了spring-boot-starter-parent的pom.xml,xml配置如下(只摘抄了部分重…

Mybatis Java API - SqlSession

正如前面提到的,​SqlSession​实例是MyBatis中最重要、最强大的类。它是您将找到执行语句、提交或回滚事务以及获取映射器实例的所有方法的地方。 SqlSession 类上有超过二十个方法,让我们将它们分成更易理解的组别。 Statement Execution Methods-语…

Android 13 - Media框架(28)- MediaCodec(三)

上一节我们了解到 ACodec 执行完 start 流程后,会把所有的 input buffer 都提交给 MediaCodec 层,MediaCodec 是如何处理传上来的 buffer 呢?这一节我们就来了解一下这部分内容。 1、ACodecBufferChannel::fillThisBuffer ACodec 通过调用 A…

Java 代理模式

一、代理模式概述 代理模式是一种比较好理解的设计模式。简单来说就是 我们使用代理对象来代替对真实对象(real object)的访问,这样就可以在不修改原目标对象的前提下,提供额外的功能操作,扩展目标对象的功能。 代理模式的主要作用是扩展目标…

C++ 383. 赎金信 (a b字符串计数比较)

给你两个字符串:ransomNote 和 magazine ,判断 ransomNote 能不能由 magazine 里面的字符构成。 如果可以,返回 true ;否则返回 false 。 magazine 中的每个字符只能在 ransomNote 中使用一次。 示例 1: 输入&…

7天玩转 Golang 标准库之 flag

在编写Golang命令行应用时,flag标准库无疑是一个很有价值的工具。它允许你以各种方式来定义和解析命令行参数。 基础示例:定义与解析参数 若想使用flag标准库,你必须首先定义你希望从命令行接收的参数。下面展示了几种常见的参数类型&#x…

Linux中proc文件系统相关介绍

proc虚拟文件系统的工作原理 linux 内核是一个非常庞大、非常复杂的一个单独的程序,对于这样一个程序来说调试是非常复杂的。像kernel这样庞大的项目,给里面添加或者修改一个功能是非常麻烦的,因为添加一个功能可能会影响其他已经有的功能。…

3D动态路障生成

3D动态路障生成 介绍设计实现1.路面创建2.空物体的创建3.Create.cs脚本创建 总结 介绍 上一篇文章介绍了Mathf.Lerp的底层实现原理,这里介绍一下跑酷类游戏的动态路障生成是如何实现的。 动态路障其实比较好生成,但是难点在哪里,如果都是平面…

6. C++的引用与指针

摘要:本文首先介绍 C 的内存模型和变量周期作为知识背景,接着对C中的引用和指针(原始指针和智能指针)进行介绍。 1. 对象生命周期 什么是对象生命周期?简单来说,对象生命周期指的是:对象从创建…

JMeter逻辑控制器之While控制器

JMeter逻辑控制器之While控制器 1. 背景2.目的3. 介绍4.While示例4.1 添加While控制器4.2 While控制器面板4.3 While控制器添加请求4.3 While控制器应用场景 1. 背景 存在一些使用场景,比如:某个请求必须等待上一个请求正确响应后才能开始执行。或者&…

Idea如何从磁盘中应用 下载好的插件流程,安装zip压缩包。

1、将下载的插件文件(通常是一个ZIP文件)复制到IntelliJ IDEA的“plugins”文件夹中。 IDEA版本 2、重启IntelliJ IDEA。 3、在设置窗口中,选择左侧的“Plugins”。 4、选择之前复制到“plugins”文件夹中的插件文件,点击“OK”按…

基于Wenet长音频分割降噪识别

Wenet是一个流行的语音处理工具,它专注于长音频的处理,具备分割、降噪和识别功能。它的长音频分割降噪识别功能允许对长时间录制的音频进行分段处理,首先对音频进行分割,将其分解成更小的段落或语音片段。接着进行降噪处理&#x…

springboot学习(八十五) 解决springboot3.2找不到资源无法抛出404错误的问题

前言 springboot3.2以下可以定义ErrorPageRegistrar将404错误转发到一个接口地址,但升级到springboot3.2(spring6.1)后,该配置不生效,抛出了500错误。 以前的错误页面处理如下: ConditionalOnClass(ErrorPageRegist…

深入理解二分查找算法(一)

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…