深度学习基础之《TensorFlow框架(9)—案例:实现线性回归》

一、线性回归原理复习

1、什么是线性回归
(1)有个假设函数,假定特征值和目标值满足这样的关系
w1x1 + w2x2 + ... + wnxn + b = y
(2)构造损失函数
均方误差、最小二乘法
(3)优化损失
正规方程和梯度下降
(4)当梯度下降到一定程度,使得损失函数比较小的时候,所对应的权重和偏置,就是我们要求的模型参数

二、案例:实现线性回归的训练

1、案例背景
(1)假设随机指定100个点,只有一个特征
(2)数据本身的分布为 y = 0.8 * x + 0.7
(3)这里将数据分布的规律确定,是为了使我们训练出的参数跟真实的参数(即0.8和0.7)比较是否训练准确

2、准备真实数据
x:特征值
y_true:目标值
y_true = 0.8 * x + 0.7
假定x和y之间的关系,满足y = kx + b
经过线性回归,求出来k≈0.8,b≈0.7

3、流程分析
(1)准备100个样本
(100, 1) * (1, 1) = (100, 1)
100行1列,乘以1行1列,得出100行1列

y_predict = x * 权重(1, 1) + 偏置(1, 1)
y_predict = tf.matmul(x, weights) + bias

(2)构造损失函数
error = tf.reduce_mean(tf.square(y_predict - y_true))
求均方误差

(3)优化损失函数
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)

(4)训练
反复的run这个optimizer,不断的更新迭代

4、用到的API
矩阵乘法:tf.matmul(x, w)
平方:tf.square(error)
均值:tf.reduce_mean(error)

梯度下降优化:
tf.train.GradientDescentOptimizer(learning_rate)
说明:
(1)梯度下降优化器
(2)learning_rate:学习率,一般为0-1之间比较小的值
(3)梯度下降优化器实例化后的方法
    minimize(loss):让error最小化
(4)return:梯度下降op

tensorflow2.0版本用tf.keras.optimizers.SGD
优化器对照:https://blog.csdn.net/u013587606/article/details/105138271

5、代码实现(tensorflow2.0版本写法)

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tfdef tensorflow_demo():"""TensorFlow的基本结构"""# TensorFlow实现加减法运算a_t = tf.constant(2)b_t = tf.constant(3)c_t = a_t + b_tprint("TensorFlow加法运算结果:\n", c_t)print(c_t.numpy())# 2.0版本不需要开启会话,已经没有会话模块了return Nonedef graph_demo():"""图的演示"""# TensorFlow实现加减法运算a_t = tf.constant(2)b_t = tf.constant(3)c_t = a_t + b_tprint("TensorFlow加法运算结果:\n", c_t)print(c_t.numpy())# 查看默认图# 方法1:调用方法default_g = tf.compat.v1.get_default_graph()print("default_g:\n", default_g)# 方法2:查看属性# print("a_t的图属性:\n", a_t.graph)# print("c_t的图属性:\n", c_t.graph)# 自定义图new_g = tf.Graph()# 在自己的图中定义数据和操作with new_g.as_default():a_new = tf.constant(20)b_new = tf.constant(30)c_new = a_new + b_newprint("c_new:\n", c_new)print("a_new的图属性:\n", a_new.graph)print("b_new的图属性:\n", b_new.graph)# 开启new_g的会话with tf.compat.v1.Session(graph=new_g) as sess:c_new_value = sess.run(c_new)print("c_new_value:\n", c_new_value)print("我们自己创建的图为:\n", sess.graph)# 可视化自定义图# 1)创建一个writerwriter = tf.summary.create_file_writer("./tmp/summary")# 2)将图写入with writer.as_default():tf.summary.graph(new_g)return Nonedef session_run_demo():"""feed操作"""tf.compat.v1.disable_eager_execution()# 定义占位符a = tf.compat.v1.placeholder(tf.float32)b = tf.compat.v1.placeholder(tf.float32)sum_ab = tf.add(a, b)print("a:\n", a)print("b:\n", b)print("sum_ab:\n", sum_ab)# 开启会话with tf.compat.v1.Session() as sess:print("占位符的结果:\n", sess.run(sum_ab, feed_dict={a: 1.1, b: 2.2}))return Nonedef tensor_demo():"""张量的演示"""tensor1 = tf.constant(4.0)tensor2 = tf.constant([1, 2, 3, 4])linear_squares = tf.constant([[4], [9], [16], [25]], dtype=tf.int32)print("tensor1:\n", tensor1)print("tensor2:\n", tensor2)print("linear_squares:\n", linear_squares)# 张量类型的修改l_cast = tf.cast(linear_squares, dtype=tf.float32)print("before:\n", linear_squares)print("l_cast:\n", l_cast)return Nonedef variable_demo():"""变量的演示"""a = tf.Variable(initial_value=50)b = tf.Variable(initial_value=40)c = tf.add(a, b)print("a:\n", a)print("b:\n", b)print("c:\n", c)with tf.compat.v1.variable_scope("my_scope"):d = tf.Variable(initial_value=30)e = tf.Variable(initial_value=20)f = tf.add(d, e)print("d:\n", d)print("e:\n", e)print("f:\n", f)return Nonedef linear_regression():"""自实现一个线性回归"""# 1、准备数据x = tf.random.normal(shape=[100,1])y_true = tf.matmul(x, [[0.8]]) + 0.7# 2、构造模型# 定义模型参数,用变量weights = tf.Variable(initial_value=tf.random.normal(shape=[1, 1]))bias = tf.Variable(initial_value=tf.random.normal(shape=[1, 1]))y_predict = tf.matmul(x, weights) + bias# 3、构造损失函数error = tf.reduce_mean(tf.square(y_predict - y_true))# 4、优化损失#optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)# 5、查看初始化模型参数之后的值print("训练前模型参数为:权重%f,偏置%f,损失%f" % (weights, bias, error))# 6、开始训练num_epoch = 10000 # 定义迭代次数for e in range(num_epoch): # 迭代多次with tf.GradientTape() as tape:y_predict = tf.matmul(x, weights) + biaserror = tf.reduce_mean(tf.square(y_predict - y_true))grads = tape.gradient(error, [weights, bias]) # 求损失关于参数weights、bias的梯度optimizer.apply_gradients(grads_and_vars=zip(grads, [weights, bias])) # 自动根据梯度更新参数,即利用梯度信息修改weights与bias,使得损失减小print("训练后模型参数为:权重%f,偏置%f,损失%f" % (weights, bias, error))return Noneif __name__ == "__main__":# 代码1:TensorFlow的基本结构# tensorflow_demo()# 代码2:图的演示#graph_demo()# feed操作#session_run_demo()# 代码4:张量的演示#tensor_demo()# 代码5:变量的演示#variable_demo()# 代码6:自实现一个线性回归linear_regression()

运行结果:

训练前模型参数为:权重-0.941857,偏置-0.845241,损失4.750606
训练后模型参数为:权重0.799998,偏置0.699998,损失0.000000

训练后,权重无限接近0.8,偏置无限接近0.7

参考资料:
https://stackoverflow.com/questions/68879963/valueerror-tape-is-required-when-a-tensor-loss-is-passed
https://blog.csdn.net/AwesomeP/article/details/123787448
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/759308.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

electron-builder打包

打包配置: "build": {"appId": "cc11001100.electron.example-001", // 程序包名"copyright": "CC11001100", // 版权相关信息"productName": "example-001", // 安装包文件名"direct…

easyExcel 读取excel(按条读取)

MAVEN <!-- https://mvnrepository.com/artifact/com.alibaba/easyexcel --><dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>3.0.5</version></dependency>代码 import com.al…

Python 解析CSV文件 使用Matplotlib绘图

数据存储在CSV文件中&#xff0c;使用Matplotlib实现数据可视化。 CSV文件&#xff1a;comma-separated values&#xff0c;是在文件中存储一系列以‘&#xff0c;’分隔的值。 例如&#xff1a;"0.0","2016-01-03","1","3","20…

电子电工基础-二极管

二极管&#xff1a;单向导电性 工作区域&#xff1a;截止区、放大区、饱和区、反向击穿区 相关计算题 注意点&#xff1a;正向压降为0.7V&#xff0c;但是电流小&#xff0c;可以设为0.6V 在对其进行静态分析 可以得出静态直流的电流大小Id 根据二极管电流为26ma的特性&…

力扣刷题Days23-35.搜索插入的位置(js)

1&#xff0c;题目 给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;并返回其索引。如果目标值不存在于数组中&#xff0c;返回它将会被按顺序插入的位置。请必须使用时间复杂度为 O(log n) 的算法。 2&#xff0c;代码 /*** param {number[]} nums*…

Vue+SpringBoot打造智慧家政系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示四、核心代码4.1 查询家政服务4.2 新增单条服务订单4.3 新增留言反馈4.4 小程序登录4.5 小程序数据展示 五、免责说明 一、摘要 1.1 项目介绍 基于微信小程序JAVAVueSpringBootMySQL的智慧家政系统&#xff0…

鸿蒙 Launcher与android Launcher的开发区别

鸿蒙&#xff08;HarmonyOS&#xff09;Launcher与Android Launcher在某些方面相似&#xff0c;但也存在一些明显的区别。尽管鸿蒙Launcher和Android Launcher都是用于用户与设备交互的界面&#xff0c;但由于底层架构、生态系统、开发语言和工具等方面的差异&#xff0c;它们在…

2024年亚洲图像处理趋势会议(ATIP 2024)即将召开!

2024年亚洲图像处理趋势会议&#xff08;简称&#xff1a;ATIP 2024&#xff09;将于2024年6月21日至23日在英国伦敦举行。在会议上我们将与相关领域的研究人员和知名专业人士共同讨论关于图像处理学科的最新研究方向及进展&#xff0c;评估当前最先进的技术和未来研究的关键领…

使用el-cascader组件写下拉级联多选并且具有全选功能

样式 说明&#xff1a; 级联选择器中加上全选的按钮&#xff0c; 并且保证数据响应式。 思路 因为是有全选的功能&#xff0c;所以不能直接使用el-cascader组件&#xff0c; 而是选择使用el-select组件&#xff0c; 在此组件内部使用el-cascader-panel级联面板全选按钮也是…

Win11配置WSL(Ubuntu)环境

一&#xff0c;什么是WSL WSL:Windows Subsystem for Linux&#xff0c;是用于Windows系统之上的Linux子系统。作用很简单&#xff0c;可以在Windows系统中获得Linux系统环境&#xff0c;并完全直连计算机硬件&#xff0c;无需通过虚拟机虚拟硬件 简而言之: Windows10的WSL功能…

【Android】图解View事件分发机制

文章目录 View事件分发机制dispartchTouchEvent()dispatchTouchEvent() 方法主要负责什么&#xff1f; onTouchEvent(event) 点击事件分发的传递规则自上而下自下而上 View事件分发机制 View的事件分发机制是Android中非常核心的一个概念&#xff0c;它负责处理触摸事件&#…

Request failed with status code 504,Gateway time out

问题描述&#xff1a; 部署在测试环境的项目在执行某功能时&#xff0c;后台程序在执行过程中&#xff0c;前端控制台在一分钟左右会报出Request failed with status code 504&#xff0c;Gateway time out异常。但是在本地开发环境会正常运行&#xff0c;并不会报出异常。 问题…

LeetCode 678:有效的括号字符串 ← 贪心算法

【题目来源】https://leetcode.cn/problems/valid-parenthesis-string/description/【题目描述】 给你一个只包含三种字符的字符串&#xff0c;支持的字符类型分别是 (、) 和 *。请你检验这个字符串是否为有效字符串&#xff0c;如果是有效字符串返回 true 。 有效字符串符合如…

【黄金手指】windows操作系统环境下使用jar命令行解压和打包Springboot项目jar包

一、背景 项目中利用maven将Springboot项目打包成生产环境jar包。名为 prod_2024_1.jar。 需求是 修改配置文件中的某些参数值&#xff0c;并重新发布。 二、解压 jar -xvf .\prod_2024_1.jar释义&#xff1a; 这段命令是用于解压缩名为"prod_2024_1.jar"的Java归…

用Vmware创建并运行Ubuntu64虚拟机,安装配置跳坑记录

起因&#xff1a; 为了学习正点原子的Linux开发板&#xff0c;按照教程用Vmware创建并运行Ubuntu64虚拟机。本以为很简单的步骤&#xff0c;结果跳了一些坑。以下是按照先后顺序遇到的问题与解决方法&#xff1a; 遇到的问题与解决方法&#xff1a; 1、需要用到一个新的空硬盘…

OpenGL学习笔记【2】——开发环境配置(GLFW,VS,Cmake),创建第一个项目

学OpenGL的都会知道&#xff0c;OpenGL只提供了绘图功能&#xff0c;创建窗口是需要自己完成的。这就需要学习相应操作系统的创建窗口方法&#xff0c;为简化创建窗口的过程&#xff0c;可以使用专门的窗口库&#xff0c;例如GLFW。使用GLFW之前需要先进行配置&#xff0c;那怎…

css实现的3D立体视觉效果鸡蛋动画特效

这是一个基于纯css实现的3D立体视觉效果鸡蛋动画特效&#xff0c;喜欢的朋友可以拿来使用演示动态效果 css实现的3D立体视觉效果鸡蛋动画特效

spark RDD 创建及相关算子

RDD编程入口 RDD编程入口对象是SparkContext对象&#xff0c;想要调用相关的计算api都需要通过构造出的sparkcontext对象调用 RDD的创建 通过并行化集合创建RDD&#xff08;本地集合转为分布式&#xff09;&#xff0c;api如下 rdd sc.parrallize(param1, param2)参数1是本…

修复 Java 错误 Java.Net.SocketException: Permission Denied

本篇文章介绍了 Java 中的 java.net.SocketException&#xff1a;Permission denied 错误。 Java中出现 java.net.SocketException: Permission returned 错误的原因 SocketException 通常在网络连接出现问题时发生。 它可以是权限被拒绝、连接重置或其他任何情况。 当网络没…

Linux:点命令source

相关阅读 Linuxhttps://blog.csdn.net/weixin_45791458/category_12234591.html?spm1001.2014.3001.5482 source命令用于读取一个文件的内容并在当前Shell环境&#xff08;包括交互式Shell或是非交互式Shell&#xff09;执行里面的命令。它被称为点命令是因为命令名source也可…