tensorflow 读取cifar_TensorFlow实战4——TensorFlow实现Cifar10识别

1 importcifar10, cifar10_input2 importtensorflow as tf3 importnumpy as np4 importtime5 importmath6

7 max_steps = 3000

8 batch_size = 128

9 data_dir = '/tmp/cifar10_data/cifar-10-batches-bin'

10

11

12 defvariable_with_weight_loss(shape, stddev, w1):13 '''定义初始化weight函数,使用tf.truncated_normal截断的正态分布,但加上L2的loss,相当于做了一个L2的正则化处理'''

14 var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))15 '''w1:控制L2 loss的大小,tf.nn.l2_loss函数计算weight的L2 loss'''

16 if wl is notNone:17 weight_loss = tf.multiply(tf.nn.l2_loss(var), w1, name='weight_loss')18 '''tf.add_to_collection:把weight losses统一存到一个collection,名为losses'''

19 tf.add_to_collection('losses', weight_loss)20

21 returnvar22

23

24 #使用cifar10类下载数据集并解压展开到默认位置

25 cifar10.maybe_download_and_extract()26

27 '''distored_inputs函数产生训练需要使用的数据,包括特征和其对应的label,28 返回已经封装好的tensor,每次执行都会生成一个batch_size的数量的样本'''

29 images_train, labels_train = cifar10_input.distored_inputs(data_dir=data_dir,30 batch_size=batch_size)31

32 images_test, labels_test = cifar10_input.inputs(eval_data=True,33 data_dir=data_dir,34 batch_size=batch_size)35

36 image_holder = tf.placeholder(tf.float32, [batch_size, 24, 24, 3])37 label_holder =tf.placeholder(tf.int32, [batch_size])38

39 '''第一个卷积层:使用variable_with_weight_loss函数创建卷积核的参数并进行初始化。40 第一个卷积层卷积核大小:5x5 3:颜色通道 64:卷积核数目41 weight1初始化函数的标准差为0.05,不进行正则wl(weight loss)设为0'''

42 weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, wl=0.0)43 #tf.nn.conv2d函数对输入image_holder进行卷积操作

44 kernel1 = tf.nn.conv2d(image_holder, weight1, [1, 1, 1, 1], padding='SAME')45

46 bias1 = tf.Variable(tf.constant(0.0, shape=[64]))47

48 conv1 =tf.nn.relu(tf.nn.bias_add(kernel1, bias1))49 #最大池化层尺寸为3x3,步长为2x2

50 pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1])51 #LRN层模仿生物神经系统的'侧抑制'机制

52 norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)53

54 '''第二个卷积层:'''

55 weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2, wl=0.0)56 kernel2 = tf.nn.conv2d(norm1, weight2, [1, 1, 1, 1], padding='SAME')57 #bias2初始化为0.1

58 bias2 = tf.Variable(tf.constant(0.1, shape=[64]))59

60 conv2 =tf.nn.relu(tf.nn.bias_add(kernel2, bias2))61 norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)62 pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')63

64 #全连接层

65 reshape = tf.reshape(pool2, [batch_size, -1])66 dim = reshape.get_shape()[1].value67 weight3 = variable_with_weight_loss(shape=[dim, 384], stddev=0.04, wl=0.004)68 bias3 = tf.Variable(tf.constant(0.1, shape=[384]))69 local3 = tf.nn.relu(tf.matmul(reshape, weight3) +bias3)70

71 #全连接层,隐含层节点数下降了一半

72 weight4 = variable_with_weight_loss(shape=[384, 182], stddev=0.04, wl=0.004)73 bias4 = tf.Variable(tf.constant(0.1, shape=[192]))74 local4 = tf.nn.relu(tf.matmul(local3, weight4) +bias4)75

76 '''正态分布标准差设为上一个隐含层节点数的倒数,且不计入L2的正则'''

77 weight5 = variable_with_weight_loss(shape=[192, 10], stddev=1 / 192.0, wl=0.0)78 bias5 = tf.Variable(tf.constant(0.0, shape=[10]))79 logits =tf.add(tf.matmul(local4, weight5), bias5)80

81

82 defloss(logits, labels):83 '''计算CNN的loss84 tf.nn.sparse_softmax_cross_entropy_with_logits作用:85 把softmax计算和cross_entropy_loss计算合在一起'''

86 labels =tf.cast(labels, tf.int64)87 cross_entropy =tf.nn.sparse_softmax_cross_entropy_with_logits(88 logits=logits, labels=labels, name='cross_entropy_per_example')89 #tf.reduce_mean对cross entropy计算均值

90 cross_entropy_mean =tf.reduce_mean(cross_entropy,91 name='cross_entropy')92 #tf.add_to_collection:把cross entropy的loss添加到整体losses的collection中

93 tf.add_to_collection('losses', cross_entropy_mean)94 #tf.add_n将整体losses的collection中的全部loss求和得到最终的loss

95 return tf.add_n(tf.get_collection('losses'), name='total_loss')96

97

98 #将logits节点和label_holder传入loss计算得到最终loss

99 loss =loss(logits, label_holder)100

101 train_op = tf.trian.AdamOptimizer(1e-3).minimize(loss)102 #求输出结果中top k的准确率,默认使用top 1(输出分类最高的那一类的准确率)

103 top_k_op = tf.nn.in_top_k(logits, label_holder, 1)104

105 sess =tf.InteractiveSession()106 tf.global_variables_initializer().run()107 tf.trian.start_queue_runners()108

109 for step inrange(max_steps):110 '''training:'''

111 start_time =time.time()112 #获得一个batch的训练数据

113 image_batch, label_batch =sess.run([images_train, labels_train])114 #将batch的数据传入train_op和loss的计算

115 _, loss_value =sess.run([train_op, loss],116 feed_dict={image_holder: image_batch, label_holder: label_batch})117

118 duration = time.time() -start_time119 if step % 10 ==0:120 #每秒能训练的数量

121 examples_per_sec = batch_size /duration122 #一个batch数据所花费的时间

123 sec_per_batch =float(duration)124

125 format_str = ('step %d, loss=%.2f (%.1f examples/sec; %.3f sec/batch)')126 print(format_str %(step, loss_value, examples_per_sec, sec_per_batch))127 #样本数

128 num_examples = 10000

129 num_iter = int(math.ceil(num_examples /batch_size))130 true_count =0131 total_sample_count = num_iter *batch_size132 step =0133 while step <134 labels_test>

135 image_batch, label_batch =sess.run([images_test, labels_test])136 #计算这个batch的top 1上预测正确的样本数

137 preditcions = sess.run([top_k_op], feed_dict={image_holder: image_batch,138 label_holder: label_batch139 })140 #全部测试样本中预测正确的数量

141 true_count +=np.sum(preditcions)142 step += 1

143 #准确率

144 precision = true_count /total_sample_count145 print('precision @ 1 = %.3f' % precision)

134>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/488224.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

​CPU将进入新时代:押注计算芯片的极限协同设计

来源&#xff1a;内容由半导体行业观察&#xff08;icbank&#xff09;编译自「nextplatform」&#xff0c;作者&#xff1a;Timothy Prickett Morgan&#xff0c;谢谢。我们现在进入了一个时代&#xff0c;那就是IT行业的计算引擎将需要比以往任何时候都更需要更低的价格&…

rk3288 android编译环境搭建,RK3288系统编译及环境搭建

准备工作编译 Android 对机器的配置要求较高&#xff1a;64 位 CPU16GB 物理内存交换内存30GB 空闲的磁盘空间用于构建&#xff0c;源码树另外占用大约 8GB官方推荐 Ubuntu 12.04 操作系统&#xff0c;实际上也可以采用更新的操作系统版本&#xff0c;只需要满足 http://source…

50 days before NOI2017

2017.5.31 今天开了这个博客&#xff0c;打算每天来写点东西&#xff0c;嗯。。。毕竟要NOI了嘛。。。 第一天跑到常州里集训&#xff0c;打开题目一看湖南集训题。。。 T1刷一下写完&#xff0c;然后交了然后发现错了。。。赶紧改过来&#xff0c;大概1h吧。。。 T2刷一下发现…

2020城市大脑与超级智能建设规范研究报告(附下载)

报告下载地址: https://pan.baidu.com/s/1x85xZrAG6df4BcVkJqtcqw提取码: 6ytv21世纪以来&#xff0c;21世纪以来&#xff0c;眼花缭乱的前沿科技新概念喷薄而出&#xff0c;从Web2.0、社交网络、物联网、移动互联网、大数据、工业4.0、工业互联网到云机器人、深度学习、边缘计…

android按钮点击无响应时间,AndroidStudio下的点击事件不响应

本来是测试自定义Toast&#xff0c;发现implements View.OnClickListener的Onclick事件怎么都不响应&#xff0c;开始以为是自定义的问题。结果当然不是&#xff1b;需要clean项目就好了&#xff0c;AndroidStudio的问题还是很多~Overridepublic void onBackPressed() {ToastUt…

同步带周长计算公式_同步带的长度计算和常见问题

同步带的长度计算公式如下&#xff1a;lb ((df dm) 1.5708 ) (2 lfm)其中lb是同步带的长度&#xff0c;df是大同步带轮的直径&#xff0c;dm是小同步带轮的直径&#xff0c;lfm是大同步带轮的中心和小同步带轮中心的距离。从上述同步带长度的计算公式可以看出同步带轮的直径对…

nodejs中处理回调函数的异常

假设是使用nodejsexpress3这个经典的组合。那么有一种非常方面的处理回调函数异常的方法&#xff1a; 1. 安装模块&#xff1a;express-domain-middleware 2. 增加例如以下的代码&#xff1a; app.use(require(express-domain-middleware)); app.use(function errorHandler(err…

5G新标准将延迟3个月发布,但5G“新战场”已经明确

来源&#xff1a;雷锋网2019年&#xff0c;5G开启商用元年。2020年开年&#xff0c;5G智能手机的发布就迎来了一个发布高峰&#xff0c;2月份至今&#xff0c;国内就有10多款5G手机发布。5G手机的数量和销量也迅速增长&#xff0c;根据工信部副部长辛国斌给出的数据&#xff0c…

python copy deepcopy_python-copy-deepcopy

1.结论&#xff1a;—–我们寻常意义的复制就是深复制&#xff0c;即将被复制对象完全再复制一遍作为独立的新个体单独存在。所以改变原有被复制对象不会对已经复制出来的新对象产生影响。—–而浅复制并不会产生一个独立的对象单独存在&#xff0c;他只是将原有的数据块打上一…

android 2个界面抽屉,Android使用DrawerLayout创建左右两个抽屉菜单

在Android support.v4 中有一个抽屉视图控件DrawerLayout。使用这个控件&#xff0c;可以生成通过在屏幕上水平滑动打开或者关闭菜单&#xff0c;能给用户一个不错的体验效果。最近在项目中&#xff0c;设计中有用到这个效果&#xff0c;但是是左右两边都能划出这样的一个菜单效…

报告:100家AI初创公司榜单 这五大趋势不得不看!

来源&#xff1a; 网易智能用新药治疗一系列慢性疾病&#xff1b;抵御各种网络攻击&#xff1b;让城市更加智能&#xff1b;更精准地预报天气和野火&#xff0c;从而提高安全性并降低风险。此外&#xff0c;还有深度伪造技术&#xff08;deepfakes&#xff09;的商业化。这些看…

android默认exported_AndroidManifest.xml文件中exported属性解析

4、目标Activity的属性Android:exported”true”如果组件包含有intent-filter则 exported默认值为true;没有intent-filter则exported默认值为false。当exported为 true时可以被外部其他App所调用当exported为 false时可以被外部其他App所调用5、目标Activity具有相应的IntentFi…

android外接键盘打汉字,Android在外接物理键盘时,如何强制调用系统软键盘

Android在外接物理键盘时&#xff0c;如何强制调用系统软键盘&#xff1f;第一次写&#xff0c;写的不好请见谅参考:物理键盘映射过程&#xff1a;手机/system/usr/keylayout/*.kl &#xff1a;内核将keyCode映射成有含义的字符串KeycodeLabels.h &#xff1a; framework 将字符…

20155204 2016-2017-2《Java程序设计》课程总结

20155204 2016-2017-2《Java程序设计》课程总结 目录 作业链接汇总作业总结实验报告链接汇总代码托管链接课堂项目实践学习经验问卷调查链接二维码&#xff08;按顺序&#xff09;每周作业链接汇总 预备作业1&#xff1a;我对师生关系的思考预备作业2&#xff1a;做中学感悟预备…

android网络测试上传速度慢,Android:如何获得互联网连接上传速度和延迟?

要获取当前网络连接类型&#xff1a;TelephonyManager telephonyManager (TelephonyManager) getSystemService(Context.TELEPHONY_SERVICE);int networkType telephonyManager.getNetworkType();并为延迟&#xff1a;String host "172.16.0.2";int timeOut 3000…

复杂性科学与还原论

来源&#xff1a;陶勇科学网博客1984年&#xff0c;两位诺贝尔物理学奖得主盖尔曼&#xff08;Murray Gell-mann&#xff09;、安德森&#xff08;Philip Anderson&#xff09;和诺贝尔经济学奖得主阿罗&#xff08;Kenneth Arrow&#xff09;聚集了一批从事物理、经济、生物、…

ios math 那个头文件_C++ 头文件系列(ios)

1 简介我们都知道&#xff0c;平时常用的那些标准流&#xff0c;诸如iostream、ofstream、ifstream等等&#xff0c;其实都是对应的basic_XXX模版的实例类。 而这些basic_XXX类模版又都是继承自同一个基类模版----basic_ios。2 basic_ios模版定义这个基类模版应该是出于可重用的…

Nim游戏(初谈博弈)

通常的Nim游戏的定义是这样的&#xff1a;有若干堆石子&#xff0c;每堆石子的数量都是有限的&#xff0c;合法的移动是“选择一堆石子并拿走若干颗&#xff08;不能不拿&#xff09;”&#xff0c; 如果轮到某个人时所有的石子堆都已经被拿空了&#xff0c;则判负&#xff08;…

android 如何使用aar,Android Studio如何使用aar依赖包?

ps:2013-12-25 号更新,升级到0.4以后 这种方法已经完美使用&#xff01;因为项目里面要用到actionbarsherlock&#xff0c;所以研究了一下如何导入到android studio中。arr(Android Archive)&#xff1a;名字是谷歌到的&#xff0c;至于中文叫什么我也不知道。不过好像依赖都要…

第二百七十九节,MySQL数据库-pymysql模块操作数据库

MySQL数据库-pymysql模块操作数据库 pymysql模块是python操作数据库的一个模块 connect()创建数据库链接,参数是连接数据库需要的连接参数使用方式&#xff1a;   模块名称.connect()   参数&#xff1a;   host数据库ip   port数据库端口   user数据库用户名   pa…