python怎么使用预训练的模型_Keras使用ImageNet上预训练的模型方式

我就废话不多说了,大家还是直接看代码吧!

import keras

import numpy as np

from keras.applications import vgg16, inception_v3, resnet50, mobilenet

#Load the VGG model

vgg_model = vgg16.VGG16(weights='imagenet')

#Load the Inception_V3 model

inception_model = inception_v3.InceptionV3(weights='imagenet')

#Load the ResNet50 model

resnet_model = resnet50.ResNet50(weights='imagenet')

#Load the MobileNet model

mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

2020052321014681.jpg

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images

y_train = mnist.train.labels

x_test = mnist.test.images

y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

2020052321014683.jpg

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

2020052321014685.jpg

完整代码

import keras

from keras.datasets import mnist

from keras.models import Sequential

from keras.layers.core import Dense, Dropout, Activation, Flatten

from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D

from keras.layers.normalization import BatchNormalization

from keras.callbacks import ModelCheckpoint

import numpy as np

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据

batch_size = 64

num_classes = 10

epochs = 10

img_shape = (28,28,1)

# input dimensions

img_rows, img_cols = 28,28

# dataset input

#(x_train, y_train), (x_test, y_test) = mnist.load_data()

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

print(mnist.train.images.shape, mnist.train.labels.shape)

print(mnist.test.images.shape, mnist.test.labels.shape)

print(mnist.validation.images.shape, mnist.validation.labels.shape)

x_train = mnist.train.images

y_train = mnist.train.labels

x_test = mnist.test.images

y_test = mnist.test.labels

# data initialization

x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)

x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)

input_shape = (img_rows, img_cols, 1)

# Define the input layer

inputs = keras.Input(shape = [img_rows, img_cols, 1])

#Define the converlutional layer 1

conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)

# Define the pooling layer 1

pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)

# Define the standardization layer 1

stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)

# Define the converlutional layer 2

conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)

# Defien the pooling layer 2

pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)

# Define the standardization layer 2

stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)

# Define the converlutional layer 3

conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)

stand3 = keras.layers.BatchNormalization(axis=1)(conv3)

# Define the converlutional layer 4

conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)

stand4 = keras.layers.BatchNormalization(axis=1)(conv4)

# Define the converlutional layer 5

conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)

pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)

stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)

# Define the fully connected layer

flatten = keras.layers.Flatten()(stand5)

fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)

drop1 = keras.layers.Dropout(0.5)(fc1)

fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)

drop2 = keras.layers.Dropout(0.5)(fc2)

fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)

# 基于Model方法构建模型

model = keras.Model(inputs= inputs, outputs = fc3)

# 编译模型

model.compile(optimizer= tf.train.AdamOptimizer(0.001),

loss= keras.losses.categorical_crossentropy,

metrics= ['accuracy'])

# 训练配置,仅供参考

model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持WEB开发者。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/349475.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NOIP模拟测试49·50「养花·折射·画作·施工·蔬菜·联盟」

一套题 养花 题解 分块\主席树 这里我用的是主席树 查询分段$1-(k-1)$找最大的,能向右找就向右找 for(ll nowl1,nowrk-1;nowl<maxx;nowlk,nowrk,nowrmin(nowr,maxx)){if(ansmod-1) break;chose(rt[r],rt[l-1],nowl,nowr,1,maxx);} 复杂度分析,调和级数$√n*log(n)$ 代码 #in…

宏任务和微任务执行顺序_确保任务的执行顺序

宏任务和微任务执行顺序有时有必要对线程池中的任务施加一定的顺序。 JavaSpecialists通讯的第206期提出了一种这样的情况&#xff1a;我们使用NIO从多个连接中读取数据。 我们需要确保来自给定连接的事件按顺序执行&#xff0c;但是不同连接之间的事件可以自由混合。 我想提出…

c语言中aver是什么意思_Linux系统top命令中的io使用率,到底是什么意思?

最近在做连续数据流的缓冲系统&#xff0c;C语言代码实现后&#xff0c;粗略测试了下&#xff0c;功能上应该没有问题。那么&#xff0c;接下来就该测试性能了。输入 top 命令&#xff0c;的确可以看到一系列 cpu 使用率&#xff0c;其中一个值得注意的子项就是 io 使用率了&am…

wireshark捕获选项不能用_wireshark的一些基础用法,欢迎收藏

About WiresharkWireshark是世界上最重要和使用最广泛的网络协议分析器。它让您在微观层次上看到网络上正在发生的事情&#xff0c;并且是许多商业和非营利性企业、政府机构和教育机构事实上(通常也是法律上)的标准。Wireshark的发展得益于全球网络专家的志愿贡献&#xff0c;并…

管理沟通-沟通框架

背景 管理三明治的承托&#xff0c;管理沟通。离开了沟通&#xff0c;所有的工作都将搁浅而无法前进。 常见话题&#xff1a; 向上沟通员工激励团队凝聚力提升向下沟通工作特点 工作职责说明技术开发计算机&#xff0c;编程语言&#xff0c;设计算法&#xff0c;开发功能&#…

t’触发器真值表和状态方程_清写出触发器按逻辑特性的分类;写出T触发器的状态方程。...

下列对配电所的说法有误的一项是()。A&#xff0e;市区10kV公用配电所的供电半径一般不大于300m&#xff0c;在郊区的供成功的基础设施服务的提供者都首先是按照商业化的原则经营的&#xff0c;并至少具有几个基本特点&#xff0c;这些基本特我国幅员辽阔&#xff0c;能源分布不…

NetBeans 9抢先体验

Java 9即将来临&#xff0c;NetBeans 9也即将来临。在本文中&#xff0c;我们将看到NetBeans 9 Early Access为开发人员提供的支持&#xff0c;以帮助他们构建Java 9兼容的应用程序。 Java 9提供了许多&#xff08;大约90种&#xff09; 新功能&#xff0c;包括Modules和JShel…

块裁剪后的矩形边界如何去掉_手持拍摄画面太抖?这节课教你如何快速稳定抖动的画面...

手持相机进行拍摄&#xff0c;画面会有较为明显的抖动&#xff0c;这节课就教大家如何稳定视频画面。素材导入到PR后&#xff0c;为素材添加变形稳定器效果&#xff0c;软件会自动开始分析。当前素材上方会显示在后台分析&#xff0c;这时候我们可以剪辑其他部分&#xff0c;并…

怎么把空字符串去掉_Python知识点字符串转整数需注意

↑↑↑关注后"星标"简说Python人人都可以简单入门Python、爬虫、数据分析简说Python严选 来源&#xff1a;简说Python 作者&#xff1a;老表One old watch, like brief python大家好&#xff0c;我是老表&#xff5e;Python知识点系列&#xff0c;学习了记得点赞、…

tree

随机走,看期望 由于zzn过于sb,考试推出来式子因为统计时间不对没有$AC$(应dfs前统计) zzn实在过于sb,式子和题解完全不一样,所以看题解的可以走了 记录tofa[x]表示当前点走到父亲期望步数 可以直接走到父亲 贡献$\frac{1}{deg[x]}$ 走到儿子再走到父亲$\frac{1}{deg[x]}*(1tofa…

android-x86 镜像iso下载_Windows 10(1909)最新12月更新版MSDN官方简体中文原版ISO镜像下载+激huo工ju...

微软已于11月中旬开始大规模推送Windows 10操作系统的最新版本1909。此次更新官方未放出具体更新日志&#xff0c;但没有太多大功能更新&#xff0c;主要还是“修修补补”为主。现在&#xff0c;为大家带来本次官方最新原版ISO镜像下载&#xff0c;具体内部版本号为18363&#…

32查运行内存的map文件_Spark Shuffle调优之调节map端内存缓冲与reduce端内存占比

本文首先介绍Spark中的两个配置参数: spark.shuffle.file.buffer map端内存缓冲 spark.shuffle.memoryFraction reduce端内存占比很多博客会说上面这两个参数是调节Spark shuffle性能的利器&#xff0c;实际上并不是这样的。以实际的生产经验来说&#xff0c;这两个参数没…

element ui 表单验证为正整数

很多时候都会用表单中输入正整数的情况&#xff0c;在element ui中可以用el-input-number 标签来显示输入框是number类型的&#xff0c;或者type"number"也可以的&#xff0c;但是正整数就需要判断了 可以利用正则来判断 代码如下 <el-form ref"checkData&qu…

python用input输入列表_Python如何使用input函数获取输入

所谓输入&#xff0c;就是用代码获取用户通过键盘输入的信息。 例如&#xff1a;去银行取钱&#xff0c;在 ATM 上输入密码。 在 Python 中&#xff0c;如果要获取用户在键盘上的输入信息&#xff0c;需要使用到input()函数。 函数input()让程序暂停运行&#xff0c;等待用户输…

web.xml.jsf_使用JSF 2.0可以更轻松地进行多字段验证

web.xml.jsf开发应用程序表单时最常见的需求之一是多字段验证&#xff08;或跨字段验证&#xff0c;但我没有使用此术语&#xff0c;因为当我将其放在Google上时&#xff0c;实际上得到了一些战后图片&#xff09;。 我正在谈论的情况是&#xff0c;我们需要比较初始日期是早于…

odoo self.ensure_one()

源码&#xff1a; def ensure_one(self): """ Verifies that the current recorset holds a single record. Raises an exception otherwise. """ try: # unpack to ensure there is only one value is faster than len when…

模板 字段_劲爆新功能:轻流文字识别(OCR)功能支持自定义识别模板啦

Hi&#xff0c;又和大家见面啦&#xff5e;前段时间我们的文字识别(OCR)功能推出后&#xff0c;由于只支持系统提供的固定识别模板&#xff0c;很多客户跟我们反馈说&#xff1a;希望可以自定义识别模板&#xff01;现应大家的要求&#xff0c;轻流「文字识别(OCR)」的「自定义…

中文字符频率统计python_python统计字符串出现最多的字母及其出现次数

统计字符串出现最多的字母及其出现次数 另外如果次数相同按字母顺序排序。 方法1 可以使用自定义键对c.most_common()进行排序&#xff0c;该键首先考虑频率的降序&#xff0c;然后考虑字母的降序&#xff08;请注意lambda x: (-x[1], x[0]) &#xff09;&#xff1a; from col…

c#float取小数点后两位_C# 保留小数点后两位(方法总结)

最简单使用: float i=1.6667f; string show=i.ToString("0.00"); //结果1.67(四舍五入) 其他类似方法: string show=i.ToString("F");//"F2","f" 不区分大小写 string show=String.Format("{0:F}",i);//也可以为F2,或者&q…

Java 9中的进程处理

一直以来&#xff0c;用Java管理操作系统进程都是一项艰巨的任务。 这样做的原因是可用的工具和API较差。 老实说&#xff0c;这并非没有道理&#xff1a;Java并非出于此目的。 如果要管理OS进程&#xff0c;则可以使用所需的Shell&#xff0c;Perl脚本。 对于面临更复杂任务的…