3 Tensorflow构建模型详解

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

本篇目标是介绍如何构建一个简单的线性回归模型,要点如下:

  • 了解神经网络原理
  • 构建模型的一般步骤
  • 模型重要参数介绍


1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
神经网络是一种广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应。

首先我们要了解下密集层(也叫全连接层),密集层是一个深度连接的神经网络层,在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

在上一篇我们创建了预测价格模型,代码为:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

其中Sequential是顺序的意思,Dense就是密集层。

看文字有点抽象,举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连,这样形成的神经网络就是密集层。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

了解这些基本的原理后,我们就开始创建一个简单的费用预测模型。

3、西瓜费用预测模型详解

代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=500)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

具体流程如下:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)设置训练数据
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=5)# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]print('w:')
print(w)
print('b: ')
print(b)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 

4、创建更复杂一点的模型

现实生活中我们要预测的东西影响因素可能有很多个,如房价预测,房价可能受到房屋面积、房间数量等等因素影响。思考一下,下面的神经网络图创建模型时要如何设置参数呢?

model = tf.keras.Sequential([tf.keras.layers.Dense(2, input_shape=[3]),tf.keras.layers.Dense(1)
])


 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/125410.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

菜单管理中icon图标回显

<el-table-column prop"icon" label"图标" show-overflow-tooltip algin"center"><template v-slot"{ row }"><el-icon :class"row.icon"></el-icon></template></el-table-column>

Oracle数据库创建Sequence序列的基本使用

1.作用就是批量插入数据的时候可以给一个主键 sequence dose not exist_sequence not exist_拒—绝的博客-CSDN博客 Oracle创建Sequence序列_TheEzreal的博客-CSDN博客 Oracle序列&#xff08;sequence&#xff09;创建失败&#xff0c;无法取值&#xff08;.nextval&#x…

iOS iGameGuardian修改器检测方案

一直以来&#xff0c;iOS 系统的安全性、稳定性都是其与安卓竞争的主力卖点。这要归功于 iOS 系统独特的闭源生态&#xff0c;应用软件上架会经过严格审核与测试。所以&#xff0c;iOS端的作弊手段&#xff0c;总是在尝试绕过 App Store 的审查。 常见的 iOS 游戏作弊&#xf…

Arrays,Arrays重载的sort方法

Arrays -1的原因.因为返回正数不就是表示存在只能是负数 Arrays重载的sort方法 //这个方法只能给引用数据类型排序 //如果是基本数据类型需要转化为对应的包装类 public class arrays {public static void main(String[] args) {Integer arr[]{2,1,4,6,3,5,8,7,9};Arrays.s…

网络套接字编程(一)

网络套接字编程(一) 文章目录 网络套接字编程(一)预备知识源IP地址和目的IP地址端口号TCP/UDP协议特点网络字节序 socket编程socket常用APIsockaddr结构 简易UDP网络程序服务端创建套接字服务端绑定IP地址和端口号字符型IP地址VS整型IP地址服务端运行客户端创建套接字客户端绑定…

mfc140u.dll丢失怎么修复,mfc140u.dll文件有什么作用

今天我想和大家分享的是关于mfc140u.dll文件丢失的解决方法。在我们使用电脑的过程中&#xff0c;有时候会遇到一些错误提示&#xff0c;其中比较常见的就是“无法找到mfc140u.dll文件”。那么&#xff0c;这个文件是什么呢&#xff1f;它有什么作用呢&#xff1f; 首先&#…

网络基础-2

IEEE制定了一个名为GARP的协议框架&#xff0c;该框架协议包含了两个具体协议&#xff0c;GMRP和GVRP。GVRP可以大大降低VLAN配置过程中的手工的工作量。 IP本身是一个协议文件的名称&#xff0c;该协议主要定义阐释了IP报文的格式。 类型网络号位数网络号个数主机号位数每个…

水溶性纳米银颗粒 纳米银颗粒 银纳米颗粒溶液

西&#xff09;产品名称&#xff1a;水溶性纳米银颗粒 安&#xff09;别名 &#xff1a;纳米银溶液 银纳米颗粒溶液 纳米银胶体等 瑞&#xff09;浓度&#xff1a;0.1mg/mL 其它均可定制 禧&#xff09;粒径&#xff1a;5nm 10nm 15nm 20nm 25nm 30nm 35nm 40nm 50nm 60nm 80…

1.6 基本安全设计准则

思维导图&#xff1a; 1.6 基本安全设计准则笔记 目标&#xff1a;理解和遵循一套广泛认可的安全设计准则&#xff0c;以指导保护机制的开发。 主要准则&#xff1a; 机制的经济性&#xff1a;安全机制应设计得简单、短小&#xff0c;便于测试和验证&#xff0c;减少漏洞和降…

【数据结构】顺序表实例探究

&#x1f497;个人主页&#x1f497; ⭐个人专栏——数据结构学习⭐ &#x1f4ab;点击关注&#x1f929;一起学习C语言&#x1f4af;&#x1f4ab; 目录 导读&#xff1a;1. 顺序表的基本内容1.1 概念及结构1.2 时间和空间复杂度1.3 基本操作1.4 顺序表的优缺点 2. 静态顺序表…

自动化测试注意事项

什么是自动化测&#xff1f; 做测试好几年了&#xff0c;真正学习和实践自动化测试一年&#xff0c;自我感觉这一个年中收获许多。一直想动笔写一篇文章分享自动化测试实践中的一些经验。终于决定花点时间来做这件事儿。 首先理清自动化测试的概念&#xff0c;广义上来讲&#…

华锐技术何志东:证券核心交易系统分布式改造将迎来规模化落地阶段

近年来&#xff0c;数字化转型成为证券业发展的下一战略高地&#xff0c;根据 2021 年证券业协会专项调查结果显示&#xff0c;71% 的券商将数字化转型列为公司战略任务。 在落地数字化转型战略过程中&#xff0c;证券业核心交易系统面临着不少挑战。构建新一代分布式核心交易…

06 MIT线性代数-线性无关,基和维数Independence, basis, and dimension

1. 线性无关 Independence Suppose A is m by n with m<n (more unknowns than equations) Then there are nonzero solutions to Ax0 Reason: there will be free variables! A中具有至少一个自由变量&#xff0c;那么Ax0一定具有非零解。A的列向量可以线性组合得到零向…

酷克数据出席永洪科技用户大会 携手驱动商业智能升级

10月27日&#xff0c;第7届永洪科技全国用户大会在北京召开。酷克数据作为国内云原生数仓代表企业&#xff0c;受邀出席本次大会&#xff0c;全面展示了云数仓领域最新前沿技术&#xff0c;并进行主题演讲。 携手合作 助力企业释放数据价值 数据仓库是商业智能&#xff08;BI…

Easy Javadoc插件的使用教程

目录 一、安装Easy Javadoc插件 二、配置注释模板 三、配置翻译 一、安装Easy Javadoc插件 在idea的File-Settings-Plugins中搜索Easy Javadoc插件&#xff0c;点击install进行安装&#xff0c;安装完成后需要restart IDE&#xff0c;重启后插件生效。 二、配置注释模板 …

openGauss学习笔记-111 openGauss 数据库管理-管理用户及权限-用户权限设置

文章目录 openGauss学习笔记-111 openGauss 数据库管理-管理用户及权限-用户权限设置111.1 给用户直接授予某对象的权限111.2 给用户指定角色111.3 回收用户权限 openGauss学习笔记-111 openGauss 数据库管理-管理用户及权限-用户权限设置 111.1 给用户直接授予某对象的权限 …

SIT3088E3.0V~5.5V 供电,ESD 15kV HBM,256 节点,14Mbps 半双工 RS485/RS422 收发器

SIT3088E 是一款 3.0V~5.5V 宽电源供电、总线端口 ESD 保护能力 HBM 达到 15kV 以上、总 线耐压范围达到 15V 、半双工、低功耗&#xff0c;功能完全满足 TIA/EIA-485 标准要求的 RS-485 收发器。 SIT3088E 包括一个驱动器和一个接收器&#xff0c;两者均可独立…

SpringCloud 微服务全栈体系(七)

第九章 Docker 一、什么是 Docker 微服务虽然具备各种各样的优势&#xff0c;但服务的拆分通用给部署带来了很大的麻烦。 分布式系统中&#xff0c;依赖的组件非常多&#xff0c;不同组件之间部署时往往会产生一些冲突。在数百上千台服务中重复部署&#xff0c;环境不一定一致…

MySQL系列-架构体系、日志、事务

MySQL架构 server 层 &#xff1a;层包括连接器、查询缓存、分析器、优化器、执行器等&#xff0c;涵盖 MySQL 的大多数核心服务功能&#xff0c;以及所有的内置函数&#xff08;如日期、时间、数学和加密函数等&#xff09;&#xff0c;所有跨存储引擎的功能都在这一层实现&am…

行情分析——加密货币市场大盘走势(10.31)

目前大饼依然在33000-36000这个位置震荡&#xff0c;需要等待指标修复&#xff0c;策略就是逢低做多&#xff0c;做短线。最近白天下跌&#xff0c;晚上涨回来&#xff0c;可以小仓位入场多单&#xff0c;晚上离场下车。 以太同样是震荡行情&#xff0c;看下来以太目前在补涨&a…