【深度学习】使用tensorflow实现VGG19网络

【深度学习】使用tensorflow实现VGG19网络

 

 

 

本文章向大家介绍【深度学习】使用tensorflow实现VGG19网络,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

 

 

 

 

VGG网络与AlexNet类似,也是一种CNN,VGG在2014年的 ILSVRC localization and classification 两个问题上分别取得了第一名和第二名。VGG网络非常深,通常有16-19层,卷积核大小为 3 x 3,16和19层的区别主要在于后面三个卷积部分卷积层的数量。第二个用tensorflow独立完成的小玩意儿......

 

 

模型结构

可以看到VGG的前几层为卷积和maxpool的交替,每个卷积包含多个卷积层,后面紧跟三个全连接层。激活函数采用Relu,训练采用了dropout,但并没有像AlexNet一样采用LRN(论文给出的理由是加LRN实验效果不好)。

模型定义

def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"):"""max-pooling"""return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1],strides = [1, strideX, strideY, 1], padding = padding, name = name)def dropout(x, keepPro, name = None):"""dropout"""return tf.nn.dropout(x, keepPro, name)def fcLayer(x, inputD, outputD, reluFlag, name):"""fully-connect"""with tf.variable_scope(name) as scope:w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float")b = tf.get_variable("b", [outputD], dtype = "float")out = tf.nn.xw_plus_b(x, w, b, name = scope.name)if reluFlag:return tf.nn.relu(out)else:return outdef convLayer(x, kHeight, kWidth, strideX, strideY,featureNum, name, padding = "SAME"):"""convlutional"""channel = int(x.get_shape()[-1]) #获取channel数with tf.variable_scope(name) as scope:w = tf.get_variable("w", shape = [kHeight, kWidth, channel, featureNum])b = tf.get_variable("b", shape = [featureNum])featureMap = tf.nn.conv2d(x, w, strides = [1, strideY, strideX, 1], padding = padding)out = tf.nn.bias_add(featureMap, b)return tf.nn.relu(tf.reshape(out, featureMap.get_shape().as_list()), name = scope.name)

定义了卷积、pooling、dropout、全连接五个模块,使用了上一篇AlexNet中的代码,其中卷积模块去除了group参数,因为网络没有像AlexNet一样分成两部分。接下来定义VGG19。

class VGG19(object):"""VGG model"""def __init__(self, x, keepPro, classNum, skip, modelPath = "vgg19.npy"):self.X = xself.KEEPPRO = keepProself.CLASSNUM = classNumself.SKIP = skipself.MODELPATH = modelPath#build CNNself.buildCNN()def buildCNN(self):"""build model"""conv1_1 = convLayer(self.X, 3, 3, 1, 1, 64, "conv1_1" )conv1_2 = convLayer(conv1_1, 3, 3, 1, 1, 64, "conv1_2")pool1 = maxPoolLayer(conv1_2, 2, 2, 2, 2, "pool1")conv2_1 = convLayer(pool1, 3, 3, 1, 1, 128, "conv2_1")conv2_2 = convLayer(conv2_1, 3, 3, 1, 1, 128, "conv2_2")pool2 = maxPoolLayer(conv2_2, 2, 2, 2, 2, "pool2")conv3_1 = convLayer(pool2, 3, 3, 1, 1, 256, "conv3_1")conv3_2 = convLayer(conv3_1, 3, 3, 1, 1, 256, "conv3_2")conv3_3 = convLayer(conv3_2, 3, 3, 1, 1, 256, "conv3_3")conv3_4 = convLayer(conv3_3, 3, 3, 1, 1, 256, "conv3_4")pool3 = maxPoolLayer(conv3_4, 2, 2, 2, 2, "pool3")conv4_1 = convLayer(pool3, 3, 3, 1, 1, 512, "conv4_1")conv4_2 = convLayer(conv4_1, 3, 3, 1, 1, 512, "conv4_2")conv4_3 = convLayer(conv4_2, 3, 3, 1, 1, 512, "conv4_3")conv4_4 = convLayer(conv4_3, 3, 3, 1, 1, 512, "conv4_4")pool4 = maxPoolLayer(conv4_4, 2, 2, 2, 2, "pool4")conv5_1 = convLayer(pool4, 3, 3, 1, 1, 512, "conv5_1")conv5_2 = convLayer(conv5_1, 3, 3, 1, 1, 512, "conv5_2")conv5_3 = convLayer(conv5_2, 3, 3, 1, 1, 512, "conv5_3")conv5_4 = convLayer(conv5_3, 3, 3, 1, 1, 512, "conv5_4")pool5 = maxPoolLayer(conv5_4, 2, 2, 2, 2, "pool5")fcIn = tf.reshape(pool5, [-1, 7*7*512])fc6 = fcLayer(fcIn, 7*7*512, 4096, True, "fc6")dropout1 = dropout(fc6, self.KEEPPRO)fc7 = fcLayer(dropout1, 4096, 4096, True, "fc7")dropout2 = dropout(fc7, self.KEEPPRO)self.fc8 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8")def loadModel(self, sess):"""load model"""wDict = np.load(self.MODELPATH, encoding = "bytes").item()#for layers in modelfor name in wDict:if name not in self.SKIP:with tf.variable_scope(name, reuse = True):for p in wDict[name]:if len(p.shape) == 1:#bias 只有一维sess.run(tf.get_variable('b', trainable = False).assign(p))else:#weightssess.run(tf.get_variable('w', trainable = False).assign(p)) 

buildCNN函数完全按照VGG的结构搭建网络。

loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。 至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的VGG有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的图片,为了避免审美疲劳,我们不只用渣土车,还要用挖掘机、采沙船:

然后编写测试代码:

parser = argparse.ArgumentParser(description='Classify some images.')
parser.add_argument('mode', choices=['folder', 'url'], default='folder')
parser.add_argument('path', help='Specify a path [e.g. testModel]')
args = parser.parse_args(sys.argv[1:])if args.mode == 'folder': #测试方式为本地文件夹#get testImagewithPath = lambda f: '{}/{}'.format(args.path,f)testImg = dict((f,cv2.imread(withPath(f))) for f in os.listdir(args.path) if os.path.isfile(withPath(f)))
elif args.mode == 'url': #测试方式为URLdef url2img(url): #获取URL图像'''url to image'''resp = urllib.request.urlopen(url)image = np.asarray(bytearray(resp.read()), dtype="uint8")image = cv2.imdecode(image, cv2.IMREAD_COLOR)return imagetestImg = {args.path:url2img(args.path)}if testImg.values():#some paramsdropoutPro = 1classNum = 1000skip = []imgMean = np.array([104, 117, 124], np.float)x = tf.placeholder("float", [1, 224, 224, 3])model = vgg19.VGG19(x, dropoutPro, classNum, skip)score = model.fc8softmax = tf.nn.softmax(score)with tf.Session() as sess:sess.run(tf.global_variables_initializer())model.loadModel(sess) #加载模型for key,img in testImg.items():#img preprocessresized = cv2.resize(img.astype(np.float), (224, 224)) - imgMean #去均值maxx = np.argmax(sess.run(softmax, feed_dict = {x: resized.reshape((1, 224, 224, 3))})) #网络输入为224*224res = caffe_classes.class_names[maxx]font = cv2.FONT_HERSHEY_SIMPLEXcv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2) #在图像上绘制结果print("{}: {}n----".format(key,res)) #输出测试结果cv2.imshow("demo", img)cv2.waitKey(0)

如果你看完了我AlexNet的博客,那么一定会发现我这里的测试代码做了一些小的修改,增加了URL测试的功能,可以测试网上的图像 ,测试结果如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/686005.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

wayland(xdg_wm_base) + egl + opengles——dma_buf 作为纹理数据源(五)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、EGL dma_buf import 相关的数据结构和函数1. EGLImageKHR2. eglCreateImageKHR()3. glEGLImageTargetTexture2DOES()二、egl 中 import dma_buf 作为纹理的代码实例1. egl_wayland_dmabuf_…

Why Not Http?

游戏服务器开发主要是基于socket,或者websocket,很少采用http(可能有部分非常轻量级的服务器选择http)。这是什么原因呢?我们先来看看socket与http之间的区别。 socket与http之间的区别 socket与http对比 sockethttpT…

【c++】list 模拟

> 作者简介:დ旧言~,目前大二,现在学习Java,c,c,Python等 > 座右铭:松树千年终是朽,槿花一日自为荣。 > 目标:能手撕list模拟 > 毒鸡汤:不为模糊…

红队ATKCK|红日靶场Write-Up(附下载链接)

网络拓扑图 下载地址 在线下载: http://vulnstack.qiyuanxuetang.net/vuln/detail/2/ 百度网盘 链接:https://pan.baidu.com/s/1nlAZAuvni3EefAy1SGiA-Q?pwdh1e5 提取码:h1e5 环境搭建 通过上述图片,web服务器vm1既能用于外…

学习Android的第十三天

目录 Android TextClock 文本时钟控件 TextClock 控件主要属性和方法 简单的 TextClock 参考文档 Android AnalogClock 控件 AnalogClock 属性 Android Chronometer 计时器 Chronometer 属性 Chronometer 主要方法 范例: 完整的计时器 范例: …

Rust 学习笔记 - Hello world

前言 本文将讲解如何完成一个 Rust 项目的开发流程,从编写 “Hello, World!” 开始,到使用 Cargo 管理和运行项目。 编写 Hello world 开始一个新项目很简单,首先,创建一个包含 main.rs 文件的 hello_world 文件夹,…

基于Doris构建亿级数据实时数据分析系统

转载至我的博客 https://www.infrastack.cn ,公众号:架构成长指南 背景 随着公司业务快速发展,对业务数据进行增长分析的需求越来越迫切,与此同时我们的业务数据量也在快速激增、每天的数据新增量大概在30w 左右,一年…

BUUCTF misc 专题(47)[SWPU2019]神奇的二维码

下载附件,得到一张二维码图片,并用工具扫描(因为图片违规了,所以就不放了哈。工具的话,一般的二维码扫描都可以) swpuctf{flag_is_not_here},(刚开始出了点小差错对不住各位师傅&am…

代码随想录第32天|● 122.买卖股票的最佳时机II ● 55. 跳跃游戏 ● 45.跳跃游戏II

文章目录 买卖股票思路一:贪心代码: 思路:动态规划代码: 跳跃游戏思路:贪心找最大范围代码: 跳跃游戏②思路:代码: 方法二:处理方法一的特殊情况 买卖股票 思路一&#x…

C++类和对象-多态->多态的基本语法、多态的原理剖析、纯虚函数和抽象类、虚析构和纯虚析构

#include<iostream> using namespace std; //多态 //动物类 class Animal { public: //Speak函数就是虚函数 //函数前面加上virtual关键字&#xff0c;变成虚函数&#xff0c;那么编译器在编译的时候就不能确定函数调用了。 virtual void speak() { …

论全人类大脑潜在联系的可能性与现实意义

随着科技和神经科学研究的深入&#xff0c;越来越多的理论与实践表明&#xff0c;尽管人的大脑在物理形态上各自独立&#xff0c;但通过思维、情感、信息交流等多种方式&#xff0c;全人类的大脑之间存在着广泛的、潜在的联系。本文旨在探讨这种普遍联系的可能性及其对人类社会…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之NavRouter组件

鸿蒙&#xff08;HarmonyOS&#xff09;项目方舟框架&#xff08;ArkUI&#xff09;之NavRouter组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、NavRouter组件 导航组件&#xff0c;默认提供点击响应处理&#xff0c;不需要…

微信公众号扫码登录

1.设计 我们采用的是个人号登录方式&#xff0c;这样拿不到我们的userInfo用户信息&#xff0c;然后我们将用户发来的消息&#xff08;xml消息体&#xff09;中的FromUser作为我们唯一的openId 整体流程: 1.用户扫码公众号码&#xff0c;然后发一条消息&#xff1a;验证码&…

2.13日学习打卡----初学RocketMQ(四)

2.13日学习打卡 目录&#xff1a; 2.13日学习打卡一.RocketMQ之Java ClassDefaultMQProducer类DefaultMQPushConsumer类Message类MessageExt类 二.RocketMQ 消费幂消费过程幂等消费速度慢的处理方式 三.RocketMQ 集群服务集群特点单master模式多master模式多master多Slave模式-…

C语言希尔排序详解!!!速过

目录 希尔排序是什么&#xff1f; 关于时间复杂度 希尔排序的源代码 希尔排序源代码的详解 希尔排序是什么&#xff1f; 之前我们说了三个排序&#xff08;插入排序&#xff0c;选择排序&#xff0c;冒泡排序&#xff09;有需要的铁铁可以去看看之前的讲解。 但因为之前的…

基于Python的信息加密解密网站设计与实现【源码+论文+演示视频+包运行成功】

博主介绍&#xff1a;✌csdn特邀作者、博客专家、java领域优质创作者、博客之星&#xff0c;擅长Java、微信小程序、Python、Android等技术&#xff0c;专注于Java、Python等技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; …

CSS 圆形的时钟秒针状的手柄绕中心点旋转的效果

<template><!-- 创建一个装载自定义加载动画的容器 --><view class="cloader"><!-- 定义加载动画主体部分 --><view class="clface"><!-- 定义类似秒针形状的小圆盘 --><view class="clsface"><!-…

docker (五)-docker存储-数据持久化

将数据存储在容器中&#xff0c;一旦容器被删除&#xff0c;数据也会被删除。同时也会使容器变得越来越大&#xff0c;不方便恢复和迁移。 将数据存储到容器之外&#xff0c;这样删除容器也不会丢失数据。一旦容器故障&#xff0c;我们可以重新创建一个容器&#xff0c;将数据挂…

Spring Boot与Kafka集成教程

当然可以&#xff0c;这里为您提供一个简化版的Spring Boot与Kafka集成教程&#xff1a; 新建Spring Boot项目 使用Spring Initializr或您喜欢的IDE&#xff08;如IntelliJ IDEA, Eclipse等&#xff09;新建一个Spring Boot项目。 添加依赖 在项目的pom.xml文件中&#xff0c;…

Linux常见指令(一)

一、基本指令 1.1ls指令 语法 &#xff1a; ls [ 选项 ][ 目录或文件 ] 功能&#xff1a;对于目录&#xff0c;该命令列出该目录下的所有子目录与文件。对于文件&#xff0c;将列出文件名以及其他信息。 常用选项&#xff1a; -a 列出目录下的所有文件&#xff0c;包括以 .…