纯Python实现鸢尾属植物数据集神经网络模型

摘要: 本文以Python代码完成整个鸾尾花图像分类任务,没有调用任何的数据包,适合新手阅读理解,并动手实践体验下机器学习方法的大致流程。

       尝试使用过各大公司推出的植物识别APP吗?比如微软识花、花伴侣等这些APP。当你看到一朵不知道学名的花时,只需要打开植物识别APP,拍摄一张你所想辨认的植物照片并上传,APP会自动识别出该花的品种及详细介绍,感觉手机中装了一个知识渊博的生物学家,是不是很神奇?其实,背后的原理很简单,是一个图像分类的过程,将上传的图像与手机中预存的数据集或联网数据进行匹配,将其分类到对应的类别即可。随着深度学习方法的应用,图像分类的精度越来越高,在部分数据集上已经超越了人眼的能力。
       相对于传统神经网络的方法而言,深度学习方法一般对数据集规模、硬件平台有着比较高的要求,如果只是单纯的想尝试了解图像分类任务的基本流程,建议采用小数据集样本及传统的神经网络方法实现。本文将带领读者采用鸢尾属植物数据集(Iris Data Set)来实现一个分类任务,整个鸢尾属植物数据集是机器学习中历史悠久的数据集,比现在常用的数字手写体数据集(Mnist Data Set)数据集还要早得多,该数据集来源于英国著名的统计学家、生物学家Ronald Fiser。本文在不使用相关软件库的情况下,从头开始构建针对鸢尾属植物数据的神经网络模型,对其进行训练并获得好的结果。

1


       鸢尾属植物数据集是用于测试机器学习算法的最常用数据集。该数据包含四种特征,萼片长度、萼片宽度、花瓣长度和花瓣宽度,用于鸢尾属植物的不同物种(versicolorvirginicasetosa)。此外,每个物种有50个实例(数据行),下面让我们看看样本数据分布情况。

2


       我们将在这个数据集上使用神经网络构建分类模型。为了简单起见,使用花瓣长度和花瓣宽度作为特征,且只有两类物种:versicolor和virginica。下面就让我们在Python中逐步训练针对该样本数据集的神经网络:

步骤1:准备鸢尾属植物数据集

       将Iris数据集导入python并对数据进行子集划分以保留行之间的相关性:

#import libraries
import os
import pandas as pd
#Set working directory and load data
os.chdir('C:\\Users\\rohan\\Documents\\Analytics\\Data')
iris = pd.read_csv('iris.csv')
#Create numeric classes for species (0,1,2) 
iris.loc[iris['Name']=='virginica','species']=0
iris.loc[iris['Name']=='versicolor','species']=1
iris.loc[iris['Name']=='setosa','species'] = 2
iris = iris[iris['species']!=2]
#Create Input and Output columns
X = iris[['PetalLength', 'PetalWidth']].values.T
Y = iris[['species']].values.T
Y = Y.astype('uint8')
#Make a scatter plot
plt.scatter(X[0, :], X[1, :], c=Y[0,:], s=40, cmap=plt.cm.Spectral);
plt.title("IRIS DATA | Blue - Versicolor, Red - Virginica ")
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.show()

 

3


       蓝色点代表Versicolor物种,红色点代表Virginica物种。本文构建的神经网络将在这些数据上进行训练,以期最后能正确地分类物种。

步骤2:初始化参数(权重和偏置)

       下面构建一个具有单个隐藏层的神经网络。此外,将隐藏图层的大小设置为6:

def initialize_parameters(n_x, n_h, n_y):np.random.seed(2) # we set up a seed so that our 
output matches ours although the initialization is random.W1 = np.random.randn(n_h, n_x) * 0.01 #weight matrix of shape (n_h, n_x)b1 = np.zeros(shape=(n_h, 1))  #bias vector of shape (n_h, 1)W2 = np.random.randn(n_y, n_h) * 0.01   #weight matrix of shape (n_y, n_h)b2 = np.zeros(shape=(n_y, 1))  #bias vector of shape (n_y, 1)#store parameters into a dictionary    parameters = {"W1": W1,"b1": b1,"W2": W2,"b2": b2}return parameters

步骤3:前向传播(forward propagation)

       在前向传播过程中,使用tanh激活函数作为第一层的激活函数,使用sigmoid激活函数作为第二层的激活函数:

def forward_propagation(X, parameters):
#retrieve intialized parameters from dictionary    W1 = parameters['W1']b1 = parameters['b1']W2 = parameters['W2']b2 = parameters['b2']# Implement Forward Propagation to calculate A2 (probability)Z1 = np.dot(W1, X) + b1A1 = np.tanh(Z1)  #tanh activation functionZ2 = np.dot(W2, A1) + b2A2 = 1/(1+np.exp(-Z2))  #sigmoid activation functioncache = {"Z1": Z1,"A1": A1,"Z2": Z2,"A2": A2}return A2, cache

步骤4:计算代价函数(cost function)

       目标是使得计算的代价函数小化,本文采用交叉熵(cross-entropy)作为代价函数:

def compute_cost(A2, Y, parameters):m = Y.shape[1] # number of training examples# Retrieve W1 and W2 from parametersW1 = parameters['W1']W2 = parameters['W2']# Compute the cross-entropy costlogprobs = np.multiply(np.log(A2), Y) + np.multiply((1 - Y), np.log(1 - A2))cost = - np.sum(logprobs) / mreturn cost

步骤5:反向传播(back propagation)

       计算反向传播过程,主要是计算代价函数的导数:

def backward_propagation(parameters, cache, X, Y):
# Number of training examplesm = X.shape[1]# First, retrieve W1 and W2 from the dictionary "parameters".
W1 = parameters['W1']W2 = parameters['W2']### END CODE HERE #### Retrieve A1 and A2 from dictionary "cache".A1 = cache['A1']A2 = cache['A2']# Backward propagation: calculate dW1, db1, dW2, db2. dZ2= A2 - YdW2 = (1 / m) * np.dot(dZ2, A1.T)db2 = (1 / m) * np.sum(dZ2, axis=1, keepdims=True)dZ1 = np.multiply(np.dot(W2.T, dZ2), 1 - np.power(A1, 2))dW1 = (1 / m) * np.dot(dZ1, X.T)db1 = (1 / m) * np.sum(dZ1, axis=1, keepdims=True)
grads = {"dW1": dW1,"db1": db1,"dW2": dW2,"db2": db2}return grads

步骤6:更新参数

       使用反向传播过程中计算的梯度来更新权重和偏置:

def update_parameters(parameters, grads, learning_rate=1.2):
# Retrieve each parameter from the dictionary "parameters"
W1 = parameters['W1']b1 = parameters['b1']W2 = parameters['W2']b2 = parameters['b2']# Retrieve each gradient from the dictionary "grads"dW1 = grads['dW1']db1 = grads['db1']dW2 = grads['dW2']db2 = grads['db2']# Update rule for each parameterW1 = W1 - learning_rate * dW1b1 = b1 - learning_rate * db1W2 = W2 - learning_rate * dW2b2 = b2 - learning_rate * db2parameters = {"W1": W1,"b1": b1,"W2": W2,"b2": b2}return parameters

步骤7:建立神经网络

       将以上所有函数组合起来以创建设计的神经网络模型。总而言之,下面是模型函数的整体顺序:

  • 初始化参数
  • 前向传播
  • 计算代价函数
  • 反向传播
  • 更新参数
def nn_model(X, Y, n_h, num_iterations=10000, print_cost=False):
np.random.seed(3)n_x = layer_sizes(X, Y)[0]n_y = layer_sizes(X, Y)[2]# Initialize parameters, then retrieve W1, b1, W2, b2. Inputs: "n_x, n_h, n_y". 
Outputs = "W1, b1, W2, b2, parameters".
parameters = initialize_parameters(n_x, n_h, n_y)W1 = parameters['W1']b1 = parameters['b1']W2 = parameters['W2']b2 = parameters['b2']# Loop (gradient descent)
for i in range(0, num_iterations):# Forward propagation. Inputs: "X, parameters". Outputs: "A2, cache".A2, cache = forward_propagation(X, parameters)# Cost function. Inputs: "A2, Y, parameters". Outputs: "cost".cost = compute_cost(A2, Y, parameters)# Backpropagation. Inputs: "parameters, cache, X, Y". Outputs: "grads".grads = backward_propagation(parameters, cache, X, Y)# Gradient descent parameter update. Inputs: "parameters, grads". Outputs: "parameters".parameters = update_parameters(parameters, grads)### END CODE HERE #### Print the cost every 1000 iterationsif print_cost and i % 1000 == 0:print ("Cost after iteration %i: %f" % (i, cost))
return parameters,n_h

步骤8:跑动模型

       将隐藏层节点设置为6,最大迭代次数设置为10,000次,并每隔1000次打印出训练的结果:

parameters = nn_model(X,Y , n_h = 6, num_iterations=10000, print_cost=True)

 

40

步骤9:画出分类边界

def plot_decision_boundary(model, X, y):# Set min and max values and give it some paddingx_min, x_max = X[0, :].min() - 0.25, X[0, :].max() + 0.25y_min, y_max = X[1, :].min() - 0.25, X[1, :].max() + 0.25h = 0.01# Generate a grid of points with distance h between themxx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))# Predict the function value for the whole gridZ = model(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)# Plot the contour and training examplesplt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)plt.ylabel('x2')plt.xlabel('x1')plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
plot_decision_boundary(lambda x: predict(parameters, x.T), X, Y[0,:])
plt.title("Decision Boundary for hidden layer size " + str(6))
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')

 

4


       从图中可以观察到,只有四个点被错误分类。虽然我们可以调整模型来进一步地提高模型训练精度,但该些操作显然会导致过拟合现象的出现。

资源

  • https://www.coursera.org/specializations/deep-learning

数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧!

作者信息

Rohan Joseph,数据科学家
个人主页:https://www.linkedin.com/in/rohan-joseph-b39a86aa/
本文由阿里云云栖社区组织翻译。
文章原标题《Neural network on iris data》,译者:海棠,审校:Uncle_LLD。
文章为简译,更为详细的内容,请查看原文。

原文链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/521433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【明人不说暗话】我就只讲进程与线程

戳蓝字“CSDN云计算”关注我们哦!作者 | 阮一峰责编 | 阿秃进程(process)和线程(thread)是操作系统的基本概念,但是它们比较抽象,不容易掌握。最近,我读到一篇材料,发现有…

(需求实战_进阶_05)SSM集成RabbitMQ 通配符模式 关键代码讲解、开发、测试

接上一篇: 文章目录一、RabbitMQ 配置文件1. RabbitMQ 生产者配置文件更新二、启动项目2.1. 启动项目2.2. 清空控制台三、管控台总览3.1. 登录管控台3.2. 交换机中查看绑定队列总览四、验证测试4.4. 生产者①请求4.5. 生产者②请求五、启动RabbitMQ5.1. 进入sbin目录…

两台邮件服务器共用一个公网地址,两个不同域邮件服务器的互通

两个不同域的邮件服务的互通如图,有两个不同域的邮件服务器(postfix)通过一个DNS服务器实现互通。首先说明一下IP分配情况服务器1qq.cometh0(VMnet2): ip:192.168.2.2 netmask:255.255.255.0 gw 192.168.2.1 hostname:mail.qq.com服务器2(qq.neteht0VMne…

希捷银河声音大_【推仔说新闻】那款硬盘它终于来了 希捷推出首款双磁臂硬盘...

经常关注科技新闻的朋友们应该都知道,现在机械硬盘领域可以说是被固态硬盘冲击的不清,而对于我们广大用户们来说,HDD这一个储存介质就被我们更多的用来充当仓库盘使用,毕竟现在的固态已经下探到白菜级别的价格了。但是对于那些HDD…

(需求实战_进阶_06)SSM集成RabbitMQ 订阅模式 关键代码讲解、开发、测试

背景: 为了减轻服务器的压力,现在原有项目的基础上集成消息队列来异步处理消息! 此项目是企业真实需求,项目的代码属于线上生产代码,直接用于生产即可! 此项目采用MQ发送消息模式为:订阅模式,如果对RabbitM…

【目瞪口呆】通信机房内部长这样

戳蓝字“CSDN云计算”关注我们哦!作者 | 小枣君责编 | 刘晶晶大家好,我是小枣君。一直以来,我都在努力给大家做通信知识科普,也写了很多有趣的文章。不过,文章再有趣也只是文字,不是实物。现实生活中&#…

NLP的ImageNet时代已经到来

摘要: NLP领域即将巨变,你准备好了吗? 自然语言处理(NLP)领域正在发生变化。 作为NLP的核心表现技术——词向量,其统治地位正在被诸多新技术挑战,如:ELMo,ULMFiT及Open…

mysql字段分隔符拆分_面试题Mysql数据库优化之垂直分表

在日常的开发工作中,除了JAVA相关的技术,打交道最多的就是Mysql数据库,当数据积累到一定程度,比如500W时就会难免出现一些慢sql,对数据库的优化方式有很多,比如通过增加合理的索引,今天我们来说…

python print用法不换行_python3让print输出不换行的方法

python 3.x版本print输出不换行的格式如下: print(x, end"") 其中,end"" 可使输出不换行,不能省略。 举例:输出结果:内容扩展: python3.x中如何实现print不换行 大家应该知道python中p…

使用Numpy和Opencv完成图像的基本数据分析(Part II)

摘要: 使用Numpy和Opencv完成图像的基本数据分析后续部分,主要包含逻辑运算符操作、掩膜以及卫星图像数据分析等操作 在上一节中,主要是介绍了图像的基本知识以及OpenCV的基本操作,具体内容参见“使用Numpy和Opencv完成基本图像的…

(需求实战_进阶_07)SSM集成RabbitMQ 订阅模式 关键代码讲解、开发、测试

接上一篇:(企业内部需求实战_进阶_06)SSM集成RabbitMQ 订阅模式 关键代码讲解、开发、测试 https://gblfy.blog.csdn.net/article/details/104219096 此项目采用MQ发送消息模式为:订阅模式,如果对RabbitMQ不熟悉,请学习…

分布式事务方案这么多,到底应该如何选型?

戳蓝字“CSDN云计算”关注我们哦!作者 | 温卫斌责编 | 刘晶晶源自 | dbaplus社群作者介绍温卫斌,就职于中国民生银行信息科技部,目前负责分布式技术平台设计与研发,主要关注分布式数据相关领域。微服务兴起的这几年涌现出不少分布…

造大专计算机学历,广昌县职业技术学校计算机应用专业助您 掌握一技之长获大专学历...

——专题宣传报道之四:计算机应用专业计算机应用专业一直是广昌县职业技术学校开设的特色专业。该专业由一批经验丰富、专业优秀的教师任教,主要学习计算机操作、组装、网络应用、影视后期制作、平面设计、文档管理等理论知识和实训课程。特色一&#xf…

阿里云正式推出消息队列Kafka:全面融合开源生态

摘要: 在全面兼容Apache Kafka生态的基础上,消息队列Kafka彻底解决Apache Kafka稳定性不足的长期痛点,并且支持消息无缝迁移到云上。 近日,阿里云宣布正式推出消息队列Kafka,全面融合开源生态。在全面兼容Apache Kafk…

异常将上下文初始化事件发送到类的侦听器实例_Spring的Bean实例化原理,这一次彻底搞懂了!...

前言之前分析了Spring XML和注解的解析原理,并将其封装为BeanDefinition对象存放到IOC容器中,而这些只是refresh方法中的其中一个步骤——obtainFreshBeanFactory,接下来就将围绕这这些BeanDefinition对象进行一系列的处理,如Bean…

(需求实战_01) SpringBoot2.x 整合RabbitMQ_生产端

文章目录一、依赖配置引入1. 引入SpringBoot整合RabbitMQ依赖2. 生产者配置文件3. 主配置二、代码Conding2.1. 生产者代码2.2. 实体对象2.3. 测试类一、依赖配置引入 1. 引入SpringBoot整合RabbitMQ依赖 <!--springboot整合RabbitMQ依赖--><dependency><groupI…

全域图像搜索给你更精准的搜索体验

摘要&#xff1a; 2018飞天技术汇&#xff0c;阿里巴巴机器智能技术实验室的刘磊带来题为全域精准图像搜索介绍的演讲&#xff0c;主要从四个方面进行了阐述&#xff0c;第一部分介绍了图像搜索的基本概念&#xff0c;第二部分主要是讲解了图像搜索的技术架构及其优势&#xff…

【这些都不知道你就是个弟弟】Docker常用命令

戳蓝字“CSDN云计算”关注我们哦&#xff01;作者 | 程序员欣宸转自 | 企业博客责编 | 阿秃除了基本的docker pull、docker image、docker ps&#xff0c;还有一些命令及参数也很重要&#xff0c;在此记录下来避免遗忘。环境信息以下是本次操作的环境&#xff1a;操作系统&…

php 模数 指数 公钥生成_php实现JWT认证

什么是JWTJWT(json web token)是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准。JWT的声明一般被用来在身份提供者和服务提供者间传递被认证的用户身份信息&#xff0c;以便于从资源服务器获取资源。比如用在用户登录上。JWT定义了一种用于简洁&#xff0c;自包…

SpringBoot2.x 整合RabbitMQ_消费端

这一篇讲解消费者 文章目录一、依赖配置1. 引入依赖2. 配置文件3. 主配置二、代码Conding2.1. 消费者代码一、依赖配置 1. 引入依赖 <!--springboot整合RabbitMQ依赖--><dependency><groupId>org.springframework.boot</groupId><artifactId>sp…