TensorFlow 2.0 - TFRecord存储数据集、@tf.function图执行模式、tf.TensorArray、tf.config分配GPU

文章目录

    • 1. TFRecord 格式存储
    • 2. tf.function 高性能
    • 3. tf.TensorArray 支持计算图特性
    • 4. tf.config 分配GPU

学习于:简单粗暴 TensorFlow 2

1. TFRecord 格式存储

  • 使用该种格式,更高效地进行大规模的模型训练

import random
import os
import tensorflow as tf# 使用前一节 kaggle 上的 猫狗数据集
train_data_dir = "./dogs-vs-cats/train/"
test_data_dir = "./dogs-vs-cats/test/"# 训练文件路径
file_dir = [train_data_dir + filename for filename in os.listdir(train_data_dir)]
labels = [0 if filename[0] == 'c' else 1for filename in os.listdir(train_data_dir)]# 打包并打乱
f_l = list(zip(file_dir, labels))
random.shuffle(f_l)
file_dir, labels = zip(*f_l)# 切分训练集,验证集
valid_ratio = 0.1
idx = int((1 - valid_ratio) * len(file_dir))
train_files, valid_files = file_dir[:idx], file_dir[idx:]
train_labels, valid_labels = labels[:idx], labels[idx:]# tfrecord 格式数据存储路径
train_tfrecord_file = "./dogs-vs-cats/train.tfrecords"
valid_tfrecord_file = "./dogs-vs-cats/valid.tfrecords"# -------------------看下面代码-----------------------------
# 存储过程
# 预先定义一个写入器
with tf.io.TFRecordWriter(path=train_tfrecord_file) as writer:# 遍历原始数据for filename, label in zip(train_files, train_labels):img = open(filename, 'rb').read()  # 读取图片,img 是 Byte 类型的字符串# 建立 feature 的 字典 k : vfeature = {'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}# feature 包裹成 exampleexample = tf.train.Example(features=tf.train.Features(feature=feature))# example 序列化为字符串,写入writer.write(example.SerializeToString())# -------------------看下面代码-----------------------------
# 读取过程
# 读取 tfrecord 数据,得到 tf.data.Dataset 对象
raw_train_dataset = tf.data.TFRecordDataset(train_tfrecord_file)
# 特征的格式、数据类型
feature_description = {'image': tf.io.FixedLenFeature(shape=[], dtype=tf.string),'label': tf.io.FixedLenFeature([], tf.int64),
}def _parse_example(example_string): # 解码每个example# tf.io.parse_single_example 反序列化feature_dict = tf.io.parse_single_example(example_string, feature_description)# 图像解码feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])# 返回数据 X, yreturn feature_dict['image'], feature_dict['label']# 处理数据集
train_dataset = raw_train_dataset.map(_parse_example)import matplotlib.pyplot as plt
for img, label in train_dataset:plt.title('cat' if label==0 else 'dog')plt.imshow(img.numpy())plt.show()

2. tf.function 高性能

  • TF 2.0 默认 即时执行模式(Eager Execution),灵活、易调试
  • 追求高性能、部署模型时,使用图执行模式(Graph Execution)
  • TF 2.0 的 tf.function 模块 + AutoGraph 机制,使用 @tf.function 修饰符,就可以将模型以图执行模式运行

注意:@tf.function修饰的函数内,尽量只用 tf 的内置函数,变量只用 tensor、numpy 数组

  • 被修饰的函数 F(X, y) 可以调用get_concrete_function 方法,获得计算图
graph = F.get_concrete_function(X, y)

3. tf.TensorArray 支持计算图特性

  • tf.TensorArray 支持计算图模式的 动态数组
arr = tf.TensorArray(dtype=tf.int64, size=1, dynamic_size=True)
arr = arr.write(index=1, value=512)
# arr.write(index=0, value=512) # 没有左值接受,会丢失
for i in range(arr.size()):print(arr.read(i))

4. tf.config 分配GPU

  • 列出设备 list_physical_devices
print('---device----')
gpus = tf.config.list_physical_devices(device_type='GPU')
cpus = tf.config.list_physical_devices(device_type='CPU')
print(gpus, "\n", cpus)
# 单个的 GPU, CPU
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')] [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
  • 设置哪些可见 set_visible_devices
tf.config.set_visible_devices(devices=gpus[0:2], device_type='GPU')

或者

  • 终端输入 export CUDA_VISIBLE_DEVICES=2,3
  • or 代码中加入
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2,3"

指定程序 只在 显卡 2, 3 上运行

  • 显存使用策略:
gpus = tf.config.list_physical_devices(device_type='GPU')
for gpu in gpus:# 仅在需要时申请显存tf.config.experimental.set_memory_growth(device=gpu, enable=True)
gpus = tf.config.list_physical_devices(device_type='GPU')
# 固定显存使用上限,超出报错
tf.config.set_logical_device_configuration(gpus[0],[tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
  • 单 GPU 模拟多 GPU 环境

在单GPU电脑上,写 多GPU 代码,可以模拟实现

gpus = tf.config.list_physical_devices('GPU')
tf.config.set_logical_device_configuration(gpus[0],[tf.config.LogicalDeviceConfiguration(memory_limit=2048),tf.config.LogicalDeviceConfiguration(memory_limit=2048)])
gpus = tf.config.list_logical_devices(device_type='GPU')
print(gpus)

输出:2个虚拟的GPU

[LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU')]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/472994.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++ qt qlistwidget清空_Qt编写控件属性设计器12-用户属性

一、前言用户属性是后面新增加的一个功能,自定义控件如果采用的Q_PROPERTY修饰的属性,会自动识别到属性栏中,这个一般称为控件属性,在组态设计软件中,光有控件本身的控件属性还是不够的,毕竟这些属性仅仅是…

TensorFlow 2.0 - tf.saved_model.save 模型导出

文章目录1. tf.saved_model.save2. Keras API 模型导出学习于:简单粗暴 TensorFlow 2 1. tf.saved_model.save tf.train.Checkpoint 可以保存和恢复模型中参数的权值导出模型:包含参数的权值,计算图 无须源码即可再次运行模型,适…

机器人动力学与控制_力控制与位置控制的区别

1.背景介绍目前已经广泛落地的力控制方案是在机械臂末端安装多轴力矩传感器,用以检测机械臂对外界环境施加的力反馈值,并配合适当的控制策略,已达到控制机械臂与环境的作用力。这篇文章所要探讨的力控制(上述力控制方案&#xff0…

基坑监测日报模板_基坑监测有多重要?实录基坑坍塌过程,不亲身经历,不知道现场有多恐怖!...

基坑整体坍塌不亲身经历,不知其恐怖▼前段时间,南宁绿地中央广场房地产项目D号地块(二期)基坑北侧约60米支护桩突然崩塌!所幸无人伤亡。深基坑施工安全生产管理要点一、基坑开挖 1、 临边防护(1)基坑施工必须按要求进行,具体临边防…

[转]asp.net导出数据到Excel的三种方法

原文出处:asp.net导出数据到Excel的几种方法(1/3) 、asp.net导出数据到Excel的几种方法(2/3)、asp.net导出数据到Excel的几种方法(3/3) asp.net导出到Excel也是个老生常谈的问题,在此归纳一下。 第一种是比较常用的方法。是利用控件的RenderControl功能…

LintCode 378. 将二叉树转换成双链表(非递归遍历)

文章目录1. 题目2. 解题1. 题目 将一个二叉树按照中序遍历转换成双向链表。 样例 样例 1&#xff1a; 输入:4/ \2 5/ \1 3 输出: 1<->2<->3<->4<->5样例 2&#xff1a; 输入:3/ \4 1输出:4<->3<->1https://www.lintcode.com/pro…

js 将图片置灰_艾叶灰千万别扔——艾叶灰的神奇功效

请 点 上面“经络技巧”免费关注每晚9点准时免费更新点击下面图片阅读↓↓↓—— 以下是正文 ——艾灰的妙用1、宝宝经常会有红屁股&#xff0c;做妈妈的当然心疼&#xff0c;用了不少膏啊霜啊油啊&#xff0c;效果也是反反复复&#xff0c;尤其害怕会有依赖性。在妈妈的提醒下…

LintCode 434. 岛屿的个数II(并查集)

文章目录1. 题目2. 解题1. 题目 给定 n, m, 分别代表一个二维矩阵的行数和列数, 并给定一个大小为 k 的二元数组A. 初始二维矩阵全0. 二元数组A内的k个元素代表k次操作, 设第 i 个元素为 (A[i].x, A[i].y), 表示把二维矩阵中下标为A[i].x行A[i].y列的元素由海洋变为岛屿. 问在…

jqprintsetup已经安装还会提示_Windows 10更新将修复困扰用户已久的循环安装问题...

对于某些设备的用户来说&#xff0c;过去一年一直深受 Windows Update 陷入循环更新的问题困扰&#xff0c;尤其是那些使用英特尔驱动程序的设备。问题在于 Windows Update 会提示错误地提供不适配的驱动或版本&#xff0c;并且强行覆盖安装。此外即便用户已经安装了更新更好的…

springboot设置运行内存_Docker 如何运行多个 Springboot?

docker 如何运行多个Springboot &#xff1f;第一个&#xff1a;端口映射第二个&#xff1a;指定内存大小第三个&#xff1a;读取、写入物理文件第四个&#xff1a;日志文件第五个&#xff1a;多个容器内部网络访问第六个&#xff1a;遇到的问题第一个&#xff1a;端口映射Ngin…

LintCode 1915. 举重(01背包)

文章目录1. 题目2. 解题1. 题目 奥利第一次来到健身房&#xff0c;她正在计算她能举起的最大重量。 杠铃所能承受的最大重量为maxCapacity&#xff0c;健身房里有 n 个杠铃片&#xff0c;第 i 个杠铃片的重量为 weights[i]。 奥利现在需要选一些杠铃片加到杠铃上&#xff0c;使…

python实现简单线性回归和多元线性回归算法

1、问题引入 在统计学中&#xff0c;线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。一个带有一个自变量的线性回归方程代表一条直线。我们需要对线性回归结…

form表单通过checkbox_飞冰表单解决方案 - FormBinder

前言中后台业务场景中&#xff0c;表单是一种很常见的与用户交互的方式&#xff0c;从业务角度看&#xff0c;表单主要是收集用户的信息&#xff0c;而从技术角度看&#xff0c;作为一个通用型的组件&#xff0c;它要解决的问题无非就是三个&#xff1a;把一个初始数据对象扔给…

@data 重写set方法_C#中的类、方法和属性

这节讲C#中的类&#xff0c;方法&#xff0c;属性。这是面向对象编程中&#xff0c;我们最直接打交道的三个结构。类&#xff1a;类(class)是面向对象中最基本的单元&#xff0c;它是一种抽象&#xff0c;对现实世界中事物的抽象&#xff0c;在C#中使用class关键字声明一个类&a…

Docker安装+镜像拉取+容器+创建镜像+push to docker hub

文章目录1. 安装2. 镜像操作3. 容器4. docker hub本文参考&#xff1a;https://zhuanlan.zhihu.com/p/23599229 1. 安装 参考 https://www.runoob.com/docker/ubuntu-docker-install.html curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun测试&#xff…

css 百分比 怎么固定正方形_你未必知道的49个CSS知识点

本文的每一条&#xff0c;都是我曾经发过的掘金沸点&#xff0c;其中有很多条超过了百赞(窃喜)。鉴于时不时有童鞋翻我以前的沸点&#xff0c;因此&#xff0c;本文收集了个人目前发过的所有CSS知识点动图&#xff0c;以便阅读。需要说明的是&#xff0c;顺序仍是按当时发布顺序…

CSS 实现加载动画之五-光盘旋转

今天做的这个动画叫光盘旋转&#xff0c;名字自己取的。动画的效果估计很多人都很熟悉&#xff0c;就是微信朋友圈里的加载动画。做过前面几个动画&#xff0c;发现其实都一个原理&#xff0c;就是如何将动画的元素如何分离出来。这个动画的实现也很简单&#xff0c;关键点在于…

css hover变成手_web前端入门到实战:彻底掌握css动画「transition」

马上就2020年了&#xff0c;不知道小伙伴们今年学习了css3动画了吗&#xff1f;说起来css动画是一个很尬的事&#xff0c;一方面因为公司用css动画比较少&#xff0c;另一方面大部分开发者习惯了用JavaScript来做动画&#xff0c;所以就导致了许多程序员比较排斥来学习css动画(…

用Docker部署TensorFlow Serving服务

文章目录1. 安装 Docker2. 使用 Docker 部署3. 请求服务3.1 手写数字例子3.2 猫狗分类例子参考&#xff1a; https://tf.wiki/zh_hans/deployment/serving.html# https://tensorflow.google.cn/tfx/serving/docker 1. 安装 Docker 以下均为 centos7 环境 参考文章&#xff1a…

K-Means算法和K-Means++算法的聚类

在构成圆形的30000个随机样本点上&#xff0c;设置7个簇&#xff0c;使用K-Means算法聚类 from math import pi, sin, cos from collections import namedtuple from random import random, choice from copy import copy import matplotlib.pyplot as plt import numpy as np…