利用Inception-V3训练的权重微调,实现猫狗分类(基于keras)

利用Inception-V3训练的权重微调实现猫狗的分类,其中权重的下载在我的博客下载资源处,https://download.csdn.net/download/fanzonghao/10566634

第一种权重不改变直接用mixed7层(mixed7呆会把打印结果一放就知道了)进行特征提取,然后在拉平,连上两层神经网络

def define_model():InceptionV3_weight_path='./model_weight/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'pre_trained_model=InceptionV3(input_shape=(150,150,3),include_top=False,#不包含全连接层weights=None)pre_trained_model.load_weights(InceptionV3_weight_path)#下面两种取其一#仅仅用其做特征提取 不需要更新权值for layer in pre_trained_model.layers:print(layer.name)layer.trainable=False#微调权值# unfreeze=False# for layer in pre_trained_model.layers:#     if unfreeze:#         layer.trainable=True#     if layer.name=='mixed6':#         unfreeze=Truelast_layer=pre_trained_model.get_layer('mixed7')print(last_layer.output_shape)last_output=last_layer.output#以下是在模型的基础上增加的x=layers.Flatten()(last_output)x=layers.Dense(1024,activation='relu')(x)x=layers.Dropout(0.2)(x)x=layers.Dense(1,activation='sigmoid')(x)model=Model(inputs=pre_trained_model.input,outputs=x)return model

第一种完全利用Inception-V3训练的权重代码

import os
import tensorflow as tf
import matplotlib.pyplot as pltfrom keras.applications.inception_v3 import InceptionV3
from keras import  layers
from keras.models import Model
from keras.optimizers import RMSprop
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
import data_read
"""
#获得所需求的图片--进行了图像增强
"""
def data_deal_overfit():# 获取数据的路径train_dir, validation_dir, next_cat_pix, next_dog_pix = data_read.read_data()#图像增强train_datagen=ImageDataGenerator(rescale=1./255,rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')test_datagen=ImageDataGenerator(rescale=1./255)#从文件夹获取所需要求的图片train_generator=train_datagen.flow_from_directory(train_dir,target_size=(150,150),batch_size=20,class_mode='binary')test_generator = test_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=20,class_mode='binary')return train_generator,test_generator
"""
#定义模型并加入了dropout
"""
def define_model():InceptionV3_weight_path='./model_weight/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'pre_trained_model=InceptionV3(input_shape=(150,150,3),include_top=False,#不包含全连接层weights=None)pre_trained_model.load_weights(InceptionV3_weight_path)#下面两种取其一#仅仅用其做特征提取 不需要更新权值for layer in pre_trained_model.layers:print(layer.name)layer.trainable=False#微调权值# unfreeze=False# for layer in pre_trained_model.layers:#     if unfreeze:#         layer.trainable=True#     if layer.name=='mixed6':#         unfreeze=Truelast_layer=pre_trained_model.get_layer('mixed7')print(last_layer.output_shape)last_output=last_layer.output#以下实在模型的基础上增加的x=layers.Flatten()(last_output)x=layers.Dense(1024,activation='relu')(x)x=layers.Dropout(0.2)(x)x=layers.Dense(1,activation='sigmoid')(x)model=Model(inputs=pre_trained_model.input,outputs=x)return model"""
训练模型
"""
def train_model():model=define_model()model.compile(optimizer=RMSprop(lr=0.001), loss='binary_crossentropy', metrics=['accuracy'])train_generator, test_generator = data_deal_overfit()# verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录# 训练模型 返回history包含各种精度和损失history = model.fit_generator(train_generator,steps_per_epoch=100,  # 2000 images=batch_szie*stepsepochs=50,validation_data=test_generator,validation_steps=50,  # 1000=20*50verbose=2)#精度acc=history.history['acc']val_acc=history.history['val_acc']#损失loss=history.history['loss']val_loss=history.history['val_loss']#epochs的数量epochs=range(len(acc))plt.plot(epochs,acc)plt.plot(epochs, val_acc)plt.title('training and validation accuracy')plt.figure()plt.plot(epochs, loss)plt.plot(epochs, val_loss)plt.title('training and validation loss')plt.show()if __name__ == '__main__':train_model()
打印结果:其中这些代表每一层的名字,直接利用mixed7的特征,(none,7,7,768)就是该层的shape, 直接拉平,添加两层神经网络进行分类。

打印结果:这是每一层的名字,mixed7层的shape是(None,7,7,768)第一种做法就是直接利用该层及之前层的权重进行训练分类的。

第二种:进行微调要不是需要对整个权重都进行重新赋值,因为前面层数学习到的特征是一些简单的特征,只是随着层数增强才更加具有针对性,故把mixed7层的卷积层权重 重新训练,代码:

unfreeze=False
for layer in pre_trained_model.layers:if unfreeze:layer.trainable=Trueif layer.name=='mixed6':unfreeze=True

也就是把我上段完整的代码注释替换一下即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493570.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

刘锋:互联网左右大脑结构与钱学森开放复杂巨系统

作者:刘锋 互联网进化论作者 计算机博士前言:1990年,钱学森提出了开放的复杂巨系统理论,并提出以人为主,人机结合,从定性到定量的综合集成研究方法,他也预见性的提出“因特网正好生动地体现了…

DL中常用的numpy

读txt文件 按行读取有三种方式,注意readlines和readline的区别。open是python自带打开方式,如果打不开,可以使用encoding"UTF-8"指定解码方案。 读取得到一行之后,行首行尾可能存在一些不需要的字符,就可以…

paip.获取proxool的配置 xml读取通过jdk xml 初始化c3c0在代码中总结

paip.获取proxool的配置 xml读取通过jdk xml 初始化c3c0在代码中xml读取通过jdk xml初始化c3c0在代码中。。。。。作者Attilax 艾龙, EMAIL:1466519819qq.com来源:attilax的专栏地址:http://blog.csdn.net/attilaxproxoolController.ini()…

手写字母数据集转换为.pickle文件

首先是数据集,我上传了相关的资源,https://download.csdn.net/download/fanzonghao/10566701 转换代码如下: import numpy as np import os import matplotlib.pyplot as plt import matplotlib.image as mpig import imageio import pickle…

一文看懂谷歌的AI芯片布局,边缘端TPU将大发神威

来源:新电子2018年7月Google在其云端服务年会Google Cloud Next上正式发表其边缘(Edge)技术,与另两家国际公有云服务大厂Amazon/AWS、Microsoft Azure相比,Google对于边缘技术已属较晚表态、较晚布局者,但其技术主张却与前两业者有…

JS学习笔记-1--基本知识和注意事项

1、JS开始的目的主要是验证表单的输入验证 2、是一种具有面向对象能力的、解释型语言。是基于事件驱动的相对较安全的客户端脚本语言 3、JS 特点:松散型:变量不具备一个明确的类型; 对象属性:把属性名可以映射成任意的属性值&a…

opencv图像处理中的一些滤波器+利用滤波器提取条形码(解析二维码)+公交卡倾斜矫正+物体尺寸丈量

一般来说,图像的能量主要集中在其低频部分,噪声所在的频段主要在高频段,同时图像中的细节信息也主要集中在其高频部分,因此,如何去掉高频干扰同时又保持细节信息是关键。为了去除噪声,有必要对图像进行平滑,可以采用低通滤波的方法去除高频干扰。图像平滑包括空域法和频域法两大…

智联汽车:复盘国内巨头布局

来源:申万宏源摘要:从今年阿里9月云栖大会、华为10月全联接大会、百度11月世界大会、腾讯11月合作伙伴大会可以发现BATH均高调展示了各自在汽车科技领域的研发成果;而京东、滴滴两家公司近两年来关于汽车科技领域的动态亦在频频更新。▌车联网:车载OS竞争…

即插即用+任意blur的超分辨率重建——DPSR

计算机视觉中存在许多的不适定问题ill-posed problem。先来看什么是适定问题well-posed problem,适定问题必须同时满足三个条件: 1. a solution exists 解必须存在2. the solution is unique 解必须唯一3. the solutions behavior changes c…

Tomcat基础教程(一)

Tomcat, 是Servlet和JSP容器,其是实现了JSP规范的servlet容器。它在servlet生命周期内包容,装载,运行,和停止servlet容器。 Servlet容器的三种工作模式: 1. 独立的Servlet容器 Servlet容器与基于JAVA技术的Web服务器集…

opencv--图像金字塔

一,高斯金字塔--图片经过高斯下采样 """ 高斯金字塔 """ def gauss_pyramid():img cv2.imread(./data/img4.png)lower_reso cv2.pyrDown(img)lower_reso2 cv2.pyrDown(lower_reso)plt.subplot(131), plt.imshow(img)plt.title(In…

中国移动:5G蜂窝IoT关键技术分析

来源:5G本文讨论了蜂窝物联网的技术现状,针对增强机器类通信和窄带物联网技术标准,提出了2种现网快速部署方案,并进一步指出了C-IoT面向5G的演进路径。该路径充分考虑了5G网络中网络功能虚拟化、软件定义网络、移动边缘计算和大数…

dataframe常用操作总结

初始化 可以使用arraycolumns的格式, dpd.DataFrame(np.arange(10).reshape(2,5)) df1 pd.DataFrame([[Snow,M,22],[Tyrion,M,32],[Sansa,F,18],[Arya,F,14]], columns[name,gender,age]) 也可以使用字典大括号的格式: df pd.DataFrame({a: [1, 2…

DEDE无简略标题时显示完整标题

新闻的标题需要进行字数限制,这就需要加入一个title属性,让鼠标放上去的时候显示完整标题。另外目前的调用只能同时调用一种标题方式,不过可 以采用以下方法,进行判断,无简略标题显示完整标题。例如dede早期版本中的”…

清华大学发布:人脸识别最全知识图谱

来源:智东西摘要:本期我们推荐来自清华大学副教授唐杰领导的学者大数据挖掘项目Aminer的研究报告,讲解人脸识别技术及其应用领域,介绍人脸识别领域的国内玩人才并预测该技术的发展趋势。自20世纪下半叶,计算机视觉技术…

图像变换dpi(tif->jpg),直方图均衡化,腐蚀膨胀,分水岭,模板匹配,直线检测

一.图像变换dpi 1.示例1 import numpy as np from PIL import Image import cv2 def test_dp():path./gt_1.tif# imgImage.open(path)# print(img.size)# print(img.info)imgcv2.imread(path)imgImage.fromarray(img)print(img.size)print(img.info)img.save(test.jpg, dpi(3…

CV中的经典网络模型

目标检测 目标检测,不仅要识别目标是什么(分类),还要知道目标的具体位置(可以当作回归来做)。 RCNN Selective Search 算法获得候选框,Alexnet提取特征,SVM对每个候选框区域打分。…

无表头单链表增删改查操作

1、返回单链表中第pos个结点中的元素,若pos超出范围,则返回02、把单链表中第pos个结点的值修改为x的值,若修改成功返回1,否则返回03、向单链表的表头插入一个元素 4、向单链表的末尾添加一个元素…

JBU联合双边上采样

很多图像处理算法,如立体视觉中的深度估计,图像上色,高动态范围HDR中的tone mapping,图像分割,都有一个共性的问题:寻找一个全局的解,这个解是指一个分段的piecewise平滑含糊,描述了…

技术阅读周刊第十一期

技术阅读周刊,每周更新。 历史更新 20231124:第七期20231201:第八期20231215:第十‍期 A Comprehensive guide to Spring Boot 3.2 with Java 21, Virtual Threads, Spring Security, PostgreSQL, Flyway, Caching, Micrometer, O…