TensorFlow2实战-系列教程2:神经网络分类任务

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、Mnist数据集

下载mnist数据集:

%matplotlib inline
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)

制作数据:

import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

简单展示数据:

from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
print(y_train[0])

打印结果:

(50000, 784)
5

在这里插入图片描述

2、模型构建

在这里插入图片描述
在这里插入图片描述
输入为784神经元,经过隐层提取特征后为10个神经元,10个神经元的输出值经过softmax得到10个概率值,取出10个概率值中最高的一个就是神经网络的最后预测值

构建模型代码:

import tensorflow as tf
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

选择损失函数,损失函数是机器学习一个非常重要的部分,基本直接决定了这个算法的效果,这里是多分类任务,一般我们就直接选用多元交叉熵函数就好了:
TensorFlow损失函数API

编译模型:

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  1. adam优化器,学习率为0.001
  2. 多元交叉熵损失函数
  3. 评价指标

模型训练:

model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_valid, y_valid))

训练数据,训练标签,训练轮次,batch_size,验证集

打印结果:

Train on 50000 samples, validate on 10000 samples
Epoch 1/5 50000/50000  1s 29us
sample-loss: 115566 - sparse_categorical_accuracy: 0.1122 - val_loss: 364928.5786 - val_sparse_categorical_accuracy: 0.1064
Epoch 2/5 50000/50000 1s 21us
sample - loss: 837104 - sparse_categorical_accuracy: 0.1136 - val_loss: 1323287.7028 - val_sparse_categorical_accuracy: 0.1064
Epoch 3/5 50000/50000 1s 20us
sample - loss: 1892431 - sparse_categorical_accuracy: 0.1136 - val_loss: 2448062.2680 - val_sparse_categorical_accuracy: 0.1064
Epoch 4/5 50000/50000 1s 20us
sample - loss: 3131130 - sparse_categorical_accuracy: 0.1136 - val_loss: 3773744.5348 - val_sparse_categorical_accuracy: 0.1064
Epoch 5/5 50000/50000 1s 20us
sample - loss: 4527781 - sparse_categorical_accuracy: 0.1136 - val_loss: 5207194.3728 - val_sparse_categorical_accuracy: 0.1064
<tensorflow.python.keras.callbacks.History at 0x1d3eb9015f8>

模型保存:

model.save('Mnist_model.h5')

3、TensorFlow常用模块

3.1 Tensor格式转换

创建一组数据

import numpy as np
input_data = np.arange(16)
input_data

打印结果:
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

转换成TensorFlow格式的数据:

dataset = tf.data.Dataset.from_tensor_slices(input_data)
for data in dataset:print (data)

将一个ndarray转换成
打印结果:
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

3.2repeat操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2)
for data in dataset:print (data)

tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

会将当前的输出重复一遍

3.3 batch操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2).batch(4)
for data in dataset:print (data)

tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)
tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)

将原来的数据按照4个为一个批次

3.4 shuffle操作

dataset = tf.data.Dataset.from_tensor_slices(input_data).shuffle(buffer_size=10).batch(4)
for data in dataset:print (data)

tf.Tensor([ 9 8 11 3], shape=(4,), dtype=int32)
tf.Tensor([ 5 6 1 13], shape=(4,), dtype=int32)
tf.Tensor([14 15 4 2], shape=(4,), dtype=int32)
tf.Tensor([12 7 0 10], shape=(4,), dtype=int32)

shuffle操作,直接翻译过来就是洗牌,把当前的数据进行打乱操作
buffer_size=10,就是缓存10来进行打乱取数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/649694.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用代码取大量2*2像素图片各通道均值,存于Excel文件中。

任务是取下图RGB各个通道的均值及标签&#xff08;R, G&#xff0c;B&#xff0c;Label&#xff09;,其中标签由图片存放的文件夹标识。由于2*2像素图片较多&#xff0c;所以将结果放置于Excel表格中&#xff0c;之后使用SVM对他们进行分类。 from PIL import Image import os …

【Linux】查看硬件信息和操作系统信息、安装的应用信息

【Linux】查看硬件信息和操作系统信息、安装的应用信息 一、硬件信息 1.1 CPU信息 cat /proc/cpuinfo #查看 processor : 0 // 逻辑处理器的唯一标识符 physical id : 0 // 硬件上真实存在的CPU siblings : 1 // 一个物理CPU有几个逻辑CPU cpu…

定向减免!函数计算让轻量 ETL 数据加工更简单,更省钱

作者&#xff1a;澈尔、墨飏 业内较为常见的高频短时 ETL 数据加工场景&#xff0c;即频率高时延短&#xff0c;一般均可归类为调用密集型场景。此场景有着高并发、海量调用的特性&#xff0c;往往会产生高额的计算费用&#xff0c;而业内推荐方案一般为攒批处理&#xff0c;业…

ChatGPT+Midjourney+闲鱼赚钱方法实战探索

最近天天在朋友群内看到朋友接单(出售提示词&#xff0c;图片&#xff09;&#xff0c;轻轻松松半小时就赚200-300&#xff0c;特意探索了一下相关玩法&#xff0c;总结出一套ChatGPTMidjourney闲鱼赚钱方法&#xff0c;主打的是易上手&#xff0c;有可操作性&#xff01; 具体…

项目性能优化之用compression-webpack-plugin插件开启gzip压缩

背景&#xff1a;vue项目打包发布后&#xff0c;部分js、css文件体积较大导致页面卡顿&#xff0c;于是使用webpack插件compression-webpack-plugin开启gzip压缩 前端配置vue.config.js 先通过npm下载compression-webpack-plugin包&#xff0c;npm i compression-webpack-plug…

C#使用RabbitMQ-2_详解工作队列模式

简介 &#x1f340;RabbitMQ中的工作队列模式是指将任务分配给多个消费者并行处理。在工作队列模式中&#xff0c;生产者将任务发送到RabbitMQ交换器&#xff0c;然后交换器将任务路由到一个或多个队列。消费者从队列中获取任务并进行处理。处理完成后&#xff0c;消费者可以向…

【图像分割】【深度学习】Windows10下UNet代码Pytorch实现与源码讲解

【图像分割】【深度学习】Windows10下UNet代码Pytorch实现与源码讲解 提示:最近开始在【医学图像分割】方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法。 文章目录 【图像分割】【深度学习】Windows10下UNet代码Pytorch实现与源码讲解前言UNet模型运行环境搭…

SQL语句创建一个简单的银行数据库

目录 一、银行业务E-R图 二、数据库模型图 转换关系模型后&#xff1a; 三、创建数据库 3.1 创建银行业务数据库 四、创建表 4.1 创建客户信息表 4.2 创建银行卡信息表 4.3 创建交易信息表 4.4 创建存款类型表 结果如下&#xff1a; ​编辑 五、插入适量数据 5.1…

java servlet果蔬产业监管系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java Web果蔬产业监管系统是一套完善的java web信息管理系统 serlvetdaobean mvc 模式开发 &#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主 要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5…

Ps:将文件载入堆栈

Ps菜单&#xff1a;文件/脚本/将文件载入堆栈 Scripts/Load Files into Stack 将文件载入堆栈 Load Files into Stack脚本命令可用于将两个及以上的文件载入到同一个 Photoshop 新文档中。 载入的每个文件都将成为独立的图层&#xff0c;并使用其原始文件名作为图层名。 Photos…

AI编译器的后端优化策略

背景 工作领域是AI芯片工具链相关&#xff0c;很多相关知识的概念都是跟着项目成长建立起来&#xff0c;但是比较整个技术体系在脑海中都不太系统&#xff0c;比如项目参与中涉及到了很多AI编译器开发相关内容&#xff0c;东西比较零碎&#xff0c;工作中也没有太多时间去做复盘…

InforSuiteAS中创中间件windows环境部署

版本&#xff1a;InforSuiteAS_StE_V10.0.5.2.1 环境要求&#xff1a;Java环境 DK1.8版本&#xff0c; 内存2GB或以上 &#xff0c; 硬盘空间 10GB或以上&#xff0c; 监视器 图形界面安装需要256色以上&#xff0c;字符界面安装没有色彩要求 &#xff0c;浏览器 Microsoft …

【华为 ICT HCIA eNSP 习题汇总】——题目集9

1、缺省情况下&#xff0c;广播网络上 OSPF 协议 Hello 报文发送的周期和无效周期分别为&#xff08;&#xff09;。 A、10s&#xff0c;40s B、40s&#xff0c;10s C、30s&#xff0c;20s D、20s&#xff0c;30s 考点&#xff1a;①路由技术原理 ②OSPF 解析&#xff1a;&…

臻于至善,CodeArts Snap 二维绘图来一套不?

前言 我在体验 华为云的 CodeArts Snap 时&#xff0c;第一个例子就是绘制三角函数图像&#xff0c;功能注释写的也很简单。 业务场景中&#xff0c;有一类就是需要产出各种二维图形的&#xff0c;比如&#xff0c;折线图、散点图、柱状图等。 为了提前积累业务素材&#xf…

Docker数据卷挂载(以容器化Mysql为例)

数据卷 数据卷是一个虚拟目录&#xff0c;是容器内目录与****之间映射的桥梁 在执行docker run命令时&#xff0c;使用**-v 本地目录&#xff1a;容器目录**可以完成本地目录挂载 eg.Mysql容器的数据挂载 1.在根目录root下创建目录mysql及三个子目录&#xff1a; cd ~ pwd m…

GitBook可以搭建知识库吗?有无其他更好更方便的?

在一个现代化的企业中&#xff0c;知识是一项宝贵的资产。拥有一个完善的企业知识库&#xff0c;不仅可以加速员工的学习和成长&#xff0c;还能提高工作效率和团队协作能力。然而&#xff0c;随着企业不断发展和扩大规模&#xff0c;知识库的构建和管理变得更加复杂和耗时。 |…

PyTorch 中的nn.Conv2d 类

nn.Conv2d 是 PyTorch 中的一个类&#xff0c;代表二维卷积层&#xff08;2D Convolution Layer&#xff09;。这个类广泛用于构建卷积神经网络&#xff08;CNN&#xff09;&#xff0c;特别是在处理图像数据时。 基本概念 卷积: 在神经网络的上下文中&#xff0c;卷积是一种特…

llamaindex 集成本地大模型

从​​​​​​​​​​​​​​用llamaindex 部署本地大模型 - 知乎Customizing LLMs within LlamaIndex Abstractions 目的&#xff1a;llamaindex 是一个很好的应用框架&#xff0c;基于此搭建一个RAG应用是一个不错的选择&#xff0c;但是由于llamaindex默认设置是openai的…

FlashInternImage实战:使用FlashInternImage实现图像分类任务(一)

文章目录 摘要安装包安装timm 数据增强Cutout和MixupEMA项目结构编译安装DCNv4环境安装过程配置CUDAHOME解决权限不够的问题 按装ninja编译DCNv4 计算mean和std生成数据集 摘要 https://arxiv.org/pdf/2401.06197.pdf 论文介绍了Deformable Convolution v4&#xff08;DCNv4&…

【MQ02】基础简单消息队列应用

基础简单消息队列应用 在上一课中&#xff0c;我们已经学习到了什么是消息队列&#xff0c;有哪些消息队列&#xff0c;以及我们会用到哪个消息队列。今天&#xff0c;就直接进入主题&#xff0c;学习第一种&#xff0c;最简单&#xff0c;但也是最常用&#xff0c;最好用的消息…