3 Tensorflow构建模型详解

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

本篇目标是介绍如何构建一个简单的线性回归模型,要点如下:

  • 了解神经网络原理
  • 构建模型的一般步骤
  • 模型重要参数介绍


1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
神经网络是一种广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应。

首先我们要了解下密集层(也叫全连接层),密集层是一个深度连接的神经网络层,在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

在上一篇我们创建了预测价格模型,代码为:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

其中Sequential是顺序的意思,Dense就是密集层。

看文字有点抽象,举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连,这样形成的神经网络就是密集层。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

了解这些基本的原理后,我们就开始创建一个简单的费用预测模型。

3、西瓜费用预测模型详解

代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=500)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

具体流程如下:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)设置训练数据
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=5)# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]print('w:')
print(w)
print('b: ')
print(b)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 

4、创建更复杂一点的模型

现实生活中我们要预测的东西影响因素可能有很多个,如房价预测,房价可能受到房屋面积、房间数量等等因素影响。思考一下,下面的神经网络图创建模型时要如何设置参数呢?

model = tf.keras.Sequential([tf.keras.layers.Dense(2, input_shape=[3]),tf.keras.layers.Dense(1)
])


 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/125410.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

菜单管理中icon图标回显

<el-table-column prop"icon" label"图标" show-overflow-tooltip algin"center"><template v-slot"{ row }"><el-icon :class"row.icon"></el-icon></template></el-table-column>

跟我学C++中级篇——右值引用和万能引用

一、右值引用 在C11中出现了右值引用&#xff0c;想知道右值引用&#xff0c;就必须知道右值。在前面分析过左右和消亡值等类型&#xff08;见“左值和右值再谈”&#xff09;&#xff0c;其实右值就是为了废物利用&#xff0c;而既然利用的好&#xff0c;就有了和左值一样的引…

Oracle数据库创建Sequence序列的基本使用

1.作用就是批量插入数据的时候可以给一个主键 sequence dose not exist_sequence not exist_拒—绝的博客-CSDN博客 Oracle创建Sequence序列_TheEzreal的博客-CSDN博客 Oracle序列&#xff08;sequence&#xff09;创建失败&#xff0c;无法取值&#xff08;.nextval&#x…

iOS iGameGuardian修改器检测方案

一直以来&#xff0c;iOS 系统的安全性、稳定性都是其与安卓竞争的主力卖点。这要归功于 iOS 系统独特的闭源生态&#xff0c;应用软件上架会经过严格审核与测试。所以&#xff0c;iOS端的作弊手段&#xff0c;总是在尝试绕过 App Store 的审查。 常见的 iOS 游戏作弊&#xf…

AlarmManager闹钟管理者

AlarmManager是Android提供的一个全局定时器&#xff0c;利用系统闹钟定时发送广播。这样做的好处是&#xff1a;如果App提前注册闹钟的广播接收器&#xff0c;即使App退出了&#xff0c;只要定时到达&#xff0c;App就会被唤醒响应广播事件。 AlarmManager设置的PendingInten…

0049【Edabit ★☆☆☆☆☆】【修改Bug代码】Buggy Code

0049【Edabit ★☆☆☆☆☆】【修改Bug代码】Buggy Code bugs language_fundamentals Instructions The challenge is to try and fix this buggy code, given the inputs true and false. See the examples below for the expected output. Examples has_bugs(true) // &qu…

Arrays,Arrays重载的sort方法

Arrays -1的原因.因为返回正数不就是表示存在只能是负数 Arrays重载的sort方法 //这个方法只能给引用数据类型排序 //如果是基本数据类型需要转化为对应的包装类 public class arrays {public static void main(String[] args) {Integer arr[]{2,1,4,6,3,5,8,7,9};Arrays.s…

嵌入式驱动开发之框架及调试技巧累积

框架准备 基本的框架app如何调用驱动机制字符设备驱动编写步骤1. 实现入口函数 XXX_init()和卸载函数 XXX_exit()2. 申请设备号 register_chrdev_region(与内核相关)3. 注册字符设备驱动 cdev_alloc / cdev_init /cdev_add(与内核相关)4. 利用udev/mdev机制创建设备文件(节点)…

C# 使用 AES 加解密文件

[作者:张赐荣] 对称加密是一种加密技术&#xff0c;它使用相同的密钥来加密和解密数据。换句话说&#xff0c;加密者和解密者需要共享同一个密钥&#xff0c;才能进行通信。 对称加密的优点是速度快&#xff0c;效率高&#xff0c;适合大量数据的加密。对称加密的缺点是密钥的管…

网络套接字编程(一)

网络套接字编程(一) 文章目录 网络套接字编程(一)预备知识源IP地址和目的IP地址端口号TCP/UDP协议特点网络字节序 socket编程socket常用APIsockaddr结构 简易UDP网络程序服务端创建套接字服务端绑定IP地址和端口号字符型IP地址VS整型IP地址服务端运行客户端创建套接字客户端绑定…

mfc140u.dll丢失怎么修复,mfc140u.dll文件有什么作用

今天我想和大家分享的是关于mfc140u.dll文件丢失的解决方法。在我们使用电脑的过程中&#xff0c;有时候会遇到一些错误提示&#xff0c;其中比较常见的就是“无法找到mfc140u.dll文件”。那么&#xff0c;这个文件是什么呢&#xff1f;它有什么作用呢&#xff1f; 首先&#…

【python代码】对图片进行数据增强(直方图均衡和加噪声)

文章目录 1. 代码 1. 代码 import os import cv2 import albumentations as A from tqdm import tqdm from glob import glob import numpy as np# 方法1&#xff1a;加入噪声 trans2 A.Compose([A.RandomBrightnessContrast(p0.5),A.HueSaturationValue(p0.5),A.OneOf([ A.A…

网络基础-2

IEEE制定了一个名为GARP的协议框架&#xff0c;该框架协议包含了两个具体协议&#xff0c;GMRP和GVRP。GVRP可以大大降低VLAN配置过程中的手工的工作量。 IP本身是一个协议文件的名称&#xff0c;该协议主要定义阐释了IP报文的格式。 类型网络号位数网络号个数主机号位数每个…

水溶性纳米银颗粒 纳米银颗粒 银纳米颗粒溶液

西&#xff09;产品名称&#xff1a;水溶性纳米银颗粒 安&#xff09;别名 &#xff1a;纳米银溶液 银纳米颗粒溶液 纳米银胶体等 瑞&#xff09;浓度&#xff1a;0.1mg/mL 其它均可定制 禧&#xff09;粒径&#xff1a;5nm 10nm 15nm 20nm 25nm 30nm 35nm 40nm 50nm 60nm 80…

Ceph入门到精通-文件条带化 stripe unit,chunk

文件条带化 以下文本描述了 Ceph 文件系统客户端中的文件是如何 存储在 RADOS 中的对象之间。 CEPH_FILE_LAYOUT Ceph 将给定文件的数据分布&#xff08;条带化&#xff09;到多个 的基础对象。文件数据映射到这些对象的方式 由ceph_file_layout结构定义。数据分布 是修改后…

0052【Edabit ★☆☆☆☆☆】Learn Lodash: _.drop, Drop the First Elements of an Array

0052【Edabit ★☆☆☆☆☆】Learn Lodash: _.drop, Drop the First Elements of an Array arrays Instructions According to the lodash documentation, _.drop creates a slice of an array with n elements dropped from the beginning. Your challenge is to write your…

RabbitMQ一条消息被多个消费者消费

前言&#xff1a;可略过 正常情况下交换机和queue绑定&#xff0c;消息经过交换机发送给指定的队列。队列的消息被监听消费后就被删除&#xff0c;queue的消息仅能被消费一次。 如何解决呢&#xff0c;如果单从queue的角度出发&#xff0c;可能会联想到fanout-广播模式&#xf…

1.6 基本安全设计准则

思维导图&#xff1a; 1.6 基本安全设计准则笔记 目标&#xff1a;理解和遵循一套广泛认可的安全设计准则&#xff0c;以指导保护机制的开发。 主要准则&#xff1a; 机制的经济性&#xff1a;安全机制应设计得简单、短小&#xff0c;便于测试和验证&#xff0c;减少漏洞和降…

【数据结构】顺序表实例探究

&#x1f497;个人主页&#x1f497; ⭐个人专栏——数据结构学习⭐ &#x1f4ab;点击关注&#x1f929;一起学习C语言&#x1f4af;&#x1f4ab; 目录 导读&#xff1a;1. 顺序表的基本内容1.1 概念及结构1.2 时间和空间复杂度1.3 基本操作1.4 顺序表的优缺点 2. 静态顺序表…

自动化测试注意事项

什么是自动化测&#xff1f; 做测试好几年了&#xff0c;真正学习和实践自动化测试一年&#xff0c;自我感觉这一个年中收获许多。一直想动笔写一篇文章分享自动化测试实践中的一些经验。终于决定花点时间来做这件事儿。 首先理清自动化测试的概念&#xff0c;广义上来讲&#…