【TensorFlow学习笔记:神经网络八股】(实现MNIST数据集手写数字识别分类以及FASHION数据集衣裤识别分类)

课程来源:人工智能实践:Tensorflow笔记2

文章目录

  • 前言
  • 一、搭建网络八股sequential
    • 1.函数介绍
    • 2.6步法实现鸢尾花分类
  • 二、搭建网络八股class
    • 1.创建自己的神经网络模板:
    • 2.调用自己创建的model对象
  • 三、MNIST数据集
    • 1.用sequential搭建网络实现手写数字识别
    • 2.用类搭建网络实现手写数字识别
  • 四、FASHION数据集
    • 用sequential搭建网络实现衣裤识别
  • 总结


前言

本讲目标:使用八股搭建神经网络 神经网络搭建八股 iris代码复现 MNIST数据集 训练MNIST数据集 Fashion数据集

一、搭建网络八股sequential

使用六步法,使用TensorFlow的API: tf.keras搭建网络八股
1、import 导入相关模块
2、train、test 告知要喂入网络的训练集、测试集是什么,也就是要指定训练集、测试集的输入特征和训练集的标签
3、model = tf.keras.models.Sequential 在sequential()中搭建网络结构,逐层描述每层网络,相当于走了一遍前向传播
4、model.compile 在compile中配置训练方法,告知训练时选择哪种优化器,选择哪个损失函数,选择哪种评测指标
5、model.fit 在fit中执行训练过程,告知训练集和测试集的输入特征和标签,告知每个batch是多少,告知要迭代多少次数据集
6、model.summary 用summary打印出网络的结构和参数统计

1.函数介绍

sequential()用法:
model = tf.keras.models.Sequential([网络结构]) #描述各层网络
网络结构举例:
拉直层:tf.keras.layers.Flatter()
全连接层:tf.keras.layers.Dense(神经元个数,activation=“激活函数”,kernel_regularizer=哪种正则化)
activation(字符串给出) 可选:relu、softmax、signoid、tanh
kernel_regularizer可选:kernel_regularizer.l1()、kernel_regularizer.l2()
卷积层:tf.keras.layers.Conv2D(filters=卷积核个数,kernel_size=卷积核尺寸,strides=卷积步长,padding=“vaild” or “same”)
LSTM层:tf.kreas.layers.LSTM()

compile() 用法:
model.compile(optimizer =优化器,loss =损失函数,metrics=[“准确率”])
optimizer 可选:
‘sgd’ or tf.keras.optimizers.SGD(lr=学习率,momentum=动量参数)
‘adagrad’ or tf.keras.optimizers.Adagrad(lr=学习率)
‘adadelta’ or tf.keras.optimizers.Adadelta(lr=学习率)
‘adam’ or tf.keras.optimizers.Adam(lr=学习率,beta_1=0.9,beta_2=0.999)
loss 可选:
‘mse’ or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits =False)
(有的神经网络的输出是经过了softmax等函数的概率分布,有些则不经概率分布直接输出,from_logits 是在询问是否是原始输出)
Metrics 可选:
‘accuracy’:y_pred和y都是数值,如y_pred=[1] y=[1]
‘categorical_accuracy’:y_pred 和 y 都是独热码(概率分布),如y_pred=[0,1,0], y=[0.5,0.5,0.5]
‘sparse_categorical_accuracy’:y_pred是数值,y是独热编码,y_pred=[1],y=[0.5,0.5,0.5]

fit()用法:
model.fit(训练集的输入特征,训练集的标签,
​ batch_size = ,epochs = ,
​ validation_data =(测试集的输入特征,测试集的标签),
​ validation_split =从训练集划分多少比例给测试集,
​ validation_freq =多少次epoch测试一次)

model()用法:
model.summary()
在这里插入图片描述

2.6步法实现鸢尾花分类

代码如下:

import tensorflow as tf
from sklearn import datasets
import numpy as np#由于这里是选择从训练集划分出测试集,所以不需要单独导入test
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
#打乱顺序
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)
#3个神经元,softmax激活,L2正则化
model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
#SGD优化器、学习率0.1,使用SparseCategoricalCrossentropy作为损失函数,由于神经网络末端使用softmax函数,输出为概率分布,所以from_logits为false
#鸢尾花数据集给的标签为0,1,2,神经网络前向传播的输出是概率分布,使用sparse_categorical_accuracy作为准确率
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
#输入训练数据,一次喂入32组数据,迭代500次,从训练集中划分出20%作为测试集,每迭代20次训练集就要在测试集中验证一次准确率
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
#打印网络结构和参数统计
model.summary()

打印结果如下:
在这里插入图片描述

二、搭建网络八股class

用sequential可以搭建出上层输出就是下层输入的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构。这时我们可以选择用类class搭建神经网络结构
使用六步法,使用TensorFlow的API: tf.keras搭建网络八股
1、import
2、train、test
3、class MyMode(Model) model=MyModel
4、model.compile
5、model.fit
6、model.summary

1.创建自己的神经网络模板:

伪代码如下:

class MyModel(Model):def _init_(self):super(MyModel,self).init_()定义网络结构块def call(self,x):调用网络结构块,实现前向传播return ymodel=MyModel()

代码如下:

class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()#鸢尾花分类的单层网络是含有3个神经元的全连接self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return y
#实例化名为model的对象
model = IrisModel()

2.调用自己创建的model对象

代码如下:

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

打印结果如下:
在这里插入图片描述

三、MNIST数据集

MNIST数据集

提供6万张28x28像素点的0~9手写数字图片和标签,用于训练。
提供1万张28x28像素点的0~9手写数字图片和标签,用于测试。

导入数据集

mnist =tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()

作为输入特征,输入神经网络时,将数据拉伸为一维数组

tf.keras.layers.Flatter()

1.用sequential搭建网络实现手写数字识别

code:

import tensorflow as tfmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#对输入网络的特征进行归一化,使原本0~255的灰度值转化为0~1的小数。
#把输入特征的值变小更有利于神经网络吸收
x_train, x_test = x_train / 255.0, x_test / 255.0
#用Sequential搭建网络
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),                          #把输入特征拉直为1维数组,即784(28*28)个数值tf.keras.layers.Dense(128, activation='relu'),      #定义第一层网络有128个神经元,relu为激活函数tf.keras.layers.Dense(10, activation='softmax')     #定义第二层网络有10个神经元,softmax使输出符合概率分布
])
#用compile配置训练方法
model.compile(optimizer='adam',                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
#每一轮训练集迭代,执行一次测试集评测,随着迭代轮数增加,手写数字识别准确率不断提升(使用测试集)
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

print result:

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
60000/60000 [] - 4s 62us/sample - loss: 0.2589 - sparse_categorical_accuracy: 0.9262 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9607
Epoch 2/5
60000/60000 [] - 2s 40us/sample - loss: 0.1114 - sparse_categorical_accuracy: 0.9676 - val_loss: 0.1027 - val_sparse_categorical_accuracy: 0.9699
Epoch 3/5
60000/60000 [] - 3s 43us/sample - loss: 0.0762 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0898 - val_sparse_categorical_accuracy: 0.9722
Epoch 4/5
60000/60000 [] - 2s 41us/sample - loss: 0.0573 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9752
Epoch 5/5
60000/60000 [] - 2s 41us/sample - loss: 0.0450 - sparse_categorical_accuracy: 0.9858 - val_loss: 0.0846 - val_sparse_categorical_accuracy: 0.9738
Model: “sequential”
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) multiple 0
dense (Dense) multiple 100480
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0


可以观察到,随着迭代轮数增加,准确率也不断提升。训练的参数也是极其多的,达到10万多个。

2.用类搭建网络实现手写数字识别

只是实例化model的方法不同,其他与用sequential搭建网络实现手写数字识别一致。
init函数中定义了call函数中所用到的层,call函数中从输入x到输出y走过一次前向传播,返回输出y

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Modelmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return ymodel = MnistModel()model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

四、FASHION数据集

FASHION数据集

提供6万张 28x28像素点的衣裤等图片和标签,用于训练.
提供1万张28x28像素点的衣裤等图片和标签,用于测试。
在这里插入图片描述

导入数据集

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()

用sequential搭建网络实现衣裤识别

加载数据需要较长时间,需耐心等待

import tensorflow as tffashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

用类的方法也可以实现,这里不做重复展开,套用八股模板即可。


总结

这个单元将整个训练的构架走了一遍,并且以八股的形式做了总结,收获很大。

课程链接:MOOC人工智能实践:TensorFlow笔记2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/378292.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c语言 在执行区域没有空格,C语言上机操作指导之TurboC.doc

C语言上机操作指导之 -------- Turbo C程序设计是实践性很强的过程,任何程序都必须在计算机上运行,以检验程序的正确与否。因此在学习程序设计中,一定要重视上机实践环节,通过上机可以加深理解 C语言的有关概念,以巩固…

java 根据类名示例化类_Java即时类| from()方法与示例

java 根据类名示例化类即时类from()方法 (Instant Class from() method) from() method is available in java.time package. from()方法在java.time包中可用。 from() method is used to return a copy of the Instant from the given TemporalAccessor object. from()方法用于…

第十二章 图形用户界面

第十二章 图形用户界面 GUI就是包含按钮、文本框等控件的窗口 Tkinter是事实上的Python标准GUI工具包 创建GUI示例应用程序 初探 导入tkinter import tkinter as tk也可导入这个模块的所有内容 from tkinter import *要创建GUI,可创建一个将充当主窗口的顶级组…

Sqlserver 2005 配置 数据库镜像:数据库镜像期间可能出现的故障:镜像超时机制

数据库镜像期间可能出现的故障 SQL Server 2005其他版本更新日期: 2006 年 7 月 17 日 物理故障、操作系统故障或 SQL Server 故障都可能导致数据库镜像会话失败。数据库镜像不会定期检查 Sqlservr.exe 所依赖的组件来验证组件是在正常运行还是已出现故障。但对于某…

江西理工大学期末试卷c语言,2016年江西理工大学信息工程学院计算机应用技术(加试)之C语言程序设计复试笔试最后押题五套卷...

一、选择题1. 设有函数定义:( )。A. B. C. D. 答:A则以下对函数sub 的调用语句中,正确的是【解析】函数的参数有两个,第一个是整型,第二个是字符类型,在调用函数时,实参必须一个是整型&#xff…

第十三章 数据库支持

第十三章 数据库支持 本章讨论Python数据库API(一种连接到SQL数据库的标准化方式),并演示如何使用这个API来执行一些基本的SQL。最后,本章将讨论其他一些数据库技术。 关Python支持的数据库清单 Python数据库API 标准数据库API…

【神经网络八股扩展】:自制数据集

课程来源:人工智能实践:Tensorflow笔记2 文章目录前言1、文件一览2、将load_data()函数替换掉2、调用generateds函数4、效果总结前言 本讲目标:自制数据集,解决本领域应用 将我们手中的图片和标签信息制作为可以直接导入的npy文件。 1、文件一览 首先看…

java 批量处理 示例_Java中异常处理的示例

java 批量处理 示例Here, we will analyse some exception handling codes, to better understand the concepts. 在这里,我们将分析一些异常处理代码 ,以更好地理解这些概念。 Try to find the errors in the following code, if any 尝试在以下代码中…

hdu 1465 不容易系列之一

http://acm.hdu.edu.cn/showproblem.php?pid1465 今天立神和我们讲了错排,才知道错排原来很简单,从第n个推起: 当n个编号元素放在n个编号位置,元素编号与位置编号各不对应的方法数用M(n)表示,那么M(n-1)就表示n-1个编号元素放在n-1个编号位置…

第十四章 网络编程

第十四章 网络编程 本章首先概述Python标准库中的一些网络模块。然后讨论SocketServer和相关的类,并介绍同时处理多个连接的各种方法。最后,简单地说一说Twisted,这是一个使用Python编写网络程序的框架,功能丰富而成熟。 几个网…

c语言输出11258循环,c/c++内存机制(一)(转)

一:C语言中的内存机制在C语言中,内存主要分为如下5个存储区:(1)栈(Stack):位于函数内的局部变量(包括函数实参),由编译器负责分配释放,函数结束,栈变量失效。(2)堆(Heap):由程序员用…

【神经网络八股扩展】:数据增强

课程来源:人工智能实践:Tensorflow笔记2 文章目录前言TensorFlow2数据增强函数数据增强网络八股代码:总结前言 本讲目标:数据增强,增大数据量 关于我们为何要使用数据增强以及常用的几种数据增强的手法,可以看看下面的文章&#…

C++:从C继承的标准库

C从C继承了的标准库 &#xff0c; 这就意味着 C 中 可以使用的标准库函数 在C 中都可以使用 &#xff0c; 但是需要注意的是 &#xff0c; 这些标准库函数在C中不再以 <xxx.h> 命名 &#xff0c; 而是变成了 <cxxx> 。 例如 &#xff1a; 在C中操作字符串的…

分享WCF聊天程序--WCFChat

无意中在一个国外的站点下到了一个利用WCF实现聊天的程序&#xff0c;作者是&#xff1a;Nikola Paljetak。研究了一下&#xff0c;自己做了测试和部分修改&#xff0c;感觉还不错&#xff0c;分享给大家。先来看下运行效果&#xff1a;开启服务&#xff1a;客户端程序&#xf…

c# uri.host_C#| 具有示例的Uri.Equality()运算符

c# uri.hostUri.Equality()运算符 (Uri.Equality() Operator) Uri.Equality() Operator is overloaded which is used to compare two Uri objects. It returns true if two Uri objects contain the same Uri otherwise it returns false. Uri.Equality()运算符已重载&#xf…

第六章至第九章的单元测试

1,‌助剂与纤维作用力大于纤维分子之间的作用力,则该助剂最好用作() 纤维增塑膨化剂。 2,助剂扩散速率快,优先占领纤维上的染座,但助剂与纤维之间作用力小于染料与纤维之间作用力,该助剂可以作为() 匀染剂。 3,助剂占领纤维上的染座,但助剂与纤维之间作用力大于染…

【神经网络扩展】:断点续训和参数提取

课程来源&#xff1a;人工智能实践:Tensorflow笔记2 文章目录前言断点续训主要步骤参数提取主要步骤总结前言 本讲目标:断点续训&#xff0c;存取最优模型&#xff1b;保存可训练参数至文本 断点续训主要步骤 读取模型&#xff1a; 先定义出存放模型的路径和文件名&#xff0…

开发DBA(APPLICATION DBA)的重要性

开发DBA是干什么的&#xff1f; 1. 审核开发人员写的SQL&#xff0c;并且纠正存在性能问题的SQL ---非常重要 2. 编写复杂业务逻辑SQL&#xff0c;因为复杂业务逻辑SQL开发人员写出的SQL基本上都是有性能问题的&#xff0c;与其让开发人员写&#xff0c;不如DBA自己写。---非常…

javascript和var之间的区别?

You can define your variables in JavaScript using two keywords - the let keyword and the var keyword. The var keyword is the oldest way of defining and declaring variables in JavaScript whereas the let is fairly new and was introduced by ES15. 您可以使用两…

小米手环6NFC安装太空人表盘

以前看我室友峰哥、班长都有手环&#xff0c;一直想买个手环&#xff0c;不舍得&#xff0c;然后今年除夕的时候降价&#xff0c;一狠心&#xff0c;入手了&#xff0c;配上除夕的打年兽活动还有看春晚京东敲鼓领的红包和这几年攒下来的京东豆豆&#xff0c;原价279的小米手环6…