使用八股搭建神经网络

神经网络搭建八股

使用tf.keras

六步法搭建模型

1.import

2.train, test 指定输入特征/标签

3.model = tf.keras.model.Sequential

在Squential,搭建神经网络

4.model.compile

配置训练方法,选择哪种优化器、损失函数、评测指标

5.model.fit 执行训练过程,告知训练集输入特征,batch,epoch

6.model.summary打印网络结构和参数统计

model = tf.keras.model.Sequential

Sequential是个容器,封装了网络结构

网络结构例子:

拉直层:tf.keras.layers.Flatten()

全连接层:tf.keras.layers.Dense(神经元个数,activetion="激活函数",kernel_regularizer=那种正则化)

卷积层:

tf.keras.layers.Conv2D(filters= 卷积核个数,kernel_size=卷积核尺寸,strides=卷积步长,padding="valid"or"same"

LSTM层:

tf.keras.layers.LSTM()

model.compile

model.compile(optimizer=优化器,loss=损失函数,metrics=["准确率"]

后期可通过tensorflow官网查询函数的具体用法,调节超参数

有些网络输出经过softmax概率分布输出,有些不经过概率分布输出 

当网络评测指标和蒙的概率一样,例如十分类概率为.1/10.可能概率分布错了

独热码y_和y是[010]网络输出则为[0.xx, 0.xx, 0.xx]

 第三种方法 y_= [1] y =[0.2xx,0xx,0xx]

model.fit

model.fit(训练集的输入特征,训练集的标签,batch_size, epochs=, 

validation_data=(测试集的输入特征,标签),

validation_split=从训练集划分多少比例给测试集,

validation_freq=多少次epoch测试一次)

model.summary

重构Iris分类

import tensorflow as tf
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)model.summary()

自定义搭建模型

swquential可以搭建上层输出就是下层输入的网络结构,但是无法搭建带有跳连特征的非顺序网络结构

class MyModel(Model)

        def __init__(self):

                super(MyModel, self) __init()

                定义网络结构块

        def call(self, x): #写出前向传播

               调用网络结构块,实现前向传播

        return y     

model = MyModel

__init__定义出积木

call调用积木,实现前向传播

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

每循环一次train,计算一次test的测试指标

MNIST数据集

1.导入MNIST数据集

mnist=tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) =  mnist.load_data(

2.作为输入特征,输入神经网络时,将数据拉伸成一维数组:

tf.keras.layers.Flatten()

把784个像素点的灰度值作为输入特征放入神经网络

plt.imshow(x_train[0], cmap='gray')#绘制灰度图

plt.show()

0表示纯黑色255表示纯白色

需要对测试集和数据集进行归一化处理,把数值变小,更适合神经网络吸收,使用sequental训练模型,由于输入特征为数组,输出为概率分布,所以我们选择sparse_categorical_accuracy

import tensorflow as tfmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

准确率是测试集的准确率

自定义Model实现 __init__中定义cell函数中用到的层

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Modelmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return ymodel = MnistModel()model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

FASHION数据集

import tensorflow as tffashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/42947.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

送给我亲爱的Python

亲爱的 Python, 在万物皆代码的世界里,你是我最优雅、最高效的算法。自从第一次遇见你,在那行“Hello, World!”之后,我的世界就被点亮了。你的简洁性和强大的功能,让我深深着迷,就像一个精心设计的函数&am…

数据结构双向循环链表

主程序 #include "fun.h" int main(int argc, const char *argv[]) { double_p Hcreate_head(); insert_head(H,10); insert_head(H,20); insert_head(H,30); insert_head(H,40); insert_tail(H,50); show_link(H); del_tail(H); …

Python 传递参数和返回值

Python是一种功能强大的编程语言,它以其简洁和易用性而广受欢迎。在Python编程中,参数传递和返回值是函数调用中两个非常重要的概念。理解这些概念对于编写高效且可维护的代码至关重要。 一、参数传递 在Python中,函数参数可以通过以下几种…

Linux 网络时间同步:NTP 与 Chrony 的终极对决

Linux 网络时间同步:NTP 与 Chrony 的终极对决 在网络世界中,时间同步是一项至关重要的任务。无论是确保分布式系统的一致性,还是维护安全协议的完整性,准确的时间同步都是必不可少的。网络时间协议(NTP)和…

Golang期末作业之电子商城(源码)

作品介绍 1.网页作品简介方面 :主要有:首页 商品详情 购物车 订单 评价 支付 总共 5个页面 2.作品使用的技术:这个作品基于Golang语言,并且结合一些前端的知识,例如:HTML、CSS、JS、AJAX等等知识点,同时连接数据库的&…

统信UOS软件包标识化工具deepin-sbom-tools使用

原文链接:统信UOS上使用软件包标识化工具deepin-sbom-tools Hello,大家好啊!今天给大家带来一篇关于在统信UOS上使用软件包标识化工具deepin-sbom-tools的文章。deepin-sbom-tools是一个强大的工具,可以帮助开发者和系统管理员更好…

Linux初始化新的git仓库

1.在git服务器上找到项目常部署的git地址可以根据其他项目的git地址确认 例如ssh://git192.168.10.100/opt/git/repository.git 用户名:git(前面的是用户) 服务器地址:192.168.10.100 git仓库路径:/opt/git/ 2.在服务器…

数据结构之折半查找

折半查找的算法思想: 折半查找又称二分查找,它仅仅适用于有序的顺表。 折半查找的基本思想:首先将给定值key与表中中间位置的元素(mid的指向元素)比较。midlowhigh/2(向下取整) 若key与中间元…

C#—Json序列化和反序列化

C#—Json序列化和反序列化 在C#中,可以使用System.Web.Script.Serialization.JavaScriptSerializer类来序列化和反序列化JSON数据。 可以使用Newtonsoft.Json库进行JSON的序列化。 可以使用.NET内置的System.Text.Json库来进行JSON的序列化。 json文件格式 [ { …

搜索引擎优化培训机构怎么选?这篇文章告诉你答案

搜索引擎优化(SEO)已成为网络生存必备技能。然而面对众多培训机构,如何选择优秀者?本文将为您揭晓此事,助您找到腾飞之地。 一、培训机构的多样性:琳琅满目的选择 当前SEO培训市场繁芜复杂,既…

C++ 八股(1)

C语言中strcpy为什么不安全?如何解决? 主要原因是缺乏对输入长度的边界检查,容易导致缓冲区溢出漏洞。 解决:可以使用strncpy函数替代,或者在程序最顶端加入代码段 #define _CRT_SECURE_NO_WARNINGS 缓冲区溢出 …

javascript高级部分笔记

javascript高级部分 Function方法 与 函数式编程 call 语法:call([thisObj[,arg1[, arg2[, [,.argN]]]]]) 定义:调用一个对象的一个方法,以另一个对象替换当前对象。 说明:call 方法可以用来代替另一个对象调用一个方法。cal…

MySQL运维实战之ProxySQL(9.5)proxysql和MySQL Group Replication配合使用

作者:俊达 如果后端MySQL使用了Group Replication,可通过配置mysql_group_replication_hostgroups表来实现高可用 1 mysql_group_replication_hostgroups 字段描述writer_hostgroup写hostgroup。read_only和super_read_only OFF的节点。backup_writer…

Vue3 pdf.js将二进制文件流转成pdf预览

好久没写东西,19年之前写过一篇Vue2将pdf二进制文件流转换成pdf文件,如果Vue2换成Vue3了,顺带来一篇文章,pdf.js这个东西用来解决内网pdf预览,是个不错的选择。 首先去pdfjs官网,下载需要的文件 然后将下载…

第4章 IT服务规划设计

第4章 IT服务规划设计 4.1 概述 规划设计处于整个IT服务生命周期中的前端,可以帮助IT服务供方了解客户的需求,并对其进行全面的需求分析,然后通过对服务要素(包括人员、资源、技术和过程)、服务模式和服务方案的具体…

OpenHarmony4.x 系统模拟器环境

先下载源码和编译程序: 首先查看 OpenHarmony4.1源码下载、编译,生成OHOS_Image可执行文件的最简易流程 准备在QEMU模拟器中运行ARM Cortex-M4的轻型开源鸿蒙系统 官方支持的开发板和模拟器种类-编译形态整体说明OpenAtom OpenHarmony 已支持的示例工…

ArduPilot开源代码之AP_MSP

ArduPilot开源代码之AP_MSP 1. 源由2. Library设计2.1 启动代码2.2 支持特性2.3 MSP DisplayPort v.s. DJI FPV OSD 3. 重要例程3.1 AP_MSP::init3.2 AP_MSP::loop3.3 AP_MSP::init_backend 4. 实例理解5. 总结6. 参考资料 1. 源由 AP_MSP是处理MSP协议格式的报文数据应用类。…

反向业务判断逻辑

业务功能需求: 根据id扣减用户余额 包括:判断用户状态是否正常判断用户余额是否充足 正向逻辑: 判断用户为正常下,判断用户余额充足,进行余额扣减; 》正向逻辑,多重嵌套,代码不美观…

✈️一文带你入门【NestJS】

✈️引言 在现代Web开发领域,框架和技术的迭代速度令人咋舌。其中,NestJS作为一款基于Node.js的后端框架,以其卓越的设计理念和强大的功能集,迅速吸引了众多开发者的眼球。本文将带你深入了解NestJS的起源、发展,以及…

SpringIOC原理

SpringIOC原理 1.概念 Spring通过一个配置文件描述Bean及Bean之间的依赖关系,利用Java语言的反射功能实例化Bean并建立Bean之间的依赖关系。Spring的IOC容器在完成这些底层工作的基础上,还提供了Bean实例缓存、生命周期管理、Bean实例代理、事件发布、…