python模拟猫狗大战_tensorflow实现猫狗大战(分类算法)-阿里云开发者社区

from __future__ importabsolute_importfrom __future__ importdivisionfrom __future__ importprint_functionimportosimporttensorflow as tf

flags=tf.app.flags

flags.DEFINE_integer(flag_name='batch_size', default_value=16, docstring='Batch 大小')

flags.DEFINE_string(flag_name='data_dir', default_value='./tfrecords', docstring='数据存放位置')

flags.DEFINE_string(flag_name='model_dir', default_value='./cat&dog_model', docstring='模型存放位置')

flags.DEFINE_integer(flag_name='steps', default_value=1000, docstring='训练步数')

flags.DEFINE_integer(flag_name='classes', default_value=2, docstring='类别数量')

FLAGS=flags.FLAGS

MODES=[tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT]def input_fn(mode, batch_size=1):"""输入函数"""

defparser(serialized_example):"""如何处理数据集中的每一个数据"""

#解析单个example对象

features =tf.parse_single_example(

serialized_example,

features={'image/height': tf.FixedLenFeature([], tf.int64),'image/width': tf.FixedLenFeature([], tf.int64),'image/depth': tf.FixedLenFeature([], tf.int64),'image/encoded': tf.FixedLenFeature([], tf.string),'image/class/label': tf.FixedLenFeature([], tf.int64),

})#获取参数

height = tf.cast(features['image/height'], tf.int32)

width= tf.cast(features['image/width'], tf.int32)

depth= tf.cast(features['image/depth'], tf.int32)#还原image

image = tf.decode_raw(features['image/encoded'], tf.float32)

image=tf.reshape(image, [height, width, depth])

image= image - 0.5

#还原label

label = tf.cast(features['image/class/label'], tf.int32)returnimage, tf.one_hot(label, FLAGS.classes)if mode inMODES:

tfrecords_file= os.path.join(FLAGS.data_dir, mode + '.tfrecords')else:raise ValueError("Mode 未知")assert tf.gfile.Exists(tfrecords_file), ('TFRrecords 文件不存在')#创建数据集

dataset =tf.data.TFRecordDataset([tfrecords_file])#创建映射

dataset = dataset.map(parser, num_parallel_calls=1)#设置batch

dataset =dataset.batch(batch_size)#如果是训练,那么就永久循环下去

if mode ==tf.estimator.ModeKeys.TRAIN:

dataset=dataset.repeat()#创建迭代器

iterator =dataset.make_one_shot_iterator()#获取 feature 和 label

images, labels =iterator.get_next()returnimages, labelsdefmy_model(inputs, mode):"""写一个网络"""net= tf.reshape(inputs, [-1, 224, 224, 1])

net= tf.layers.conv2d(net, 32, [3, 3], padding='same', activation=tf.nn.relu)

net= tf.layers.max_pooling2d(net, [2, 2], strides=2)

net= tf.layers.conv2d(net, 32, [3, 3], padding='same', activation=tf.nn.relu)

net= tf.layers.max_pooling2d(net, [2, 2], strides=2)

net= tf.layers.conv2d(net, 64, [3, 3], padding='same', activation=tf.nn.relu)

net= tf.layers.conv2d(net, 64, [3, 3], padding='same', activation=tf.nn.relu)

net= tf.layers.max_pooling2d(net, [2, 2], strides=2)#print(net)

net = tf.reshape(net, [-1, 28 * 28 * 64])

net= tf.layers.dense(net, 1024, activation=tf.nn.relu)

net= tf.layers.dropout(net, 0.4, training=(mode ==tf.estimator.ModeKeys.TRAIN))

net=tf.layers.dense(net, FLAGS.classes)returnnetdefmy_model_fn(features, labels, mode):"""模型函数"""

#可视化输入

tf.summary.image('images', features)#创建网络

logits =my_model(features, mode)

predictions={'classes': tf.argmax(input=logits, axis=1),'probabilities': tf.nn.softmax(logits, name='softmax_tensor')

}#如果是PREDICT,那么只需要predictions就够了

if mode ==tf.estimator.ModeKeys.PREDICT:return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)#创建Loss

loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits, scope='loss')

tf.summary.scalar('train_loss', loss)#设置如何训练

if mode ==tf.estimator.ModeKeys.TRAIN:

optimizer= tf.train.AdamOptimizer(learning_rate=1e-3)

train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step())else:

train_op=None#获取训练精度

accuracy =tf.metrics.accuracy(

tf.argmax(labels, axis=1), predictions['classes'],

name='accuracy')

accuracy_topk=tf.metrics.mean(

tf.nn.in_top_k(predictions['probabilities'], tf.argmax(labels, axis=1), 2),

name='accuracy_topk')

metrics={'test_accuracy': accuracy,'test_accuracy_topk': accuracy_topk

}#可视化训练精度

tf.summary.scalar('train_accuracy', accuracy[1])

tf.summary.scalar('train_accuracy_topk', accuracy_topk[1])returntf.estimator.EstimatorSpec(

mode=mode,

predictions=predictions,

loss=loss,

train_op=train_op,

eval_metric_ops=metrics)defmain(_):#监视器

logging_hook =tf.train.LoggingTensorHook(

every_n_iter=100,

tensors={'accuracy': 'accuracy/value','accuracy_topk': 'accuracy_topk/value','loss': 'loss/value'},

)#创建 Estimator

model =tf.estimator.Estimator(

model_fn=my_model_fn,

model_dir=FLAGS.model_dir)for i in range(20):#训练

model.train(

input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN, FLAGS.batch_size),

steps=FLAGS.steps,

hooks=[logging_hook])#测试并输出结果

print("=" * 10, "Testing", "=" * 10)

eval_results=model.evaluate(

input_fn=lambda: input_fn(tf.estimator.ModeKeys.EVAL))print('Evaluation results:\n\t{}'.format(eval_results))print("=" * 30)if __name__ == '__main__':

tf.logging.set_verbosity(tf.logging.INFO)

tf.app.run()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/520296.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Spring整合Quartz轻松完成定时任务

文章目录1. 必不可少jar包依赖2. 编写任务调度类013. 编写任务调度类024. quartz配置文件5. 测试service6. 效果图开发环境版本jdk1.8Maven3.6.1springSpring 4.2.6.RELEASEIdea2019 1. 必不可少jar包依赖 <dependency><groupId>org.springframework</groupId&g…

太真实了:程序员等级图鉴

戳蓝字“CSDN云计算”关注我们哦&#xff01;作者 | 肥又君责编&#xff5c;阿秃程序员是一个非常神奇的工种&#xff0c;他们对技术有特殊的崇拜&#xff0c;有着严格的等级划分&#xff0c;不同级别的程序员有什么不同之处呢&#xff1f;Let us 瞅瞅 &#xff5e;日常工作日常…

windows故障转移群集和mysql_Windows 2016 无域故障转移群集部署方法 超详细图文教程...

故障转移群集是一个很实用的功能,而windows在2016版本开始,终于支持不用域做故障转移群集.在群集中,我们可以设定一个"群集IP"而客户端只需要根据这个"群集IP"就能连接当前群集的主服务器.而不必关心群集服务器之间的替换.而更棒的是,它是"去中心&quo…

Dubbo下一站:Apache顶级项目

近日&#xff0c;在Apache Dubbo开发者沙龙杭州站的活动中&#xff0c;阿里巴巴中间件技术专家曹胜利(展图)向开发者们分享了Dubbo2.7版本的规划。 本文将为你探秘 Dubbo 2.7背后的思考和实现方式。 Dubbo 2.7 将围绕 异步支持优化、元数据改造&#xff0c;引入JDK8的特性、Net…

Java 中判断连接Oracle数据库连接成功

import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; public class Application {public static Connection getConnection() {Connection conn null;try { //连接driver为&#xff1a;oracle.jdbc.driver.OracleDriver//或者oracle…

从内部自用到对外服务,配置管理的演进和设计优化实践

本文整理自阿里巴巴中间件技术专家彦林在中国开源年会上的分享&#xff0c;通过此文&#xff0c;您将了解到&#xff1a; 微服务给配置管理所带来的变化配置管理演进过程中的设计思考配置管理开源后的新探索配置中心控制台设计实践“为什么相对于传统的软件开发模式&#xff0…

12 种主流编程语言输出“ Hello World ”,把我给难住了!

作为一名程序员&#xff0c;在初步学习编程想必都绕不开一个最为基础的入门级示例“Hello World”&#xff0c;那么&#xff0c;你真的了解各个语言“Hello World”的正确写法吗&#xff1f;在我们刚开始打开编程世界的时候&#xff0c;往往写的第一个程序都是简单的文本输出&a…

mysql 截取域名_sql 截取域名的问题

sql 截取域名的几种方法总结&#xff0c;需要的朋友可以参考一下最近由于对数据库的域名要排重&#xff0c;因为sql直接使用起来方便一些&#xff0c;就整理下A.截取从字符串左边开始N个字符代码如下:Declare S1 varchar(100)Select S1http://www.jb51.netSelect Left(S1,4)---…

优秀工程师必备的三大思维,你拥有哪些?

不同岗位、不同职责的技术人对工程师思维的深度要求是不一样的&#xff0c;但从多维度去思考却应是每个技术人都应该具备的素养。本文整理自阿里巴巴高级技术专家至简在团队内部的个人分享&#xff0c;希望通过对工程师思维的分析和解读&#xff0c;让大家能正确对待那些在现实…

看完这篇还不了解Nginx,那我就哭了!

戳蓝字“CSDN云计算”关注我们哦&#xff01;作者 | 蔷薇Nina责编 | 阿秃想必大家一定听说过 Nginx&#xff0c;若没听说过它&#xff0c;那么一定听过它的"同行"Apache 吧&#xff01;Nginx 的产生Nginx 同 Apache 一样都是一种 Web 服务器。基于 REST 架构风格&…

迁移学习NLP:BERT、ELMo等直观图解

2018年是自然语言处理的转折点&#xff0c;能捕捉潜在意义和关系的方式表达单词和句子的概念性理解正在迅速发展。此外&#xff0c;NLP社区已经出现了非常强大的组件&#xff0c;你可以在自己的模型和管道中自由下载和使用&#xff08;它被称为NLP的ImageNet时刻&#xff09;。…

python如何仿写文章_python,python3.x_求助,用python仿写以下代码,python,python3.x,java - phpStudy...

求助&#xff0c;用python仿写以下代码public static void main(String[] args) {Scanner scnew Scanner(System.in);int nsc.nextInt();int[] flagnew int[29];float[] anew float[29];for(int i0;ia[i](float)1.0/(float)(i2);}for(int j1;j<Math.pow(2, 29);j){int tempj…

开发函数计算的正确姿势 —— 爬虫

在 《函数计算本地运行与调试 - Fun Local 基本用法》 中&#xff0c;我们介绍了利用 Fun Local 本地运行、调试函数的方法。但如果仅仅这样简单的介绍&#xff0c;并不能展现 Fun Local 对函数计算开发的巨大效率的提升。 这一次&#xff0c;我们拿一个简单的场景来举例子——…

SonarQube 规则的挂起与激活

文章目录规则添加规则挂起规则添加 规则挂起

内存性能的正确解读

一台服务器&#xff0c;不管是物理机还是虚拟机&#xff0c;必不可少的就是内存&#xff0c;内存的性能又是如何来衡量呢。 1. 内存与缓存 现在比较新的CPU一般都有三级缓存&#xff0c;L1 Cache&#xff08;32KB-256KB&#xff09;&#xff0c;L2 Cache&#xff08;128KB-2M…

2019年技术盘点云数据库篇(一):UCloud专家谈云数据库:千锤百炼 云之重器

作者 | 刘丹 出品 | CSDN云计算&#xff08;ID&#xff1a;CSDNcloud&#xff09; 公有云逐渐成为企业运行 IT 设施的新趋势&#xff0c;那么作为企业最核心的系统—数据库&#xff0c;数据上云也成为大数据时代的必然选择。对企业来说&#xff0c;数据可视为其命脉&#xff0…

wxpython富文本_去除富文本编辑器中的标签

public static String deRegularExpression(String content) {content deRegularScript(content); // 过滤script标签String regEx_style "/* String regEx_inStyle "style\"([^\";];?)\""; */String regEx_html "<[^>]>&quo…

专访阿里云专有云马劲,一个理性的理想主义者

“我的故事都是和团队技术相关的&#xff0c;自己还真没有什么引人入胜的故事。”当马劲被问到能不能多分享些个人经历故事时他笑着说&#xff0c;我们就干脆怀着好奇聊了聊他和阿里云专有云一路走来的故事。 马劲&#xff0c;花名隆猫&#xff0c;阿里云专有云事业部兼企业应用…

80后阿里P10,“关老板”如何带着MaxCompute一路升级?

我是个幸运的人。虽然幸运不能被复制&#xff0c;但是眼光和努力可以。 关涛/关老板&#xff0c;80后的阿里P10&#xff0c;阿里巴巴通用计算平台负责人&#xff0c;阿里巴巴计算平台研究员。12年职场人生&#xff0c;微软和阿里的选择。 关涛的花名取自谐音&#xff1a;观涛。…