【TensorFlow学习笔记:神经网络八股】(实现MNIST数据集手写数字识别分类以及FASHION数据集衣裤识别分类)

课程来源:人工智能实践:Tensorflow笔记2

文章目录

  • 前言
  • 一、搭建网络八股sequential
    • 1.函数介绍
    • 2.6步法实现鸢尾花分类
  • 二、搭建网络八股class
    • 1.创建自己的神经网络模板:
    • 2.调用自己创建的model对象
  • 三、MNIST数据集
    • 1.用sequential搭建网络实现手写数字识别
    • 2.用类搭建网络实现手写数字识别
  • 四、FASHION数据集
    • 用sequential搭建网络实现衣裤识别
  • 总结


前言

本讲目标:使用八股搭建神经网络 神经网络搭建八股 iris代码复现 MNIST数据集 训练MNIST数据集 Fashion数据集

一、搭建网络八股sequential

使用六步法,使用TensorFlow的API: tf.keras搭建网络八股
1、import 导入相关模块
2、train、test 告知要喂入网络的训练集、测试集是什么,也就是要指定训练集、测试集的输入特征和训练集的标签
3、model = tf.keras.models.Sequential 在sequential()中搭建网络结构,逐层描述每层网络,相当于走了一遍前向传播
4、model.compile 在compile中配置训练方法,告知训练时选择哪种优化器,选择哪个损失函数,选择哪种评测指标
5、model.fit 在fit中执行训练过程,告知训练集和测试集的输入特征和标签,告知每个batch是多少,告知要迭代多少次数据集
6、model.summary 用summary打印出网络的结构和参数统计

1.函数介绍

sequential()用法:
model = tf.keras.models.Sequential([网络结构]) #描述各层网络
网络结构举例:
拉直层:tf.keras.layers.Flatter()
全连接层:tf.keras.layers.Dense(神经元个数,activation=“激活函数”,kernel_regularizer=哪种正则化)
activation(字符串给出) 可选:relu、softmax、signoid、tanh
kernel_regularizer可选:kernel_regularizer.l1()、kernel_regularizer.l2()
卷积层:tf.keras.layers.Conv2D(filters=卷积核个数,kernel_size=卷积核尺寸,strides=卷积步长,padding=“vaild” or “same”)
LSTM层:tf.kreas.layers.LSTM()

compile() 用法:
model.compile(optimizer =优化器,loss =损失函数,metrics=[“准确率”])
optimizer 可选:
‘sgd’ or tf.keras.optimizers.SGD(lr=学习率,momentum=动量参数)
‘adagrad’ or tf.keras.optimizers.Adagrad(lr=学习率)
‘adadelta’ or tf.keras.optimizers.Adadelta(lr=学习率)
‘adam’ or tf.keras.optimizers.Adam(lr=学习率,beta_1=0.9,beta_2=0.999)
loss 可选:
‘mse’ or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy’ or tf.keras.losses.SparseCategoricalCrossentropy(from_logits =False)
(有的神经网络的输出是经过了softmax等函数的概率分布,有些则不经概率分布直接输出,from_logits 是在询问是否是原始输出)
Metrics 可选:
‘accuracy’:y_pred和y都是数值,如y_pred=[1] y=[1]
‘categorical_accuracy’:y_pred 和 y 都是独热码(概率分布),如y_pred=[0,1,0], y=[0.5,0.5,0.5]
‘sparse_categorical_accuracy’:y_pred是数值,y是独热编码,y_pred=[1],y=[0.5,0.5,0.5]

fit()用法:
model.fit(训练集的输入特征,训练集的标签,
​ batch_size = ,epochs = ,
​ validation_data =(测试集的输入特征,测试集的标签),
​ validation_split =从训练集划分多少比例给测试集,
​ validation_freq =多少次epoch测试一次)

model()用法:
model.summary()
在这里插入图片描述

2.6步法实现鸢尾花分类

代码如下:

import tensorflow as tf
from sklearn import datasets
import numpy as np#由于这里是选择从训练集划分出测试集,所以不需要单独导入test
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
#打乱顺序
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)
#3个神经元,softmax激活,L2正则化
model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
#SGD优化器、学习率0.1,使用SparseCategoricalCrossentropy作为损失函数,由于神经网络末端使用softmax函数,输出为概率分布,所以from_logits为false
#鸢尾花数据集给的标签为0,1,2,神经网络前向传播的输出是概率分布,使用sparse_categorical_accuracy作为准确率
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
#输入训练数据,一次喂入32组数据,迭代500次,从训练集中划分出20%作为测试集,每迭代20次训练集就要在测试集中验证一次准确率
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
#打印网络结构和参数统计
model.summary()

打印结果如下:
在这里插入图片描述

二、搭建网络八股class

用sequential可以搭建出上层输出就是下层输入的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构。这时我们可以选择用类class搭建神经网络结构
使用六步法,使用TensorFlow的API: tf.keras搭建网络八股
1、import
2、train、test
3、class MyMode(Model) model=MyModel
4、model.compile
5、model.fit
6、model.summary

1.创建自己的神经网络模板:

伪代码如下:

class MyModel(Model):def _init_(self):super(MyModel,self).init_()定义网络结构块def call(self,x):调用网络结构块,实现前向传播return ymodel=MyModel()

代码如下:

class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()#鸢尾花分类的单层网络是含有3个神经元的全连接self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return y
#实例化名为model的对象
model = IrisModel()

2.调用自己创建的model对象

代码如下:

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

打印结果如下:
在这里插入图片描述

三、MNIST数据集

MNIST数据集

提供6万张28x28像素点的0~9手写数字图片和标签,用于训练。
提供1万张28x28像素点的0~9手写数字图片和标签,用于测试。

导入数据集

mnist =tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()

作为输入特征,输入神经网络时,将数据拉伸为一维数组

tf.keras.layers.Flatter()

1.用sequential搭建网络实现手写数字识别

code:

import tensorflow as tfmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#对输入网络的特征进行归一化,使原本0~255的灰度值转化为0~1的小数。
#把输入特征的值变小更有利于神经网络吸收
x_train, x_test = x_train / 255.0, x_test / 255.0
#用Sequential搭建网络
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),                          #把输入特征拉直为1维数组,即784(28*28)个数值tf.keras.layers.Dense(128, activation='relu'),      #定义第一层网络有128个神经元,relu为激活函数tf.keras.layers.Dense(10, activation='softmax')     #定义第二层网络有10个神经元,softmax使输出符合概率分布
])
#用compile配置训练方法
model.compile(optimizer='adam',                         loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
#每一轮训练集迭代,执行一次测试集评测,随着迭代轮数增加,手写数字识别准确率不断提升(使用测试集)
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

print result:

Train on 60000 samples, validate on 10000 samples
Epoch 1/5
60000/60000 [] - 4s 62us/sample - loss: 0.2589 - sparse_categorical_accuracy: 0.9262 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9607
Epoch 2/5
60000/60000 [] - 2s 40us/sample - loss: 0.1114 - sparse_categorical_accuracy: 0.9676 - val_loss: 0.1027 - val_sparse_categorical_accuracy: 0.9699
Epoch 3/5
60000/60000 [] - 3s 43us/sample - loss: 0.0762 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0898 - val_sparse_categorical_accuracy: 0.9722
Epoch 4/5
60000/60000 [] - 2s 41us/sample - loss: 0.0573 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9752
Epoch 5/5
60000/60000 [] - 2s 41us/sample - loss: 0.0450 - sparse_categorical_accuracy: 0.9858 - val_loss: 0.0846 - val_sparse_categorical_accuracy: 0.9738
Model: “sequential”
Layer (type) Output Shape Param #
=================================================================
flatten (Flatten) multiple 0
dense (Dense) multiple 100480
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0


可以观察到,随着迭代轮数增加,准确率也不断提升。训练的参数也是极其多的,达到10万多个。

2.用类搭建网络实现手写数字识别

只是实例化model的方法不同,其他与用sequential搭建网络实现手写数字识别一致。
init函数中定义了call函数中所用到的层,call函数中从输入x到输出y走过一次前向传播,返回输出y

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Modelmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return ymodel = MnistModel()model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

四、FASHION数据集

FASHION数据集

提供6万张 28x28像素点的衣裤等图片和标签,用于训练.
提供1万张28x28像素点的衣裤等图片和标签,用于测试。
在这里插入图片描述

导入数据集

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()

用sequential搭建网络实现衣裤识别

加载数据需要较长时间,需耐心等待

import tensorflow as tffashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

用类的方法也可以实现,这里不做重复展开,套用八股模板即可。


总结

这个单元将整个训练的构架走了一遍,并且以八股的形式做了总结,收获很大。

课程链接:MOOC人工智能实践:TensorFlow笔记2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/378292.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第十二章 图形用户界面

第十二章 图形用户界面 GUI就是包含按钮、文本框等控件的窗口 Tkinter是事实上的Python标准GUI工具包 创建GUI示例应用程序 初探 导入tkinter import tkinter as tk也可导入这个模块的所有内容 from tkinter import *要创建GUI,可创建一个将充当主窗口的顶级组…

Sqlserver 2005 配置 数据库镜像:数据库镜像期间可能出现的故障:镜像超时机制

数据库镜像期间可能出现的故障 SQL Server 2005其他版本更新日期: 2006 年 7 月 17 日 物理故障、操作系统故障或 SQL Server 故障都可能导致数据库镜像会话失败。数据库镜像不会定期检查 Sqlservr.exe 所依赖的组件来验证组件是在正常运行还是已出现故障。但对于某…

【神经网络八股扩展】:自制数据集

课程来源:人工智能实践:Tensorflow笔记2 文章目录前言1、文件一览2、将load_data()函数替换掉2、调用generateds函数4、效果总结前言 本讲目标:自制数据集,解决本领域应用 将我们手中的图片和标签信息制作为可以直接导入的npy文件。 1、文件一览 首先看…

c语言输出11258循环,c/c++内存机制(一)(转)

一:C语言中的内存机制在C语言中,内存主要分为如下5个存储区:(1)栈(Stack):位于函数内的局部变量(包括函数实参),由编译器负责分配释放,函数结束,栈变量失效。(2)堆(Heap):由程序员用…

【神经网络八股扩展】:数据增强

课程来源:人工智能实践:Tensorflow笔记2 文章目录前言TensorFlow2数据增强函数数据增强网络八股代码:总结前言 本讲目标:数据增强,增大数据量 关于我们为何要使用数据增强以及常用的几种数据增强的手法,可以看看下面的文章&#…

分享WCF聊天程序--WCFChat

无意中在一个国外的站点下到了一个利用WCF实现聊天的程序,作者是:Nikola Paljetak。研究了一下,自己做了测试和部分修改,感觉还不错,分享给大家。先来看下运行效果:开启服务:客户端程序&#xf…

【神经网络扩展】:断点续训和参数提取

课程来源:人工智能实践:Tensorflow笔记2 文章目录前言断点续训主要步骤参数提取主要步骤总结前言 本讲目标:断点续训,存取最优模型;保存可训练参数至文本 断点续训主要步骤 读取模型: 先定义出存放模型的路径和文件名&#xff0…

小米手环6NFC安装太空人表盘

以前看我室友峰哥、班长都有手环,一直想买个手环,不舍得,然后今年除夕的时候降价,一狠心,入手了,配上除夕的打年兽活动还有看春晚京东敲鼓领的红包和这几年攒下来的京东豆豆,原价279的小米手环6…

为什么两层3*3卷积核效果比1层5*5卷积核效果要好?

目录1、感受野2、2层3 * 3卷积与1层5 * 5卷积3、2层3 * 3卷积与1层5 * 5卷积的计算量比较4、2层3 * 3卷积与1层5 * 5卷积的非线性比较5、2层3 * 3卷积与1层5 * 5卷积的参数量比较1、感受野 感受野:卷积神经网络各输出特征像素点,在原始图片映射区域大小。…

算法正确性和复杂度分析

算法正确性——循环不变式 算法复杂度的计算 方法一 代换法 —局部代换 这里直接对n变量进行代换 —替换成对数或者指数的情形 n 2^m —整体代换 这里直接对递推项进行代换 —替换成内部递推下标的形式 T(2^n) S(n) 方法二 递归树法 —用实例说明 —分析每一层的内容 —除了…

第十五章 Python和Web

第十五章 Python和Web 本章讨论Python Web编程的一些方面。 三个重要的主题:屏幕抓取、CGI和mod_python。 屏幕抓取 屏幕抓取是通过程序下载网页并从中提取信息的过程。 下载数据并对其进行分析。 从Python Job Board(http://python.org/jobs&#x…

【数据结构基础笔记】【图】

代码参考《妙趣横生的算法.C语言实现》 文章目录前言1、图的概念2、图的存储形式1、邻接矩阵:2、邻接表3、代码定义邻接表3、图的创建4、深度优先搜索DFS5、广度优先搜索BFS6、实例分析前言 本章总结:图的概念、图的存储形式、邻接表定义、图的创建、图…

如何蹭网

引言蹭网,在普通人的眼里,是一种很高深的技术活,总觉得肯定很难,肯定很难搞。还没开始学,就已经败给了自己的心里,其实,蹭网太过于简单。我可以毫不夸张的说,只要你会windows的基本操…

android对象缓存,Android简单实现 缓存数据

前言1、每一种要缓存的数据都是有对应的versionCode,通过versionCode请求网络获取是否需要更新2、提前将要缓存的数据放入assets文件夹中,打包上线。缓存设计代码实现/*** Created by huangbo on 2017/6/19.** 主要是缓存的工具类** 缓存设计&#xff1a…

通信原理.绪论

今天刚上通信原理的第一节课,没有涉及过多的讲解,只是讲了下大概的知识框架。现记录如下: 目录1、基本概念消息、信息与信号2、通信系统模型1、信息源2、发送设备3、信道4、接收设备5、信宿6、模拟通信系统模型7、数字通信系统模型8、信源编…

css rgba透明_rgba()函数以及CSS中的示例

css rgba透明Introduction: 介绍: Functions are used regularly while we are developing a web page or website. Therefore, to be a good developer you need to master as many functions as you can. This way your coding knowledge will increase as well …

犀牛脚本:仿迅雷的增强批量下载

迅雷的批量下载满好用。但是有两点我不太中意。在这个脚本里会有所增强 1、不能设置保存的文件名。2、不能单独设置这批下载的线程限制。 使用方法 // 下载从编号001到编号020的图片,保存名为猫咪写真*.jpg 使用6个线程 jdlp http://bizhi.zhuoku.com/bizhi/200804/…

android 服务端 漏洞,安卓漏洞 CVE 2017-13287 复现详解-

2018年4月,Android安全公告公布了CVE-2017-13287漏洞。与同期披露的其他漏洞一起,同属于框架中Parcelable对象的写入(序列化)与读出(反序列化)的不一致所造成的漏洞。在刚看到谷歌对于漏洞给出的补丁时一头雾水,在这里要感谢heeeeenMS509Team…

GAP(全局平均池化层)操作

转载的文章链接: 为什么使用全局平均池化层? 关于 global average pooling https://blog.csdn.net/qq_23304241/article/details/80292859 在卷积神经网络的初期,卷积层通过池化层(一般是 最大池化)后总是要一个或n个全…

zoj1245 Triangles(DP)

/* 动态三角形&#xff1a;每次DP时考虑的是两个子三角形的高度即可 注意&#xff1a; 三角形可以是倒置的。 */ View Code 1 #include <iostream> 2 #include <cstdlib> 3 #include <cstring> 4 #include <stdio.h> 5 6 using namespace std; 7 8…