TensorFlow中slim包的具体用法

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim|-- BUILD|-- README.md|-- WORKSPACE|-- __init__.py|-- datasets|   |-- __init__.py|   |-- __pycache__|   |   |-- __init__.cpython-37.pyc|   |   |-- dataset_utils.cpython-37.pyc|   |   |-- download_and_convert_cifar10.cpython-37.pyc|   |   |-- download_and_convert_flowers.cpython-37.pyc|   |   `-- download_and_convert_mnist.cpython-37.pyc|   |-- build_imagenet_data.py|   |-- cifar10.py|   |-- dataset_factory.py|   |-- dataset_utils.py|   |-- download_and_convert_cifar10.py|   |-- download_and_convert_flowers.py|   |-- download_and_convert_imagenet.sh|   |-- download_and_convert_mnist.py|   |-- download_imagenet.sh|   |-- flowers.py|   |-- imagenet.py|   |-- imagenet_2012_validation_synset_labels.txt|   |-- imagenet_lsvrc_2015_synsets.txt|   |-- imagenet_metadata.txt|   |-- mnist.py|   |-- preprocess_imagenet_validation_data.py|   `-- process_bounding_boxes.py|-- deployment|   |-- __init__.py|   |-- model_deploy.py|   `-- model_deploy_test.py|-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式|-- eval_image_classifier.py        # 测试模型分类效果|-- export_inference_graph.py|-- export_inference_graph_test.py|-- nets|   |-- __init__.py|   |-- alexnet.py|   |-- alexnet_test.py|   |-- cifarnet.py|   |-- cyclegan.py|   |-- cyclegan_test.py|   |-- dcgan.py|   |-- dcgan_test.py|   |-- i3d.py|   |-- i3d_test.py|   |-- i3d_utils.py|   |-- inception.py|   |-- inception_resnet_v2.py|   |-- inception_resnet_v2_test.py|   |-- inception_utils.py|   |-- inception_v1.py|   |-- inception_v1_test.py|   |-- inception_v2.py|   |-- inception_v2_test.py|   |-- inception_v3.py|   |-- inception_v3_test.py|   |-- inception_v4.py|   |-- inception_v4_test.py|   |-- lenet.py|   |-- mobilenet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- conv_blocks.py|   |   |-- madds_top1_accuracy.png|   |   |-- mnet_v1_vs_v2_pixel1_latency.png|   |   |-- mobilenet.py|   |   |-- mobilenet_example.ipynb|   |   |-- mobilenet_v2.py|   |   `-- mobilenet_v2_test.py|   |-- mobilenet_v1.md|   |-- mobilenet_v1.png|   |-- mobilenet_v1.py|   |-- mobilenet_v1_eval.py|   |-- mobilenet_v1_test.py|   |-- mobilenet_v1_train.py|   |-- nasnet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- nasnet.py|   |   |-- nasnet_test.py|   |   |-- nasnet_utils.py|   |   |-- nasnet_utils_test.py|   |   |-- pnasnet.py|   |   `-- pnasnet_test.py|   |-- nets_factory.py|   |-- nets_factory_test.py|   |-- overfeat.py|   |-- overfeat_test.py|   |-- pix2pix.py|   |-- pix2pix_test.py|   |-- resnet_utils.py|   |-- resnet_v1.py|   |-- resnet_v1_test.py|   |-- resnet_v2.py|   |-- resnet_v2_test.py|   |-- s3dg.py|   |-- s3dg_test.py|   |-- vgg.py|   `-- vgg_test.py|-- preprocessing|   |-- __init__.py|   |-- cifarnet_preprocessing.py|   |-- inception_preprocessing.py|   |-- lenet_preprocessing.py|   |-- preprocessing_factory.py|   `-- vgg_preprocessing.py|-- scripts                     # gqr:存储的是相关的模型训练脚本                |   |-- export_mobilenet.sh|   |-- finetune_inception_resnet_v2_on_flowers.sh|   |-- finetune_inception_v1_on_flowers.sh|   |-- finetune_inception_v3_on_flowers.sh|   |-- finetune_resnet_v1_50_on_flowers.sh|   |-- train_cifarnet_on_cifar10.sh|   `-- train_lenet_on_mnist.sh|-- setup.py|-- slim_walkthrough.ipynb`-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; thenmkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; thenwget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gztar -xvf resnet_v1_50_2016_08_28.tar.gzmv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckptrm resnet_v1_50_2016_08_28.tar.gz
fi# Download the dataset
python download_and_convert_data.py \--dataset_name=flowers \--dataset_dir=${DATASET_DIR}# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50 \--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \--batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR} \--eval_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--checkpoint_path=${TRAIN_DIR} \--model_name=resnet_v1_50 \--max_number_of_steps=1000 \--batch_size=32 \--learning_rate=0.001 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR}/all \--eval_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \--train_dir=./data/flowers-models/resnet_v1_50\--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=./data/flowers \--model_name=resnet_v1_50 \--checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \ --batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54321.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis五大数据类型

Redis五大数据类型 Redis-Key 官网:https://www.redis.net.cn/order/ 序号命令语法描述1DEL key该命令用于在 key 存在时删除 key2DUMP key序列化给定 key ,并返回被序列化的值3EXISTS key检查给定 key 是否存在,存在返回1,否则返…

yolov8热力图可视化

安装pytorch_grad_cam pip install grad-cam自动化生成不同层的bash脚本 # 循环10次,将i的值从0到9 for i in $(seq 0 13) doecho "Running iteration $i";python yolov8_heatmap.py $i; done热力图生成python代码 import warnings warnings.filterwarn…

vscode流程图插件使用

vscode流程图插件使用 1.在vscode中点击左下角设置然后选择扩展。 2.在扩展中搜索Draw.io Integration,安装上面第一个插件。 3.安装插件后在工程中创建一个后缀为drawio的文件并且双击打开即可绘制流程图

2023-08-26 LeetCode每日一题(汇总区间)

2023-08-26每日一题 一、题目编号 228. 汇总区间二、题目链接 点击跳转到题目位置 三、题目描述 给定一个 无重复元素 的 有序 整数数组 nums 。 返回 恰好覆盖数组中所有数字 的 最小有序 区间范围列表 。也就是说,nums 的每个元素都恰好被某个区间范围所覆盖…

如何在地图上寻找最密集点的位置?

最近我在工作中遇到了一个小的需求点,大概是需要在地图上展示出一堆点中的点密度最密集的位置。最开始没想到好的方法,就使用了一个非常简单的策略——所有点的坐标求平均值,这个方法大部分的时候好用,因为大部分城市所有点位基本…

深度学习4. 循环神经网络 – Recurrent Neural Network | RNN

目录 循环神经网络 – Recurrent Neural Network | RNN 为什么需要 RNN ?独特价值是什么? RNN 的基本原理 RNN 的优化算法 RNN 到 LSTM – 长短期记忆网络 从 LSTM 到 GRU RNN 的应用和使用场景 总结 百度百科维基百科 循环神经网络 – Recurre…

【手写promise——基本功能、链式调用、promise.all、promise.race】

文章目录 前言一、前置知识二、实现基本功能二、实现链式调用三、实现Promise.all四、实现Promise.race总结 前言 关于动机,无论是在工作还是面试中,都会遇到Promise的相关使用和原理,手写Promise也有助于学习设计模式以及代码设计。 本文主…

WPF基础入门-Class5-WPF命令

WPF基础入门 Class5-WPF命令 1、xaml编写一个button&#xff0c;Command绑定一个命令 <Grid><ButtonWidth"100"Height"40" Command"{Binding ShowCommand}"></Button> </Grid>2、编写一个model.cs namespace WPF_Le…

【LeetCode-面试经典150题-day15】

目录 104.二叉树的最大深度 100.相同的树 226.翻转二叉树 101.对称二叉树 105.从前序与中序遍历序列构造二叉树 106.从中序与后序遍历序列构造二叉树 117.填充每个节点的下一个右侧节点指针Ⅱ 104.二叉树的最大深度 题意&#xff1a; 给定一个二叉树 root &#xff0c;返回其…

STM32F103 4G Cat.1模块EC200S使用

一、简介 EC200S-CN 是移远通信最近推出的 LTE Cat 1 无线通信模块&#xff0c;支持最大下行速率 10Mbps 和最大上行速率 5Mbps&#xff0c;具有超高的性价比&#xff1b;同时在封装上兼容移远通信多网络制式 LTE Standard EC2x&#xff08;EC25、EC21、EC20 R2.0、EC20 R2.1&a…

用大白话来讲讲多线程的知识架构

感觉多线程的知识又多又杂&#xff0c;自从接触java&#xff0c;就在一遍一遍捋脉络和深入学习。现在将这次的学习成果展示如下。 什么是多线程&#xff1f; 操作系统运行一个程序&#xff0c;就是一个线程。同时运行多个程序&#xff0c;就是多线程。即在同一时间&#xff0…

基于FPGA的Lorenz混沌系统verilog开发,含testbench和matlab辅助测试程序

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 将vivado的仿真结果导入到matlab显示三维混沌效果&#xff1a; 2.算法运行软件版本 vivado2019.2 matlab2022a 3.部分核心程序 testbench如下所…

npm常用命令 + 前端常用的包管理工具 以及 npm淘宝镜像配置等

npm常用命令 前端常用的包管理工具 以及 npm淘宝镜像配置等 1. 前言1.1 NodeJs的下载安装1.2 windows上1.3 常用包管理工具 2. npm2.1 npm 的安装2.2 npm初始化包2.3 npm 安装、卸载包2.3.1 非全局安装2.3.1.1 单个包的安装2.3.1.1.1 默认版本安装2.3.1.1.2 指定版本安装 2.3.…

解除用户账户控制提醒

解决用户账户控制提醒 1. 前言2. 解决用户账户控制提醒2.1 控制面板2.2 注册表2.3 UAC服务 结束语 1. 前言 当我们使用电脑时&#xff0c;有时进行安装应用或者打开应用时&#xff0c;总会弹出一个提示框&#xff0c;要选择点击是否允许程序运行&#xff1b; 系统经常弹出用户…

【Git】测试持续集成——Git+Gitee+PyCharm

文章目录 概述一、使用Gitee1. 注册账号2. 绑定邮箱3. 新建仓库4. 查看项目地址 二、安装配置Git1. 下载安装包2. 校验是否安装成功。3. 配置Git4. Git命令5. Git实操 三、PyCharmGit1. 配置Git2. Clone项目3. 提交文件到服务器4. 从服务器拉取文件 概述 持续集成&#xff08;…

【javaweb】学习日记Day4 - Maven 依赖管理 Web入门

目录 一、Maven入门 - 管理和构建java项目的工具 1、IDEA如何构建Maven项目 2、Maven 坐标 &#xff08;1&#xff09;定义 &#xff08;2&#xff09;主要组成 3、IDEA如何导入和删除项目 二、Maven - 依赖管理 1、依赖配置 2、依赖传递 &#xff08;1&#xff09;查…

Docker容器学习:Dockerfile制作Web应用系统nginx镜像

目录 编写Dockerfile 1.文件内容需求&#xff1a; 2.编写Dockerfile&#xff1a; 3.开始构建镜像 4.现在我们运行一个容器&#xff0c;查看我们的网页是否可访问 推送镜像到私有仓库 1.把要上传的镜像打上合适的标签 2.登录harbor仓库 3.上传镜像 编写Dockerfile 1.文…

2000-2021年地级市产业升级、产业结构高级化面板数据

2000-2021年地级市产业升级、产业结构高级化面板数据 1、时间&#xff1a;2000-2021年 2、范围&#xff1a;地级市 3、指标&#xff1a;年份、地区、行政区划代码、地区、所属省份、地区生产总值、第一产业增加值、第二产业增加值、第三产业增加值、第一产业占GDP比重、第二…

常见的时序数据库

1.概念 时序数据库全称为时间序列数据库。时间序列数据库指主要用于处理带时间标签&#xff08;按照时间的顺序变化&#xff0c;即时间序列化&#xff09;的数据&#xff0c;带时间标签的数据也称为时间序列数据。 时间序列数据主要由电力行业、化工行业、气象行业、地理信息…

十四、pikachu之XSS

文章目录 1、XSS概述2、实战2.1 反射型XSS&#xff08;get&#xff09;2.2 反射型XSS&#xff08;POST型&#xff09;2.3 存储型XSS2.4 DOM型XSS2.5 DOM型XSS-X2.6 XSS之盲打2.7 XSS之过滤2.8 XSS之htmlspecialchars2.9 XSS之href输出2.10 XSS之JS输出 1、XSS概述 Cross-Site S…