TensorFlow的可训练变量和自动求导机制

文章目录

  • 一些概念、函数、用法
  • TensorFlow实现一元线性回归
  • TensorFlow实现多元线性回归


一些概念、函数、用法

对象Variable

创建对象Variable:

tf.Variable(initial_value,dtype)

利用这个方法,默认整数为int32,浮点数为float32,注意Numpy默认的浮点数类型是float64,如果想和Numpy数据进行对比,则需要修改与numpy一致,否则在机器学习中float32位够用了。
将张量封装为可训练变量

print(tf.Variable(tf.random. normal([2,2])))

<tf.Variable ‘Variable:0’ shape=(2, 2) dtype=float32, numpy=array([[-1.2848959 , -0.22805293],[-0.79079854, 0.7035335 ]], dtype=float32)>

trainalbe属性
用来检查Variable变量是否可训练

x.trainalbe

可训练变量赋值,注意x是Variable对象类型,不是tensor类型

x.assign()
x.assign_add()
x.assign_sub()

用isinstance()方法来判断是tensor还是Variable
在这里插入图片描述
自动求导

with GradientTape() as tape:
函数表达式
grad=tape.gradient(函数,自变量)

x=tf.Variable(3.)
with tf.GradientTape() as tape:y=tf.square(x)
dy_dx = tape.gradient(y,x)
print(y)
print(dy_dx)

tf.Tensor(9.0, shape=(), dtype=float32)
tf.Tensor(6.0, shape=(), dtype=float32)

GradientTape函数

GradientTape(persistent,watch_accessed_variables)
第一个参数默认为false,表示梯度带只使用一次,使用完就销毁了,若为true则表明梯度带可以多次使用,但在循环最后要记得把它销毁
第二个参数默认为true,表示自动添加监视

tape.watch()函数
用来添加监视非可训练变量
多元函数求一阶偏导数

x=tf.Variable(3.)
y=tf.Variable(4.)
with tf.GradientTape(persistent=True) as tape:f=tf.square(x)+2*tf.square(y)+1
df_dx,df_dy = tape.gradient(f,[x,y])
first_grade = tape.gradient(f,[x,y])
print(f)
print(df_dx)
print(df_dy)
print(first_grade)
del tape

tf.Tensor(42.0, shape=(), dtype=float32)
tf.Tensor(6.0, shape=(), dtype=float32)
tf.Tensor(16.0, shape=(), dtype=float32)
[<tf.Tensor: id=36, shape=(), dtype=float32, numpy=6.0>, <tf.Tensor: id=41, shape=(), dtype=float32, numpy=16.0>]

多元函数求二阶偏导数
在这里插入图片描述

x=tf.Variable(3.)
y=tf.Variable(4.)
with tf.GradientTape(persistent=True) as tape2:with tf.GradientTape(persistent=True) as tape1:f=tf.square(x)+2*tf.square(y)+1first_grade = tape1.gradient(f,[x,y])
second_grade = [tape2.gradient(first_grade,[x,y])]
print(f)
print(first_grade)
print(second_grade)
del tape1
del tape2

tf.Tensor(42.0, shape=(), dtype=float32)
[<tf.Tensor: id=27, shape=(), dtype=float32, numpy=6.0>, <tf.Tensor: id=32, shape=(), dtype=float32, numpy=16.0>]
[[<tf.Tensor: id=39, shape=(), dtype=float32, numpy=2.0>, <tf.Tensor: id=41, shape=(), dtype=float32, numpy=4.0>]]

TensorFlow实现一元线性回归

import numpy as np
import tensorflow as tf 
import matplotlib.pyplot as plt 
#设置字体
plt.rcParams['font.sans-serif'] =['SimHei']
#加载样本数据
x=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
y=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
#设置超参数,学习率
learn_rate=0.0001
#迭代次数
iter=100
#每10次迭代显示一下效果
display_step=10
#设置模型参数初值
np.random.seed(612)
w=tf.Variable(np.random.randn())
b=tf.Variable(np.random.randn())
#训练模型
#存放每次迭代的损失值
mse=[]
for i in range(0,iter+1):with tf.GradientTape() as tape:pred=w*x+bLoss=0.5*tf.reduce_mean(tf.square(y-pred))mse.append(Loss)#更新参数dL_dw,dL_db = tape.gradient(Loss,[w,b])w.assign_sub(learn_rate*dL_dw)b.assign_sub(learn_rate*dL_db)#plt.plot(x,pred)if i%display_step==0:print("i:%i,Loss:%f,w:%f,b:%f"%(i,mse[i],w.numpy(),b.numpy()))

TensorFlow实现多元线性回归

import numpy as np
import tensorflow as tf #=======================【1】加载样本数据===============================================
area=np.array([137.97,104.50,100.00,124.32,79.20,99.00,124.00,114.00,106.69,138.05,53.75,46.91,68.00,63.02,81.26,86.21])
room=np.array([3,2,2,3,1,2,3,2,2,3,1,1,1,1,2,2])
price=np.array([145.00,110.00,93.00,116.00,65.32,104.00,118.00,91.00,62.00,133.00,51.00,45.00,78.50,69.65,75.69,95.30])
num=len(area) #样本数量
#=======================【2】数据处理===============================================
x0=np.ones(num)
#归一化处理,这里使用线性归一化
x1=(area-area.min())/(area.max()-area.min())
x2=(room-room.min())/(room.max()-room.min())
#堆叠属性数组,构造属性矩阵
#从(16,)到(16,3),因为新出现的轴是第二个轴所以axis为1
X=np.stack((x0,x1,x2),axis=1)
print(X)
#得到形状为一列的数组
Y=price.reshape(-1,1)
print(Y)
#=======================【3】设置超参数===============================================
learn_rate=0.001
#迭代次数
iter=500
#每10次迭代显示一下效果
display_step=50
#=======================【4】设置模型参数初始值===============================================
np.random.seed(612)
W=tf.Variable(np.random.randn(3,1))
#=======================【4】训练模型=============================================
mse=[]
for i in range(0,iter+1):with tf.GradientTape() as tape:PRED=tf.matmul(X,W)Loss=0.5*tf.reduce_mean(tf.square(Y-PRED))mse.append(Loss)#更新参数dL_dw = tape.gradient(Loss,W)W.assign_sub(learn_rate*dL_dw)#plt.plot(x,pred)if i % display_step==0:print("i:%i,Loss:%f"%(i,mse[i]))

喜欢的话点个赞和关注呗!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/378345.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

django第二个项目--使用模板做一个站点访问计数器

上一节讲述了django和第一个项目HelloWorld&#xff0c;这节我们讲述如何使用模板&#xff0c;并做一个简单的站点访问计数器。 1、建立模板 在myblog模块文件夹&#xff08;即包含__init__.py的文件夹)下面新建一个文件夹templates&#xff0c;用于存放HTML模板&#xff0c;在…

c语言math乘法,JavaScript用Math.imul()方法进行整数相乘

1. 基本概念Math.imul()方法用于计算两个32位整数的乘积&#xff0c;它的结果也是32位的整数。JavaScript的Number类型同时包含了整数和浮点数&#xff0c;它没有专门的整型和浮点型。因此&#xff0c;Math.imul()方法能提供类似C语言的整数相乘的功能。我们将Math.imul()方法的…

梯度下降法预测波士顿房价以及简单的模型评估

目录原理代码关于归一化的思考原理 观察数据可知属性之间差距很大&#xff0c;为了平衡所有的属性对模型参数的影响&#xff0c;首先进行归一化处理。 每一行是一个记录&#xff0c;每一列是个属性&#xff0c;所以对每一列进行归一化。 二维数组归一化&#xff1a;1、循环方式…

Windows Phone 内容滑动切换实现

在新闻类的APP中&#xff0c;有一个经常使用的场景&#xff1a;左右滑动屏幕来切换上一条或下一条新闻。 那么通常我们该使用哪种方式去实现呢&#xff1f;可以参考一下Demo的实现步骤。 1&#xff0c;添加Windows Phone用户自定义控件。例如&#xff1a; 这里我为了演示的方便…

使用鸢尾花数据集实现一元逻辑回归、多分类问题

目录鸢尾花数据集逻辑回归原理【1】从线性回归到广义线性回归【2】逻辑回归【3】损失函数【4】总结TensorFlow实现一元逻辑回归多分类问题原理独热编码多分类的模型参数损失函数CCETensorFlow实现多分类问题独热编码计算准确率计算交叉熵损失函数使用花瓣长度、花瓣宽度将三种鸢…

【神经网络计算】——神经网络实现鸢尾花分类

本blog为观看MOOC视频与网易云课堂所做的笔记 课堂链接&#xff1a; 人工智能实践:TensorFlow笔记 吴恩达机器学习 疑问与思考 为什么按照batch喂入数据 之前看的视频里面处理数据都是一次性将所有数据喂入&#xff0c;现在看的这个视频对数据进行了分组投入。这是为何&#…

c# xaml语言教程,c#学习之30分钟学会XAML

1.狂妄的WPF相对传统的Windows图形编程&#xff0c;需要做很多复杂的工作&#xff0c;引用许多不同的API。例如&#xff1a;WinForm(带控件表单)、GDI(2D图形)、DirectXAPI(3D图形)以及流媒体和流文档等&#xff0c;都需要不同的API来构建应用程序。WPF就是看着上面的操作复杂和…

.NET 小结之内存模型

.NET 小结之内存模型 为什么要解.NET 的内存模型 在.NET下的内存管理、垃圾回收其实大部分不需要我们操心&#xff0c;因为大部分.NET已经帮我们做了&#xff0c;通常情况下也不需要考虑这些。但是如果想要了解一些.NET一些稍微“底层”的原理&#xff0c;如&#xff1a;“装箱…

【电设控制与图像训练题】【激光打靶】【openmv测试代码以及效果】

9.4加入串口通讯,送出靶心坐标、激光坐标、激光所在环数、方位;加入防误判操作 博主联系方式: QQ:1540984562 QQ交流群:892023501 群里会有往届的smarters和电赛选手,群里也会不时分享一些有用的资料,有问题可以在群里多问问。 目录 规则坐标系代码总结相关openmv使用文…

MVC3中的视图文件

在MVC3中的视图部分&#xff0c;Razor视图引擎是与以往不同的地方之一&#xff0c;使用Razor的视图文件再也不是以往的ASPX文件了&#xff0c;是cshtml文件&#xff0c;在新建视图的时候也会发现增加多了几类文件 由上到下分别是 MVC 3 Layout Page&#xff1a;与原来Web Form的…

C语言 链表拼接 PTA,PTA实验 链表拼接 (20point(s))

本题要求实现一个合并两个有序链表的简单函数。链表结点定义如下&#xff1a;struct ListNode {int data;struct ListNode *next;};函数接口定义&#xff1a;struct ListNode *mergelists(struct ListNode *list1, struct ListNode *list2);其中list1和list2是用户传入的两个按…

【TensorFlow学习笔记:神经网络优化(6讲)】

目录【1】NN复杂度【2】指数衰减学习率【3】激活函数优秀激活函数所具有的特点常见的激活函数对于初学者的建议【4】损失函数【5】缓解过拟合——正则化【6】参数优化器【1】SGD【2】SGDM(SGD基础上增加了一阶动量)【3】Adagrade(SGD基础上增加了二阶动量)【4】RMSProp(SGD基础…

第十章 开箱即用

第十章 开箱即用 “开箱即用”&#xff08;batteries included&#xff09;最初是由Frank Stajano提出的&#xff0c;指的是Python丰富的标准库。 模块 使用import将函数从外部模块导入到程序中。 import math math.sin(0)#结果为&#xff1a;0.0模块就是程序 在文件夹中创…

Openmv通过串口接收数据、发送数据与stm32通信

博主联系方式: QQ:1540984562 QQ交流群:892023501 群里会有往届的smarters和电赛选手,群里也会不时分享一些有用的资料,有问题可以在群里多问问。 目录 参考接线星瞳教程openmv传送数据STM32解码程序参考 接线 星瞳教程

c语言尹宝林答案,c程序设计导引 尹宝林

《C程序设计导引》特别适合作为计算机和非计算机专业学生学习高级语言程序设计的教材&#xff0c;也可供计算机等级考试者和其他各类学习者使用参考。17.40定价&#xff1a;44.75(3.89折)/2013-05-01《大学计算机优秀教材系列&#xff1a;C程序设计导引》是一本讲解C程序设计的…

第十一章 文件

第十一章 文件 打开文件 当前目录中有一个名为beyond.txt的文本文件&#xff0c;打开该文件 调用open时&#xff0c;原本可以不指定模式&#xff0c;因为其默认值就是’r’。 import io f open(beyond.txt)文件模式 值描述‘r’读取模式&#xff08;默认值&#xff09;‘w…

【TensorFlow学习笔记:神经网络八股】(实现MNIST数据集手写数字识别分类以及FASHION数据集衣裤识别分类)

课程来源&#xff1a;人工智能实践:Tensorflow笔记2 文章目录前言一、搭建网络八股sequential1.函数介绍2.6步法实现鸢尾花分类二、搭建网络八股class1.创建自己的神经网络模板&#xff1a;2.调用自己创建的model对象三、MNIST数据集1.用sequential搭建网络实现手写数字识别2.用…

第十二章 图形用户界面

第十二章 图形用户界面 GUI就是包含按钮、文本框等控件的窗口 Tkinter是事实上的Python标准GUI工具包 创建GUI示例应用程序 初探 导入tkinter import tkinter as tk也可导入这个模块的所有内容 from tkinter import *要创建GUI&#xff0c;可创建一个将充当主窗口的顶级组…

Sqlserver 2005 配置 数据库镜像:数据库镜像期间可能出现的故障:镜像超时机制

数据库镜像期间可能出现的故障 SQL Server 2005其他版本更新日期&#xff1a; 2006 年 7 月 17 日 物理故障、操作系统故障或 SQL Server 故障都可能导致数据库镜像会话失败。数据库镜像不会定期检查 Sqlservr.exe 所依赖的组件来验证组件是在正常运行还是已出现故障。但对于某…

【神经网络八股扩展】:自制数据集

课程来源&#xff1a;人工智能实践:Tensorflow笔记2 文章目录前言1、文件一览2、将load_data()函数替换掉2、调用generateds函数4、效果总结前言 本讲目标:自制数据集&#xff0c;解决本领域应用 将我们手中的图片和标签信息制作为可以直接导入的npy文件。 1、文件一览 首先看…