TensorFlow2实战-系列教程8:TFRecords数据源制作2

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

5、图像数据处理实例

5.1 读数据

import os
import glob
from datetime import datetimeimport cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
image_path = '../img/'
images = glob.glob(image_path + '*.jpg')for fname in images:image = mpimg.imread(fname)f, (ax1) = plt.subplots(1, 1, figsize=(8,8))f.subplots_adjust(hspace = .2, wspace = .05)ax1.imshow(image)ax1.set_title('Image', fontsize=20)image_labels = {'dog': 0,'kangaroo': 1,
}

5.2 制作TFRecord

# 读数据,binary格式
image_string = open('./img/dog.jpg', 'rb').read()
label = image_labels['dog']

打开一张图像和它对应的标签

def _bytes_feature(value):"""Returns a bytes_list from a string/byte."""if isinstance(value, type(tf.constant(0))):value = value.numpy() # BytesList won't unpack a string from an EagerTensor.return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _float_feature(value):"""Return a float_list form a float/double."""return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))def _int64_feature(value):"""Return a int64_list from a bool/enum/int/uint."""return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

前面3个处理字符、浮点数、整型的函数

# 创建图像数据的Example
def image_example(image_string, label):image_shape = tf.image.decode_jpeg(image_string).shapefeature = {'height': _int64_feature(image_shape[0]),'width': _int64_feature(image_shape[1]),'depth': _int64_feature(image_shape[2]),'label': _int64_feature(label),'image_raw': _bytes_feature(image_string),}return tf.train.Example(features=tf.train.Features(feature=feature))

定义一个函数,指定要保存的指标,以及要用什么格式存这批数据
image_shape 就是图像长、宽、通道数
feature 中定义了图像h、w、c、标签、矩阵特征
最后构建Example返回一条数据

#打印部分信息
image_example_proto = image_example(image_string, label)for line in str(image_example_proto).split('\n')[:15]:print(line)
print('...')

调用刚刚的函数,传进实际的数据和标签
把转换的数据打印出来

# 制作 `images.tfrecords`.image_path = './img/'
images = glob.glob(image_path + '*.jpg')
record_file = 'images.tfrecord'
counter = 0with tf.io.TFRecordWriter(record_file) as writer:for fname in images:with open(fname, 'rb') as f:image_string = f.read()label = image_labels[os.path.basename(fname).replace('.jpg', '')]tf_example = image_example(image_string, label)writer.write(tf_example.SerializeToString())counter += 1print('Processed {:d} of {:d} images.'.format(counter, len(images)))print(' Wrote {} images to {}'.format(counter, record_file))
  1. 指定所有数据的路径
  2. 指定最后保存的数据的路径和名称
  3. 计数器
  4. 打开TFRecordWriter,准备写数据
  5. 遍历所有的图像数据路径
  6. 打开当前路径的文件
  7. 读取图像
  8. 映射图像文件名(无扩展名)到标签
  9. 使用 image_example 函数(需要事先定义)创建 tf.Example 对象
  10. 将其序列化后写入 TFRecord 文件
  11. 每处理一个图像文件,计数器增加
  12. 并打印出已处理的图像数量
  13. 所有图像处理完成后,打印出写入 TFRecord 文件的总图像数量

打印结果:

Processed 1 of 2 images.
Processed 2 of 2 images.
Wrote 2 images to images.tfrecord

5.3 加载制作好的TFRecord

raw_train_dataset = tf.data.TFRecordDataset('images.tfrecord')
raw_train_dataset

<TFRecordDatasetV2 shapes: (), types: tf.string>

example数据都进行了序列化,还需要解析以下之前写入的序列化string,即反序列化

  • tf.io.parse_single_example(example_proto, feature_description)函数可以解析单条example

这个函数是专门用来解析图像数据的

# 解析的格式需要跟之前创建example时一致
image_feature_description = {'height': tf.io.FixedLenFeature([], tf.int64),'width': tf.io.FixedLenFeature([], tf.int64),'depth': tf.io.FixedLenFeature([], tf.int64),'label': tf.io.FixedLenFeature([], tf.int64),'image_raw': tf.io.FixedLenFeature([], tf.string),
}

现在看起来仅仅完成了一个样本的解析,实际数据不可能一个个来写吧,可以定义一个映射规则map函数

解析的格式,要和原来一致

def parse_tf_example(example_proto):parsed_example = tf.io.parse_single_example(example_proto, image_feature_description)x_train = tf.image.decode_jpeg(parsed_example['image_raw'], channels=3)x_train = tf.image.resize(x_train, (416, 416))x_train /= 255.lebel = parsed_example['label']y_train = lebelreturn x_train, y_traintrain_dataset = raw_train_dataset.map(parse_tf_example)
train_dataset
  1. 定义一个专门用来解析的函数
  2. 传进解析的样本、解析的对照关系image_feature_description
  3. 进行预处理操作,将原始的二进制 JPEG 图像数据解码为Tensor
  4. 将图像大小调整为 416x416 像素
  5. 将图像数据归一化到 0 到 1 的范围
  6. 提取 label 字段并赋值给 y_train
  7. 返回处理后的图像数据和标签
  8. 将 parse_tf_example 函数应用于原始的训练数据集 raw_train_dataset,map 函数会对数据集中的每个元素应用 parse_tf_example 函数

打印结果:

<MapDataset shapes: ((416, 416, 3), ()), types: (tf.float32, tf.int64)>

5.4 制作训练集

num_epochs = 10train_ds = train_dataset.shuffle(buffer_size=10000).batch(2).repeat(num_epochs)
for batch, (x, y) in enumerate(train_ds):print(batch, x.shape, y)

把数据转换成batch形式,然后把转换的数据打印出来
打印结果:

0 (2, 416, 416, 3) tf.Tensor([0 1], shape=(2,), dtype=int64)
1 (2, 416, 416, 3) tf.Tensor([0 1], shape=(2,), dtype=int64)

8 (2, 416, 416, 3) tf.Tensor([1 0], shape=(2,), dtype=int64)
9 (2, 416, 416, 3) tf.Tensor([1 0], shape=(2,), dtype=int64)

model = tf.keras.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(2, activation='softmax')
])
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy'])
model.fit(train_ds, epochs=num_epochs)

定义一个简单的模型、训练器,然后进行训练
打印结果:

Epoch 1/10 10/10 1s 51ms/step - loss: 55.1923 - accuracy: 0.6500

Epoch 9/10 10/10 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Epoch 10/10 10/10 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
<tensorflow.python.keras.callbacks.History at 0x274f2524400>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/655723.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Qt】QInputDialog setGeometry: Unable to set geometry 问题

QInputDialog setGeometry: Unable to set geometry 问题 文章目录 I - 问题背景II - 解决办法III - 参考链接 I - 问题背景 创建了一个 QMainWindow 并在上边创建了布局&#xff0c;尝试调用 QInputDialog 的 getInt 静态方法&#xff0c;结果运行时出现了以下警告 QWindows…

电商API接口的应用|电商跨境电商商品采集高效解决方案

电商API接口的应用|电商跨境电商商品采集高效解决方案 面对数十万亿元的跨境电商市场&#xff0c;以阿里巴巴国际站为代表的跨境电商数字平台&#xff0c;在政策、需求以及供应链的驱动下&#xff0c;为中小企业提供了全产业链、全供应链一体化综合服务&#xff0c;让越来越多…

机器学习复习(1)——任务整理流程

目录 固定的随机数种子 定义predict功能 拆分数据集 定义trainer 超参数设置 数据集载入 固定的随机数种子 在大量的机器学习与深度学习实验中&#xff0c;如果不进行特殊设置&#xff0c;我们的结果将不可复现&#xff0c;固定的随机数种子将会解决这个问题 def same…

字符串相关函数和文件操作

文章目录 1. C/C 字符串概述1.1 字符串常量1.2 字符数组 2. 字符串函数2.1 拷贝赋值功能相关函数&#xff08;覆盖&#xff09;2.1.1 strcpy2.1.2 strncpy2.1.3 memcpy2.1.4 memmove2.1.5 memset2.1.6 注意小点2.1.7 【函数区别】 2.2 追加功能相关函数2.2.1 strcat2.2.2 strnc…

使用plotly dash 画3d圆柱(Python)

plotly3D &#xff08;3d charts in Python&#xff09;可以画3维图形 在做圆柱的3D装箱项目&#xff0c;需要装箱的可视化&#xff0c;但是Mesh &#xff08;3d mesh plots in Python&#xff09;只能画三角形&#xff0c;所以需要用多个三角形拼成一个圆柱&#xff08;想做立…

Python qt.qpa.xcb: could not connect to display解决办法

遇到问题&#xff1a;qt.qpa.xcb: could not connect to display 解决办法&#xff0c;在命令行输入&#xff1a; export DISPLAY:0 然后重新跑python程序&#xff0c;解决&#xff01; 参考博客&#xff1a;qt.qpa.xcb: could not connect to displayqt.qpa.plugin: Could …

Ubuntu搭建国标平台wvp-GB28181-pro

目录 简介安装和编译1.查看操作系统信息2.安装最新版的nodejs3.安装java环境4.安装mysql5.安装redis6.安装编译器7.安装cmake8.安装依赖库9.编译ZLMediaKit9.1.编译结果说明 10.编译wvp-GB28181-pro10.1.编译结果说明 配置1.WVP-PRO配置文件1.1.Mysql数据库配置1.2.REDIS数据库…

监听项目中指定属性数据,点击或模块显示时

当项目中&#xff0c;需要获取某个页面上、某个标签上、有指定自定义属性时&#xff0c;需要在点击该元素时进行公共逻辑处理&#xff0c;或该元素在显示的时候进行逻辑处理&#xff0c;这时可以定义一个公共的方法&#xff0c;在每个页面引用&#xff0c;并写入数据即可 &…

OpenHarmony RK3568 启动流程优化

目前rk3568的开机时间有21s&#xff0c;统计的是关机后从按下 power 按键到显示锁屏的时间&#xff0c;当对openharmony的系统进行了裁剪子系统&#xff0c;系统app&#xff0c;禁用部分服务后发现开机时间仅仅提高到了20.94s 优化微乎其微。在对init进程的log进行分析并解决其…

【Spring Boot 3】异步线程任务

【Spring Boot 3】异步线程任务 背景介绍开发环境开发步骤及源码工程目录结构总结背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工作经历中,每次学习新技术总是要花费或多或…

面向云服务的GaussDB全密态数据库

前言 全密态数据库&#xff0c;顾名思义与大家所理解的流数据库、图数据库一样&#xff0c;就是专门处理密文数据的数据库系统。数据以加密形态存储在数据库服务器中&#xff0c;数据库支持对密文数据的检索与计算&#xff0c;而与查询任务相关的词法解析、语法解析、执行计划生…

【工具】raw与jpg互转python-cpp

在工作中常常需要将图像转化为raw数据或者yuv数据&#xff0c;这里将提供 cpp 版本和 python 版本的互转代码 代码链接见文档尾部。 cpp 版本 jpg2raw.cpp #include <fstream> #include <iostream> #include <opencv2/core.hpp> #include <opencv2/hig…

oracle版本号中的i,G,C代表什么含义

大家都熟悉的 Oracle 版本号有 9i、10G、11G、12C、19C 等&#xff0c;但在早期&#xff0c;Oracle 的版本号并不包含这些字母。 最初&#xff0c;Oracle 的版本号简单地是 1、2、3、4 等&#xff0c;一直发展到 1999 年发布的 8i 版本。20 世纪末是互联网爆发式发展的时代。 …

将一个excel文件里面具有相同参数的行提取后存入新的excel

功能描述&#xff1a; 一个excel里面有很多行数据&#xff0c;其中“交易时间”这一列有很多交易日期&#xff0c;有些行的交易日期是一样的&#xff0c;那么就把所有交易日期相同的行挑出来&#xff0c;形成一个新的以交易日期命名的文件。import pandas as pd import os# 读取…

跨境ERP定制趋势预测:数字化转型助您赢得市场先机

随着全球贸易的不断融合和发展&#xff0c;跨境业务已成为许多企业拓展市场的重要途径。在这个背景下&#xff0c;ERP定制正逐渐成为企业数字化转型的关键利器。本文将为您预测跨境ERP定制的趋势&#xff0c;并探讨数字化转型如何助您赢得市场先机。 ERP定制趋势预测 1. 数据…

命令行启动Android Studio模拟器

1、sdk路径查看&#xff08;打开Android Studio&#xff09; 以上前提是安装的Android Studio并添加了模拟器&#xff01;&#xff01;&#xff01; 2、复制路径在终端进入到 cd /Users/duxi/Library/Android/sdk目录&#xff08;命令行启动不用打开Android Studio就能运行模拟…

【Java程序设计】【C00182】基于SSM的高校成绩报送管理系统(论文+PPT)

基于SSM的高校成绩报送管理系统&#xff08;论文PPT&#xff09; 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于ssm的高校成绩报送系统 本系统分为前台系统、管理员、教师以及学生4个功能模块。 前台系统&#xff1a;当游客打开系统的网址后&#xff0c;首…

25考研北大软微该怎么做?

25考研想准备北大软微&#xff0c;那肯定要认真准备了 考软微需要多少实力 现在的软微已经不是以前的软微了&#xff0c;基本上所有考计算机的同学都知道&#xff0c;已经没有什么信息优势了&#xff0c;只有实打实的有实力的选手才建议报考。 因为软微的专业课也是11408&am…

PyTorch自动微分机制的详细介绍

PyTorch深度学习框架的官方文档确实提供了丰富的信息来阐述其内部自动微分机制。在PyTorch中&#xff0c;张量&#xff08;Tensor&#xff09;和计算图&#xff08;Computation Graph&#xff09;的设计与实现使得整个系统能够支持动态的、高效的自动求导过程。 具体来说&#…

掌握Java多线程利器:ConcurrentHashMap详解

在并发编程的世界里&#xff0c;每一个微小的延迟都可能积累成为性能瓶颈。今天&#xff0c;让我们一起揭开Java中ConcurrentHashMap的神秘面纱&#xff0c;这是一个在多线程环境中不可或缺的高性能组件。从它的设计理念到底层实现&#xff0c;我们将详细探讨ConcurrentHashMap…