pytorch如何定义损失函数_对比PyTorch和TensorFlow的自动差异和动态模型

使用自定义模型类从头开始训练线性回归,比较PyTorch 1.x和TensorFlow 2.x之间的自动差异和动态模型子类化方法,

这篇简短的文章重点介绍如何在PyTorch 1.x和TensorFlow 2.x中分别使用带有模块/模型API的动态子类化模型,以及这些框架在训练循环中如何使用AutoDiff获得损失的梯度并从头开始实现 一个非常幼稚的渐变后代实现。

ede8987aef184437843e4e4d501b9d92

生成噪声的线性数据

为了专注于自动差异/自动渐变功能的核心,我们将使用最简单的模型,即线性回归模型,然后我们将首先使用numpy生成一些线性数据,以添加随机级别的噪声。

def generate_data(m=0.1, b=0.3, n=200):  x = np.random.uniform(-10, 10, n)  noise = np.random.normal(0, 0.15, n)  y = (m * x + b ) + noise  return x.astype(np.float32), y.astype(np.float32)x, y = generate_data()plt.figure(figsize = (12,5))ax = plt.subplot(111)ax.scatter(x,y, c = "b", label="samples")
1e69bd7abe1c4ad38a63e49733d7c859

模型

然后,我们将在TF和PyTorch中实现从零开始的线性回归模型,而无需使用任何层或激活器,而只需定义两个张量w和b,分别代表线性模型的权重和偏差,并简单地实现线性函数即可:y = wx + b

正如您在下面看到的,我们的模型的TF和PyTorch类定义基本上完全相同,但在一些api名称上只有很小的差异。

唯一值得注意的区别是,PyTorch明确地使用Parameter对象定义权重和要由图形"捕获"的偏置张量,而TF似乎在这里更"神奇",而是自动捕获用于图形的参数。

确实在PyTorch参数中是Tensor子类,当与Module api一起使用时,它们具有非常特殊的属性,可以自动将自身添加到Module参数列表中,并会出现在在parameters()迭代器中。

无论如何,两个框架都能够从此类定义和执行方法(callforward ),参数和图形定义中提取信息,以便向前执行图形执行,并且正如我们将看到的那样,通过自动可微分获得梯度功能,以便能够执行反向传播。

TensorFlow动态模型

class LinearRegressionKeras(tf.keras.Model):  def __init__(self):    super().__init__()    self.w = tf.Variable(tf.random.uniform(shape=[1], -0.1, 0.1))    self.b = tf.Variable(tf.random.uniform(shape=[1], -0.1, 0.1))      def __call__(self,x):     return x * self.w + self.b

PyTorch动态模型

class LinearRegressionPyTorch(torch.nn.Module):   def __init__(self):     super().__init__()     self.w = torch.nn.Parameter(torch.Tensor(1, 1).uniform_(-0.1, 0.1))    self.b = torch.nn.Parameter(torch.Tensor(1).uniform_(-0.1, 0.1))    def forward(self, x):      return x @ self.w + self.b

训练循环,反向传播和优化器

现在我们已经实现了简单的TensorFlow和PyTorch模型,我们可以定义TF和PyTorch api来实现均方误差的损失函数,最后实例化我们的模型类并运行训练循环。

同样,本着眼于自动差异/自动渐变功能核心的目的,我们将使用TF和PyTorch特定的自动差异实现方式实现自定义训练循环,以便为我们的简单线性函数提供渐变并手动优化权重和偏差参数以及临时和朴素的渐变后代优化器。

在TensorFlow训练循环中,我们将特别明确地使用GradientTape API来记录模型的正向执行和损失计算,然后从该GradientTape中获得用于优化权重和偏差参数的梯度。

相反,在这种情况下,PyTorch提供了一种更"神奇"的自动渐变方法,隐式捕获了对参数张量的任何操作,并为我们提供了相同的梯度以用于优化权重和偏置参数,而无需使用任何特定的api。

一旦我们有了权重和偏差梯度,就可以在PyTorch和TensorFlow上实现我们的自定义梯度派生方法,就像将权重和偏差参数减去这些梯度乘以恒定的学习率一样简单。

此处的最后一个微小区别是,当PyTorch在向后传播中更新权重和偏差参数时,以更隐蔽和"魔术"的方式实现自动差异/自动graf时,我们需要确保不要继续让PyTorch从最后一次更新操作中提取grad,这次明确调用no_grad api,最后将权重和bias参数的梯度归零。

TensorFlow训练循环

def squared_error(y_pred, y_true):  return tf.reduce_mean(tf.square(y_pred - y_true))tf_model = LinearRegressionKeras()[w, b] = tf_model.trainable_variablesfor epoch in range(epochs):  with tf.GradientTape() as tape:    predictions = tf_model(x)    loss = squared_error(predictions, y)          w_grad, b_grad = tape.gradient(loss, tf_model.trainable_variables)  w.assign(w - w_grad * learning_rate)  b.assign(b - b_grad * learning_rate)  if epoch % 20 == 0:    print(f"Epoch {epoch} : Loss {loss.numpy()}")

PyTorch训练循环

def squared_error(y_pred, y_true):  return torch.mean(torch.square(y_pred - y_true))torch_model = LinearRegressionPyTorch()[w, b] = torch_model.parameters()for epoch in range(epochs):  y_pred = torch_model(inputs)  loss = squared_error(y_pred, labels)  loss.backward()    with torch.no_grad():    w -= w.grad * learning_rate    b -= b.grad * learning_rate    w.grad.zero_()    b.grad.zero_()      if epoch % 20 == 0:    print(f"Epoch {epoch} : Loss {loss.data}")

结论

正如我们所看到的,TensorFlow和PyTorch自动区分和动态子分类API非常相似,当然,两种模型的训练也给我们非常相似的结果。

在下面的代码片段中,我们将分别使用Tensorflow和PyTorch trainable_variables和parameters方法来访问模型参数并绘制学习到的线性函数的图。

绘制结果

[w_tf, b_tf] = tf_model.trainable_variables[w_torch, b_torch] = torch_model.parameters()with torch.no_grad():  plt.figure(figsize = (12,5))  ax = plt.subplot(111)  ax.scatter(x, y, c = "b", label="samples")  ax.plot(x, w_tf * x + b_tf, "r", 5.0, "tensorflow")  ax.plot(x, w_torch * inputs + b_torch, "c", 5.0, "pytorch")  ax.legend()  plt.xlabel("x1")  plt.ylabel("y",rotation = 0)
36885df4cac44a2a9e2c96ae8e65a271

作者:Jacopo Mangiavacchi

本文代码:github/JacopoMangiavacchi/TF-VS-PyTorch

deephub翻译组

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/365964.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Gradle命令行便利

在我的《用Gradle构建Java的gradle tasks 》一文中,我简要地提到了使用Gradle的“ gradle tasks ”命令来查看特定Gradle构建的可用任务。 在这篇文章中,我将对这一简短提及进行更多的扩展,并查看一些相关的Gradle命令行便利。 Gradle可以轻松…

精读《setState 做了什么》

1 引言 setState 是 React 框架最常用的命令,它是用来更新状态的,这也是 React 框架划时代的功能。 但是 setState 函数是 react 包导出的,他们又是如何与 react-dom react-native react-art 这些包结合的呢? 通过 how-does-setst…

java封装实现Excel建表读写操作

对 Excel 进行读写操作是生产环境下常见的业务,网上搜索的实现方式都是基于POI和JXL第三方框架,但都不是很全面。小编由于这两天刚好需要用到,于是就参考手写了一个封装操作工具,基本涵盖了Excel表(分有表头和无表头&a…

c语言程序中注释的格式化,格式化C语言命令indent

indent是linux下一个能力极强的代码整理软件,使用他,可以轻松的写出代码风格十分精良的代码。但是indent的参数太多,使用起来不是很容易,怎么办呢?查看/usr/src/linux-headers-/scripts/Lindent文件 ,可以看…

argmax函数_1.4 TensorFlow2.1常用函数

1.4 TF常用函数tf.cast(tensor,dtypedatatype)可以进行强制类型转换。tf.reduce_min(tensor)和tf.reduce_max(tensor)将计算出张量中所有元素的最大值和最小值。import tensorflow as tfx1 tf.constant([1., 2., 3.], dtypetf.float64)print("x1:", x1)x2 tf.cast(…

休眠:DDL模式生成

不久前,我必须使用内存数据库。 该活动与集成测试有关。 如您所知,通常将内存数据库用于集成测试。 造成这种情况的原因有很多:可移植性,完善的环境基础结构,高性能,原始数据库的一致性。 问题在于如何将生…

分析jQuery源码时记录的一点感悟

分析jQuery源码时记录的一点感悟 1. 链式写法 这是jQuery语法上的最大特色,也许该改改POJO里的set方法,和其他的非get方法什么的,可以把多行代码合并,减去每次敲打对象变量的麻烦 2. 动态参数 偶尔使用Java…

设计模式---数据结构模式之迭代器模式(Iterate)

一:概念 迭代模式是行为模式之一,它把对容器中包含的内部对象的访问委让给外部类,使用Iterator(遍历)按顺序进行遍历访问的设计模式。 在应用Iterator模式之前,首先应该明白Iterator模式用来解决什么问题。…

识别Gradle约定

通过约定进行配置具有许多优点,尤其是在简洁方面,因为开发人员不需要显式配置通过约定隐式配置的内容。 但是,在利用约定进行配置时,需要了解约定。 这些约定可能已经记录在案,但是当我可以编程方式确定约定时&#xf…

jQuery函数的等价原生函数代码示例

选择器 jQuery的核心之一就是能非常方便的取到DOM元素。我们只需输入CSS选择字符串,便可以得到匹配的元素。但在大多数情况下,我们可以用简单的原生代码达到同样的效果。 .代码如下://----得到页面的所有div--------- /* jQuery */ $("div") …

高校c语言题库,C语言-中国大学mooc-题库零氪

第1 周 程序设计与C语言简介1.1 程序设计基础随堂测验1、计算机只能处理由人们编写的、解决某些问题的、事先存储在计算机存储器中的二进制指令序列。第1周单元测验1、通常把高级语言源程序翻译成目标程序的程序称为( )。A、编辑程序B、解释程序C、汇编程序D、编译程序2、一个算…

python图形化编程实验_转换图像RGB-实验室与python

自2010年以来, linked question被问到相应的代码从scipy移动到一个单独的工具包: http://scikit-image.org/ 所以这里是我实际寻找的代码: from skimage import io,color rgb io.imread(filename) lab color.rgb2lab(rgb) 还应该注意&#…

一个页面同时发起多个ajax请求,会出现阻塞情况

ajax请求设置为同步解决转载于:https://www.cnblogs.com/johnblogs/p/10245218.html

场景法设计测试用例

在面向对象的软件开发中,事件触发机制是编程中经常遇到的。 (一)场景法原理 现在的软件几乎都是用事件触发来控制流程的。像GUI软件、游戏等。事件触发时的情景形成了场景,而同一事件不同的触发顺序和处理结果就形成了事件流。这种…

JQuery让input从disabled变成enabled

设置input框可用:0.document.getElementById("removeButton").disabled false; //普通Js写法 1.$("#input").attr("disabled",true) 2.$("#input").removeAttr("disabled") 3.$("#input").attr(&q…

python中range函数是什么意思_python里range是什么

python range() 函数可创建一个整数列表,一般用在 for 循环中。函数语法(推荐学习:Python视频教程)range(start, stop[, step]) 参数说明: start: 计数从 start 开始。默认是从 0 开始。例如range(5&#x…

android 7.0编译报错,编译android7.0 sdk错误解决方法

编译时最后报错:SDK: warning: including GNU target out/target/product/generic/system/lib/libext2fs.so SDK: warning: including GNU target out/target/product/generic/system/lib/libiprouteutil.soSDK: warning: including GNU target out/target/product/…

为什么我喜欢Spring bean别名

Spring框架被广泛用作依赖项注入容器,这是有充分理由的。 首先,它促进了集成测试,并赋予了我们自定义Bean创建和初始化功能的能力(例如Autowired用于List类型 )。 但是还有一个非常有用的功能,可能会被忽略…

SYS.AUD$无法扩容导致无法登录的问题

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/bisal/article/details/19068663昨天同事说有个测试库无法登录了,用PLSQL Developer登陆后提示: ERROR: ORA-00604: error occurred at recursive SQL…

Jquery——hover与toggle

hover方法的语法结构为&#xff1a;hover&#xff08;enter&#xff0c;leave&#xff09;hover()当鼠标移动到元素上时&#xff0c;会触发第一个方法&#xff0c;当鼠标移开的时候会触发第二个方法复制代码<html><head><title>测试用</title><scri…