TensorFlow-slim包进行图像数据集分类---具体流程

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim|-- BUILD|-- README.md|-- WORKSPACE|-- __init__.py|-- datasets|   |-- __init__.py|   |-- __pycache__|   |   |-- __init__.cpython-37.pyc|   |   |-- dataset_utils.cpython-37.pyc|   |   |-- download_and_convert_cifar10.cpython-37.pyc|   |   |-- download_and_convert_flowers.cpython-37.pyc|   |   `-- download_and_convert_mnist.cpython-37.pyc|   |-- build_imagenet_data.py|   |-- cifar10.py|   |-- dataset_factory.py|   |-- dataset_utils.py|   |-- download_and_convert_cifar10.py|   |-- download_and_convert_flowers.py|   |-- download_and_convert_imagenet.sh|   |-- download_and_convert_mnist.py|   |-- download_imagenet.sh|   |-- flowers.py|   |-- imagenet.py|   |-- imagenet_2012_validation_synset_labels.txt|   |-- imagenet_lsvrc_2015_synsets.txt|   |-- imagenet_metadata.txt|   |-- mnist.py|   |-- preprocess_imagenet_validation_data.py|   `-- process_bounding_boxes.py|-- deployment|   |-- __init__.py|   |-- model_deploy.py|   `-- model_deploy_test.py|-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式|-- eval_image_classifier.py        # 测试模型分类效果|-- export_inference_graph.py|-- export_inference_graph_test.py|-- nets|   |-- __init__.py|   |-- alexnet.py|   |-- alexnet_test.py|   |-- cifarnet.py|   |-- cyclegan.py|   |-- cyclegan_test.py|   |-- dcgan.py|   |-- dcgan_test.py|   |-- i3d.py|   |-- i3d_test.py|   |-- i3d_utils.py|   |-- inception.py|   |-- inception_resnet_v2.py|   |-- inception_resnet_v2_test.py|   |-- inception_utils.py|   |-- inception_v1.py|   |-- inception_v1_test.py|   |-- inception_v2.py|   |-- inception_v2_test.py|   |-- inception_v3.py|   |-- inception_v3_test.py|   |-- inception_v4.py|   |-- inception_v4_test.py|   |-- lenet.py|   |-- mobilenet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- conv_blocks.py|   |   |-- madds_top1_accuracy.png|   |   |-- mnet_v1_vs_v2_pixel1_latency.png|   |   |-- mobilenet.py|   |   |-- mobilenet_example.ipynb|   |   |-- mobilenet_v2.py|   |   `-- mobilenet_v2_test.py|   |-- mobilenet_v1.md|   |-- mobilenet_v1.png|   |-- mobilenet_v1.py|   |-- mobilenet_v1_eval.py|   |-- mobilenet_v1_test.py|   |-- mobilenet_v1_train.py|   |-- nasnet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- nasnet.py|   |   |-- nasnet_test.py|   |   |-- nasnet_utils.py|   |   |-- nasnet_utils_test.py|   |   |-- pnasnet.py|   |   `-- pnasnet_test.py|   |-- nets_factory.py|   |-- nets_factory_test.py|   |-- overfeat.py|   |-- overfeat_test.py|   |-- pix2pix.py|   |-- pix2pix_test.py|   |-- resnet_utils.py|   |-- resnet_v1.py|   |-- resnet_v1_test.py|   |-- resnet_v2.py|   |-- resnet_v2_test.py|   |-- s3dg.py|   |-- s3dg_test.py|   |-- vgg.py|   `-- vgg_test.py|-- preprocessing|   |-- __init__.py|   |-- cifarnet_preprocessing.py|   |-- inception_preprocessing.py|   |-- lenet_preprocessing.py|   |-- preprocessing_factory.py|   `-- vgg_preprocessing.py|-- scripts                     # gqr:存储的是相关的模型训练脚本                |   |-- export_mobilenet.sh|   |-- finetune_inception_resnet_v2_on_flowers.sh|   |-- finetune_inception_v1_on_flowers.sh|   |-- finetune_inception_v3_on_flowers.sh|   |-- finetune_resnet_v1_50_on_flowers.sh|   |-- train_cifarnet_on_cifar10.sh|   `-- train_lenet_on_mnist.sh|-- setup.py|-- slim_walkthrough.ipynb`-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; thenmkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; thenwget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gztar -xvf resnet_v1_50_2016_08_28.tar.gzmv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckptrm resnet_v1_50_2016_08_28.tar.gz
fi# Download the dataset
python download_and_convert_data.py \--dataset_name=flowers \--dataset_dir=${DATASET_DIR}# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50 \--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \--batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR} \--eval_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--checkpoint_path=${TRAIN_DIR} \--model_name=resnet_v1_50 \--max_number_of_steps=1000 \--batch_size=32 \--learning_rate=0.001 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR}/all \--eval_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \--train_dir=./data/flowers-models/resnet_v1_50\--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=./data/flowers \--model_name=resnet_v1_50 \--checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \ --batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–checkpoint_path:预训练模型存放地址
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

注意:在模型训练前,需要下载预训练模型,
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

解压后存放在相应目录

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/58690.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

open cv快速入门系列---数字图像基础

目录 一、数字图像基础 1.1 数字图像和图像单位 1.2 区分图片分辨率与屏幕分辨率 1.3 图像的灰度与灰度级 1.4 图像的深度 1.5 二值图像、灰度图像与彩色图像 1.6 通道数 二、数字图像处理 2.1 图像噪声及其消除 2.2 数字图像处理技术 2.2.1 图像变换 2.2.2 图像增强…

爬虫逆向实战(二十七)--某某招标投标网站招标公告

一、数据接口分析 主页地址:某网站 1、抓包 通过抓包可以发现数据接口是page 2、判断是否有加密参数 请求参数是否加密? 通过查看“载荷”模块可以发现,请求参数是一整个密文 请求头是否加密? 无响应是否加密? 通…

springboot集成es 插入和查询的简单使用

第一步&#xff1a;引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-elasticsearch</artifactId><version>2.2.5.RELEASE</version></dependency>第二步&#xff1a;…

TypeScript的变量声明的各种方式

TypeScript是一种静态类型的JavaScript超集&#xff0c;它为JavaScript代码提供了类型检查和更好的代码组织结构。在TypeScript中&#xff0c;变量声明是非常重要的&#xff0c;因为它们定义了变量的类型和范围。本文将详细介绍TypeScript的变量声明&#xff0c;并通过代码案例…

Tomcat安装及基本使用

1. 什么是Web服务器 Web服务器是一种应用程序&#xff08;软件&#xff09;&#xff0c;它封装了对HTTP协议的操作&#xff0c;使得开发人员无需直接操作协议&#xff0c;从而简化了Web开发。其主要功能是提供网上信息浏览服务。 Web服务器安装在服务器端&#xff0c;我们可以…

C++ 异常

一、异常概念 异常是一种处理错误的方式&#xff0c;当一个函数发现自己无法处理的错误时就可以抛出异常&#xff0c;让函数的直接或间接 的调用者处理这个错误。 throw: 当问题出现时&#xff0c;程序会抛出一个异常。这是通过使用 throw 关键字来完成的。 catch: 在您想要…

L1-043 阅览室(Python实现) 测试点全过

题目 天梯图书阅览室请你编写一个简单的图书借阅统计程序。当读者借书时&#xff0c;管理员输入书号并按下S键&#xff0c;程序开始计时&#xff1b;当读者还书时&#xff0c;管理员输入书号并按下E键&#xff0c;程序结束计时。书号为不超过1000的正整数。当管理员将0作为书号…

国际腾讯云账号云服务器网络访问丢包问题解决办法!!

本文主要介绍可能引起云服务器网络访问丢包问题的主要原因&#xff0c;及对应排查、解决方法。下面一起了解腾讯云国际云服务器网络访问丢包问题解决办法&#xff1a; 可能原因 引起云服务器网络访问丢包问题的可能原因如下&#xff1a; 1.触发限速导致 TCP 丢包 2.触发限速导致…

linux下vi或vim操作Found a swap file by the name的原因及解决方法--九五小庞

在linux下用vi或vim打开Test.java文件时 [rootlocalhost tmp]# vi Test.java出现了如下信息&#xff1a; E325: ATTENTION Found a swap file by the name ".Test.java.swp" owned by: root dated: Wed Dec 7 13:52:56 2011 file name: /var/tmp/Test.java modif…

Hystrix: Dashboard流监控

接上两张服务熔断 开始搭建Dashboard流监控 pom依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocat…

华为复合vlan(mux vlan)

一、概念&#xff1a; Multiplex vlan&#xff1a;实现网络资源控制的的机制。 / Principle vlan&#xff1a;port 可以和mux vlan内所有接口进行通信&#xff0c;限制128个 < /Separate vlan&#xff1a;隔离型从vlan&#xff0c;只能和…

Git 简单介绍

Git 是一个开源的分布式版本控制系统&#xff0c;用于敏捷高效地处理任何或小或大的项目。 一、Git 安装 windows安装&#xff1a;进入网站 https://git-scm.com/ 安装&#xff0c;ubuntu配置&#xff1a;apt install git。当前于 Win 下已安装 Git 版本 2.40.1。 二、配置 设…

一台服务器上部署 Redis 伪集群

哈喽大家好&#xff0c;我是咸鱼 今天这篇文章介绍如何在一台服务器&#xff08;以 CentOS 7.9 为例&#xff09;上通过 redis-trib.rb 工具搭建 Redis cluster &#xff08;三主三从&#xff09; redis-trib.rb 是一个基于 Ruby 编写的脚本&#xff0c;其功能涵盖了创建、管…

flutter高德地图大头针

1、效果图 2、pub get #地图定位 amap_flutter_map: ^3.0.0 amap_flutter_location: ^3.0.0 3、上代码 import dart:async; import dart:io;import package:amap_flutter_location/amap_flutter_location.dart; import package:amap_flutter_location/amap_location_option…

网络安全研究和创新:探讨网络安全领域的最新研究成果、趋势和创新技术,以及如何参与其中。

第一章&#xff1a;引言 随着数字化时代的到来&#xff0c;网络安全变得比以往任何时候都更加重要。无论是个人、企业还是国家&#xff0c;都面临着日益复杂和隐蔽的网络威胁。为了确保我们的信息和资产的安全&#xff0c;网络安全研究变得至关重要。本文将深入探讨网络安全领…

vue PDF或Word转换为HTML并保留原有样式

方法一 要将PDF或Word转换为HTML并保留原有样式&#xff0c;可以使用pdfjs-dist和mammoth.js这两个库。首先需要安装这两个库&#xff1a; npm install pdfjs-dist mammoth.js然后在Vue项目中使用这两个库进行转换&#xff1a; import * as pdfjsLib from pdfjs-dist; impor…

【机器学习】鸢尾花分类-逻辑回归示例

这段代码是一个完整的示例&#xff0c;展示了如何使用逻辑回归对鸢尾花数据集进行训练、保存模型&#xff0c;并允许用户输入数据进行预测。以下是对这段代码的总结&#xff1a;功能&#xff1a; 这段代码演示了如何使用逻辑回归对鸢尾花数据集进行训练&#xff0c;并将训练好的…

安防监控/磁盘阵列存储/视频汇聚平台EasyCVR调用rtsp地址返回的IP不正确是什么原因?

安防监控/云存储/磁盘阵列存储/视频汇聚平台EasyCVR可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标准协议有GB28181、RTSP/Onvif、RTMP等&#xff0c;以及厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等&#xff0c;能对外分发RTSP、RT…

LNMP架构之搭建Discuz论坛

LNMP 一、编译安装Nginx1&#xff09;前置准备2&#xff09;开始编译安装3&#xff09;添加到系统服务&#xff08;systemd启动&#xff09; 二、编译安装MySQL服务1&#xff09;前置准备2&#xff09;编译安装3&#xff09;编辑配置文件4&#xff09;更改mysql安装目录和配置文…

【深度学习】神经网络中 Batch 和 Epoch 之间的区别是什么?我们该如何理解?

文章目录 一、问题的引入1.1 随机梯度下降1.2 主要参数 二、Batch三、Epoch四、两者之间的联系和区别 一、问题的引入 1.1 随机梯度下降 随机梯度下降&#xff08;Stochastic Gradient Descent&#xff0c;SGD&#xff09;是一种优化算法&#xff0c;用于在机器学习和深度学习…