【TensorFlow】实现简单的鸢尾花分类器

代码实现及说明

# python 3.6
# TensorFlow实现简单的鸢尾花分类器
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn import datasetssess = tf.Session()#导入数据
iris = datasets.load_iris()
# 是否是山鸢尾 0/1
binary_target = np.array([1. if x == 0 else 0. forx in iris.target])
# 选择两个特征:花瓣长度和宽度
iris_2d = np.array([[x[2],x[3]] for x in iris.data])# 声明批训练大小、占位符和变量
# tf.float32降低float字节数 可以提高算法性能
batch_size = 20
x1_data = tf.placeholder(shape=[None,1],dtype=tf.float32)
x2_data = tf.placeholder(shape=[None,1],dtype=tf.float32)
y_target = tf.placeholder(shape=[None,1],dtype=tf.float32)
# 声明变量 A 和 b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))# 定义线性模型
# 线性模型的表达式为:x1=x2*A+b。
# 如果找到的数据点在直线以上,则将数据点代入x1-x2*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x1-x2*A-b计算出的结果小于0。
# 将公式x1-x2*A-b传入sigmoid函数,然后预测结果1或者0
# TensorFlow有内建的sigmoid损失函数,所以这里仅仅需要定义模型输出
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data,my_add)# 增加分类损失函数 这里用两类交叉熵损失函数 cross entropy
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output,labels=y_target)# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)# 循环
for i in range(1000):rand_index = np.random.choice(len(iris_2d),size=batch_size)rand_x = iris_2d[rand_index]rand_x1 = np.array([[x[0]] for x in rand_x])rand_x2 = np.array([[x[1]] for x in rand_x])rand_y = np.array([[y] for y in binary_target[rand_index]])sess.run(train_step, feed_dict={x1_data:rand_x1,x2_data:rand_x2,y_target:rand_y})if (i+1)%200 == 0:print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))# 结果可视化
[[slope]] = sess.run(A) # 斜率
# 因为A的shape是(1,1)所以要写成一行一列的形式
[[intercept]] = sess.run(b) # 截距# 创建拟合线
x = np.linspace(0, 3, num=50) # 0~3 50个均匀间隔的数字
ablineValues = []
for i in x:ablineValues.append(slope*i+intercept)# 绘图
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

绘图结果

绘图结果

总结

这里利用花瓣长度和花瓣宽度的特征在山鸢尾与其他物种间拟合一条直线,然后通过该直线来分割两类目标(山鸢尾和非山鸢尾),直线是迭代1000次得到的线性分割,通过直线分割两个目标并不是最好的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/481101.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

偏差-方差全面解释

偏差(Bias)与方差(Variance) 目录: 为什么会有偏差和方差?偏差、方差、噪声是什么?泛化误差、偏差和方差的关系?用图形解释偏差和方差。偏差、方差窘境。偏差、方差与过拟合、欠拟合…

「小算法」回文数与数值合法性检验

喵喵喵,小夕最近准备复习一下数学和基础算法,尽量每篇推送下面会附带点数学和基础算法的小文章。说不定哪天就用(考)到了呢( ̄∇ ̄)注意哦,与头条位的文章推送不同,「小公式」和「小算…

【TensorFlow】实现、训练并评估简单的回归模型和分类模型

1 回归模型 回归算法模型用来预测连续数值型,其目标不是分类值而是数字。为了评估这些回归预测值是否与实际目标相符,我们需要度量两者间的距离,打印训练过程中的损失,最终评估模型损失。 这里使用的例子是从均值为1、标准差为0…

史上最通熟易懂的检索式聊天机器人讲解

喵喵喵,一不小心又匿了三个月,突然诈尸害不害怕( ̄∇ ̄) 小夕从7月份开始收到第一场面试邀请,到9月初基本结束了校招(面够了面够了T_T),深深的意识到今年的对话系统/chatbot方向是真的…

Python pandas数据分析中常用方法

官方教程 读取写入文件 官方IO 读取 写入 read_csv       to_csv read_excel      to_excel read_hdf       to_hdf read_sql       to_sql read_json      to_json read_msgpack (experimental)   to_msgpack (experimental) read_html    …

小哥哥,检索式chatbot了解一下?

喵喵喵,一不小心又匿了三个月,突然诈尸害不害怕( ̄∇ ̄) 小夕从7月份开始收到第一场面试邀请,到9月初基本结束了校招(面够了面够了T_T),深深的意识到今年的对话系统/chatbot方向是真的…

领域应用 | 中医临床术语系统

本文转载自公众号中医药知识组织与标准。什么是中医药术语系统?它是干什么用的呢?中医药术语系统是运用计算机与信息技术等工具,对中医药学各领域中的事物、现象、特性、关系和过程进行标记和概括,并为每个概念赋予指称形成概念体…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Image图片组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Image图片组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Image组件 Image 用来加载并显示图片的基础组件,它支持从内存、本…

梁家卿 | 百科知识图谱同步更新

本文转载自公众号知识工场。 本文整理自复旦大学知识工场梁家卿博士在IJCAI 2017 会议上的论文报告,题目为《How to Keep a Knowledge Base Synchronized with Its Encyclopedia Source》,作者包括:梁家卿博士(复旦大学&#xff0…

别再搜集面经啦!小夕教你斩下NLP算法岗offer!

推完上一篇文章,订阅号和知乎后台有好多小伙伴跟小夕要面经(还有个要买简历的是什么鬼),然鹅小夕真的没有整理面经呀,真的木有时间(。 ́︿ ̀。)。不过话说回来,面经有多大用呢&#…

肖仰华 | 基于知识图谱的问答系统

本文转载自公众号知识工场。 本文整理自复旦大学知识工场肖仰华教授在VLDB 2017 会议上的论文报告,题目为《KBQA: Learning Question Answering over QA Corpora and Knowledge Bases》,作者包括:崔万云博士(现上海财经大学讲师&a…

【小夕精选】YJango 7分钟带你领略你未曾想过的线性代数+微积分

小夕很早之前就想转一些精彩的技术文章,这样哪怕没有时间写作的时候,也能把优质的干货分享给大家~然鹅,由于我也不知道是什么的原因,就不小心拖到了现在╮( ̄▽ ̄"")╭之前有不少粉丝希…

白硕 | 基于区块链的众包社区激励机制

本文整理自白硕老师在 YOCSEF 武汉专题论坛:“人工智能遇到区块链,是惊鸿一瞥还是天长地久?”的报告。 很高兴有这个机会跟大家交流。我先讲几个案例作为引子。第一个案例与知识图谱有关。这个公司做的是非常垂直的一个领域,安全教…

【小夕精选】多轮对话之对话管理(Dialog Management)

这一篇是一段时间之前小夕初入对话领域时刷到的徐阿衡小姐姐写的一篇文章,写的深入浅出,十分适合有一定基础的情况下想快速了解对话管理技术的童鞋阅读~另外顺手推一下阿衡小姐姐的订阅号「徐阿衡」,干货满满不要错过哦~这一篇想写一写对话管…

KD Tree的原理及Python实现

1. 原理篇我们用大白话讲讲KD-Tree是怎么一回事。1.1 线性查找假设数组A为[0, 6, 3, 8, 7, 4, 11],有一个元素x,我们要找到数组A中距离x最近的元素,应该如何实现呢?比较直接的想法是用数组A中的每一个元素与x作差,差的…

漆桂林 | 知识图谱的应用

本文作者为东南大学漆桂林老师,首发于知乎专栏知识图谱和智能问答 前面一篇文章“知识图谱之语义网络篇”已经提到了知识图谱的发展历史,回顾一下有以下几点: 1. 知识图谱是一种语义网络,即一个具有图结构的知识库,这里…

NLP预训练之路——从word2vec, ELMo到BERT

前言 还记得不久之前的机器阅读理解领域,微软和阿里在SQuAD上分别以R-Net和SLQA超过人类,百度在MS MARCO上凭借V-Net霸榜并在BLEU上超过人类。这些网络可以说一个比一个复杂,似乎“如何设计出一个更work的task-specific的网络"变成了NLP…

论文 | 信息检索结果Ranking的评价指标《RankDCG: Rank-Ordering Evaluation Measure》

未经允许,不得转载,谢谢~~ 一 文章简介 为什么要提出这个新的评价算法? 我们都知道ranking过程对于信息检索的结果是非常重要的,那么我们就需要有一些算法能评价ranking的结果到底如何。现有用来评价ranking的常用算法有&#xff…

肖仰华 | 基于知识图谱的用户理解

本文转载自公众号知识工场。 本文整理自肖仰华教授在三星电子中国研究院做的报告,题目为《Understanding users with knowldge graphs》。 今天,很高兴有这个机会来这里与大家交流。 前面两位老师把基于社会影响力的传播和推荐,以及跨领域的…

NLP的游戏规则从此改写?从word2vec, ELMo到BERT

前言还记得不久之前的机器阅读理解领域,微软和阿里在SQuAD上分别以R-Net和SLQA超过人类,百度在MS MARCO上凭借V-Net霸榜并在BLEU上超过人类。这些网络可以说一个比一个复杂,似乎“如何设计出一个更work的task-specific的网络"变成了NLP领…