TensorFlow-slim包进行图像数据集分类---具体流程

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim|-- BUILD|-- README.md|-- WORKSPACE|-- __init__.py|-- datasets|   |-- __init__.py|   |-- __pycache__|   |   |-- __init__.cpython-37.pyc|   |   |-- dataset_utils.cpython-37.pyc|   |   |-- download_and_convert_cifar10.cpython-37.pyc|   |   |-- download_and_convert_flowers.cpython-37.pyc|   |   `-- download_and_convert_mnist.cpython-37.pyc|   |-- build_imagenet_data.py|   |-- cifar10.py|   |-- dataset_factory.py|   |-- dataset_utils.py|   |-- download_and_convert_cifar10.py|   |-- download_and_convert_flowers.py|   |-- download_and_convert_imagenet.sh|   |-- download_and_convert_mnist.py|   |-- download_imagenet.sh|   |-- flowers.py|   |-- imagenet.py|   |-- imagenet_2012_validation_synset_labels.txt|   |-- imagenet_lsvrc_2015_synsets.txt|   |-- imagenet_metadata.txt|   |-- mnist.py|   |-- preprocess_imagenet_validation_data.py|   `-- process_bounding_boxes.py|-- deployment|   |-- __init__.py|   |-- model_deploy.py|   `-- model_deploy_test.py|-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式|-- eval_image_classifier.py        # 测试模型分类效果|-- export_inference_graph.py|-- export_inference_graph_test.py|-- nets|   |-- __init__.py|   |-- alexnet.py|   |-- alexnet_test.py|   |-- cifarnet.py|   |-- cyclegan.py|   |-- cyclegan_test.py|   |-- dcgan.py|   |-- dcgan_test.py|   |-- i3d.py|   |-- i3d_test.py|   |-- i3d_utils.py|   |-- inception.py|   |-- inception_resnet_v2.py|   |-- inception_resnet_v2_test.py|   |-- inception_utils.py|   |-- inception_v1.py|   |-- inception_v1_test.py|   |-- inception_v2.py|   |-- inception_v2_test.py|   |-- inception_v3.py|   |-- inception_v3_test.py|   |-- inception_v4.py|   |-- inception_v4_test.py|   |-- lenet.py|   |-- mobilenet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- conv_blocks.py|   |   |-- madds_top1_accuracy.png|   |   |-- mnet_v1_vs_v2_pixel1_latency.png|   |   |-- mobilenet.py|   |   |-- mobilenet_example.ipynb|   |   |-- mobilenet_v2.py|   |   `-- mobilenet_v2_test.py|   |-- mobilenet_v1.md|   |-- mobilenet_v1.png|   |-- mobilenet_v1.py|   |-- mobilenet_v1_eval.py|   |-- mobilenet_v1_test.py|   |-- mobilenet_v1_train.py|   |-- nasnet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- nasnet.py|   |   |-- nasnet_test.py|   |   |-- nasnet_utils.py|   |   |-- nasnet_utils_test.py|   |   |-- pnasnet.py|   |   `-- pnasnet_test.py|   |-- nets_factory.py|   |-- nets_factory_test.py|   |-- overfeat.py|   |-- overfeat_test.py|   |-- pix2pix.py|   |-- pix2pix_test.py|   |-- resnet_utils.py|   |-- resnet_v1.py|   |-- resnet_v1_test.py|   |-- resnet_v2.py|   |-- resnet_v2_test.py|   |-- s3dg.py|   |-- s3dg_test.py|   |-- vgg.py|   `-- vgg_test.py|-- preprocessing|   |-- __init__.py|   |-- cifarnet_preprocessing.py|   |-- inception_preprocessing.py|   |-- lenet_preprocessing.py|   |-- preprocessing_factory.py|   `-- vgg_preprocessing.py|-- scripts                     # gqr:存储的是相关的模型训练脚本                |   |-- export_mobilenet.sh|   |-- finetune_inception_resnet_v2_on_flowers.sh|   |-- finetune_inception_v1_on_flowers.sh|   |-- finetune_inception_v3_on_flowers.sh|   |-- finetune_resnet_v1_50_on_flowers.sh|   |-- train_cifarnet_on_cifar10.sh|   `-- train_lenet_on_mnist.sh|-- setup.py|-- slim_walkthrough.ipynb`-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; thenmkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; thenwget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gztar -xvf resnet_v1_50_2016_08_28.tar.gzmv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckptrm resnet_v1_50_2016_08_28.tar.gz
fi# Download the dataset
python download_and_convert_data.py \--dataset_name=flowers \--dataset_dir=${DATASET_DIR}# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50 \--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \--batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR} \--eval_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--checkpoint_path=${TRAIN_DIR} \--model_name=resnet_v1_50 \--max_number_of_steps=1000 \--batch_size=32 \--learning_rate=0.001 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR}/all \--eval_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \--train_dir=./data/flowers-models/resnet_v1_50\--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=./data/flowers \--model_name=resnet_v1_50 \--checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \ --batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–checkpoint_path:预训练模型存放地址
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

注意:在模型训练前,需要下载预训练模型,
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

解压后存放在相应目录

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/58690.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

open cv快速入门系列---数字图像基础

目录 一、数字图像基础 1.1 数字图像和图像单位 1.2 区分图片分辨率与屏幕分辨率 1.3 图像的灰度与灰度级 1.4 图像的深度 1.5 二值图像、灰度图像与彩色图像 1.6 通道数 二、数字图像处理 2.1 图像噪声及其消除 2.2 数字图像处理技术 2.2.1 图像变换 2.2.2 图像增强…

爬虫逆向实战(二十七)--某某招标投标网站招标公告

一、数据接口分析 主页地址:某网站 1、抓包 通过抓包可以发现数据接口是page 2、判断是否有加密参数 请求参数是否加密? 通过查看“载荷”模块可以发现,请求参数是一整个密文 请求头是否加密? 无响应是否加密? 通…

springboot集成es 插入和查询的简单使用

第一步&#xff1a;引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-elasticsearch</artifactId><version>2.2.5.RELEASE</version></dependency>第二步&#xff1a;…

Tomcat安装及基本使用

1. 什么是Web服务器 Web服务器是一种应用程序&#xff08;软件&#xff09;&#xff0c;它封装了对HTTP协议的操作&#xff0c;使得开发人员无需直接操作协议&#xff0c;从而简化了Web开发。其主要功能是提供网上信息浏览服务。 Web服务器安装在服务器端&#xff0c;我们可以…

C++ 异常

一、异常概念 异常是一种处理错误的方式&#xff0c;当一个函数发现自己无法处理的错误时就可以抛出异常&#xff0c;让函数的直接或间接 的调用者处理这个错误。 throw: 当问题出现时&#xff0c;程序会抛出一个异常。这是通过使用 throw 关键字来完成的。 catch: 在您想要…

Hystrix: Dashboard流监控

接上两张服务熔断 开始搭建Dashboard流监控 pom依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocat…

Git 简单介绍

Git 是一个开源的分布式版本控制系统&#xff0c;用于敏捷高效地处理任何或小或大的项目。 一、Git 安装 windows安装&#xff1a;进入网站 https://git-scm.com/ 安装&#xff0c;ubuntu配置&#xff1a;apt install git。当前于 Win 下已安装 Git 版本 2.40.1。 二、配置 设…

一台服务器上部署 Redis 伪集群

哈喽大家好&#xff0c;我是咸鱼 今天这篇文章介绍如何在一台服务器&#xff08;以 CentOS 7.9 为例&#xff09;上通过 redis-trib.rb 工具搭建 Redis cluster &#xff08;三主三从&#xff09; redis-trib.rb 是一个基于 Ruby 编写的脚本&#xff0c;其功能涵盖了创建、管…

flutter高德地图大头针

1、效果图 2、pub get #地图定位 amap_flutter_map: ^3.0.0 amap_flutter_location: ^3.0.0 3、上代码 import dart:async; import dart:io;import package:amap_flutter_location/amap_flutter_location.dart; import package:amap_flutter_location/amap_location_option…

网络安全研究和创新:探讨网络安全领域的最新研究成果、趋势和创新技术,以及如何参与其中。

第一章&#xff1a;引言 随着数字化时代的到来&#xff0c;网络安全变得比以往任何时候都更加重要。无论是个人、企业还是国家&#xff0c;都面临着日益复杂和隐蔽的网络威胁。为了确保我们的信息和资产的安全&#xff0c;网络安全研究变得至关重要。本文将深入探讨网络安全领…

安防监控/磁盘阵列存储/视频汇聚平台EasyCVR调用rtsp地址返回的IP不正确是什么原因?

安防监控/云存储/磁盘阵列存储/视频汇聚平台EasyCVR可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标准协议有GB28181、RTSP/Onvif、RTMP等&#xff0c;以及厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等&#xff0c;能对外分发RTSP、RT…

LNMP架构之搭建Discuz论坛

LNMP 一、编译安装Nginx1&#xff09;前置准备2&#xff09;开始编译安装3&#xff09;添加到系统服务&#xff08;systemd启动&#xff09; 二、编译安装MySQL服务1&#xff09;前置准备2&#xff09;编译安装3&#xff09;编辑配置文件4&#xff09;更改mysql安装目录和配置文…

macOS使用命令行连接Oracle(SQL*Plus)

Author: histonevonzohomail.com Date: 2023/08/25 文章目录 SQL\*Plus安装下载环境配置 SQL\*Plus远程连接数据库参考文献 原文地址&#xff1a;https://histonevon.top/archives/oracle-mac-sqlplus数据库安装&#xff1a;Docker安装Oracle数据库 (histonevon.top) SQL*Plus…

Ansys Zemax | 手机镜头设计 - 第 2 部分:使用 OpticsBuilder 实现光机械封装

本文是3篇系列文章的一部分&#xff0c;该系列文章将讨论智能手机镜头模块设计的挑战&#xff0c;从概念、设计到制造和结构变形的分析。本文是三部分系列的第二部分。概括介绍了如何在 CAD 中编辑光学系统的光学元件以及如何在添加机械元件后使用 Zemax OpticsBuilder 分析系统…

二级MySQL(十)——单表查询

这里我们只在一个表内查询&#xff0c;用到的是较为简单的SELECT函数形式 1、查询指定的字段&#xff1a; 用到的数据库是之前提到的S、P、SP数据库 S表格用到的总数据&#xff1a; 首先我们查询所有供应商的序号和名字 这时都是独立的&#xff0c;没有关系&#xff0c;我们找…

android多屏触摸相关的详解方案-安卓framework开发手机车载车机系统开发课程

背景 直播免费视频课程地址&#xff1a;https://www.bilibili.com/video/BV1hN4y1R7t2/ 在做双屏相关需求开发过程中&#xff0c;经常会有对两个屏幕都要求可以正确触摸的场景。但是目前我们模拟器默认创建的双屏其实是没有办法进行触摸的 修改方案1 静态修改方案 使用命令…

对class文件进行base64编码

使用以下代码 package org.springframework.cloud.gateway.sample;import org.springframework.util.Base64Utils;import java.io.*; import java.nio.charset.StandardCharsets;public class EncodeShell {public static void main(String[] args){byte[] data null;try {In…

2021年09月 C/C++(五级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题:抓牛 农夫知道一头牛的位置,想要抓住它。农夫和牛都位于数轴上,农夫起始位于点N(0<=N<=100000),牛位于点K(0<=K<=100000)。农夫有两种移动方式: 1、从X移动到X-1或X+1,每次移动花费一分钟 2、从X移动到2*X,每次移动花费一分钟 假设牛没有意识到农夫的…

2023第七届蓝帽杯 初赛 web LovePHP

LovePHP 直接给了源码。 network查看到&#xff0c;PHP版本是7.4.33 题目要求我们GET一个my_secret.flag参数&#xff0c;根据PHP字符串解析特性&#xff0c;PHP需要将所有参数转换为有效的变量名&#xff0c;因此在解析查询字符串时&#xff0c;它会做两件事&#xff1a; 删…

python+TensorFlow实现人脸识别智能小程序的项目(包含TensorFlow版本与Pytorch版本)(一)

pythonTensorFlow实现人脸识别智能小程序的项目&#xff08;包含TensorFlow版本与Pytorch版本&#xff09;&#xff08;一&#xff09; 一&#xff1a;TensorFlow基础知识内容部分&#xff08;简明扼要&#xff0c;快速适应&#xff09;1、下载Cifar10数据集&#xff0c;并进行…