Pytorch-SGD算法解析

关注B站可以观看更多实战教学视频:肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com)

SGD,即随机梯度下降(Stochastic Gradient Descent),是机器学习中用于优化目标函数的迭代方法,特别是在处理大数据集和在线学习场景中。与传统的批量梯度下降(Batch Gradient Descent)不同,SGD在每一步中仅使用一个样本来计算梯度并更新模型参数,这使得它在处理大规模数据集时更加高效。

SGD算法的基本步骤

  1. 初始化参数:选择初始参数值,可以是随机的或者基于一些先验知识。
  2. 随机选择样本:从数据集中随机选择一个样本。
  3. 计算梯度:计算损失函数关于当前参数的梯度。
  4. 更新参数:沿着负梯度方向更新参数。
  5. 重复:重复步骤2-4,直到满足停止条件(如达到预设的迭代次数或损失函数的改变小于某个阈值)。

SGD的Python代码示例:

python实现

假设我们要使用SGD来优化一个简单的线性回归模型。

import numpy as np  # 目标函数(损失函数)和其梯度  
def loss_function(w, b, x, y):  return np.sum((y - (w * x + b)) ** 2) / len(x)  def gradient_function(w, b, x, y):  dw = -2 * np.sum((y - (w * x + b)) * x) / len(x)  db = -2 * np.sum(y - (w * x + b)) / len(x)  return dw, db  # SGD算法  
def sgd(x, y, learning_rate=0.01, epochs=1000):  # 初始化参数  w = np.random.rand()  b = np.random.rand()  # 存储每次迭代的损失值,用于可视化  losses = []  for i in range(epochs):  # 随机选择一个样本(在这个示例中,我们没有实际进行随机选择,而是使用了整个数据集。在大数据集上,你应该随机选择一个样本或小批量样本。)  # 注意:为了简化示例,这里我们实际上使用的是批量梯度下降。在真正的SGD中,你应该在这里随机选择一个样本。  # 计算梯度  dw, db = gradient_function(w, b, x, y)  # 更新参数  w = w - learning_rate * dw  b = b - learning_rate * db  # 记录损失值  loss = loss_function(w, b, x, y)  losses.append(loss)  # 每隔一段时间打印损失值(可选)  if i % 100 == 0:  print(f"Epoch {i}, Loss: {loss}")  return w, b, losses  # 示例数据(你可以替换为自己的数据)  
x = np.array([1, 2, 3, 4, 5])  
y = np.array([2, 4, 6, 8, 10])  # 运行SGD算法  
w, b, losses = sgd(x, y)  
print(f"Optimized parameters: w = {w}, b = {b}")

解析

  • 在上面的代码中,我们首先定义了损失函数和它的梯度。对于线性回归,损失函数通常是均方误差。
  • sgd函数实现了SGD算法。它接受输入数据x和标签y,以及学习率和迭代次数作为参数。
  • 在每次迭代中,我们计算损失函数关于参数wb的梯度,并使用这些梯度来更新参数。
  • 我们还记录了每次迭代的损失值,以便稍后可视化算法的收敛情况。
  • 最后,我们打印出优化后的参数值。在实际应用中,你可能还需要使用这些参数来对新数据进行预测。

在PyTorch中,SGD(随机梯度下降)是一种基本的优化器,用于调整模型的参数以最小化损失函数。下面是torch.optim.SGD的参数解析和一个简单的用例。

SGD的Pytorch代码示例:

参数解析

torch.optim.SGD的主要参数如下:

  1. params (iterable):待优化的参数,或者是定义了参数的模型的迭代器。
  2. lr (float):学习率。这是更新参数的步长大小。较小的值会导致更新更精细,而较大的值可能会导致训练过程不稳定。这是SGD优化器的一个关键参数。
  3. momentum (float, optional):动量因子 (default: 0)。该参数加速了SGD在相关方向上的收敛,并抑制了震荡。
  4. dampening (float, optional):动量的抑制因子 (default: 0)。增加此值可以减少动量的影响。在实际应用中,这个参数的使用较少。
  5. weight_decay (float, optional):权重衰减 (L2 penalty) (default: 0)。通过向损失函数添加与权重向量平方成比例的惩罚项,来防止过拟合。
  6. nesterov (bool, optional):是否使用Nesterov动量 (default: False)。Nesterov动量是标准动量方法的一个变种,它在计算梯度时使用了未来的近似位置。

用例

下面是一个使用SGD优化器的简单例子:

import torch  
import torch.nn as nn  
import torch.optim as optim  # 定义一个简单的模型  
model = nn.Sequential(  nn.Linear(10, 5),  nn.ReLU(),  nn.Linear(5, 2),  
)  # 定义损失函数  
criterion = nn.CrossEntropyLoss()  # 定义优化器  
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)  # 假设有输入数据和目标  
input_data = torch.randn(1, 10)  
target = torch.tensor([1])  # 训练循环(这里只展示了一次迭代)  
for epoch in range(1):  # 通常会有多个 epochs  # 前向传播  output = model(input_data)  # 计算损失  loss = criterion(output, target)  # 反向传播  optimizer.zero_grad()  # 清除之前的梯度  loss.backward()  # 计算当前梯度  # 更新参数  optimizer.step()  # 应用梯度更新  # 打印损失  print(f'Epoch {epoch+1}, Loss: {loss.item()}')

在这个例子中,我们创建了一个简单的两层神经网络模型,并使用SGD优化器来更新模型的参数。在训练循环中,我们执行了前向传播来计算模型的输出,然后计算了损失,通过调用loss.backward()执行了反向传播来计算梯度,最后通过调用optimizer.step()更新了模型的参数。在每次迭代开始时,我们使用optimizer.zero_grad()来清除之前累积的梯度,这是非常重要的步骤,因为PyTorch默认会累积梯度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/695667.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言字符串函数strchr与strrchr

注意: 这两个函数的功能,都是在指定的字符串 s 中,试图找到字符 c。strchr() 从左往右找,strrchr() 从右往左找。字符串结束标记 ‘\0’ 被认为是字符串的一部分。 示例 char *p;p strchr("www.qq.com", .); // 从左…

python自动化接口测试

前几天,同组姐妹说想要对接口那些异常值进行测试,能否有自动化测试的方法。仔细想了一下,工具还挺多,大概分析了一下: 1、soapui:可以对接口参数进行异常值参数化,可以加断言,一般我们会加http…

如何申请试用Gemini 1.5 Pro

https://developers.googleblog.com/2024/02/gemini-15-available-for-private-preview-in-google-ai-studio.html 我开发的chatgpt网站: https://chat.xutongbao.top

设计模式的分类及Spring中用到的设计模式

设计模式的分类 在《设计模式:可复用面向对象软件的基础》(Design Patterns: Elements of Reusable Object-Oriented Software)一书中,提出了 23 种设计模式,通常称为 GoF(Gang of Four)设计模…

浅谈数字信号处理器的本质与作用:从定义、原理到应用场景

数字信号处理器(DSP)作为一种关键的电子元件,在通信、音频、图像处理等领域扮演着不可或缺的角色。然而,对于许多人来说,数字信号处理器的概念可能依然模糊,其作用和原理也许并不为人所熟知。因此&#xff…

”戏说“ 交换机 与 路由器

一般意义上说 老哥 这文章发表 的 东一榔头 西一锤 呵呵, 想到哪里就啰嗦到哪里 。 交换机: 其实就是在通道交换 路由器: 不光是在通道交换还要在协议上交换 下图你看懂了吗? (仅仅数据交换-交换机 协议…

kafka 生产者消费者设计思考

生产者 负载均衡 生产者直接发送消息给分区leader,而不需要通过中间者进行转发。 这意味着生产者需要知道哪些服务器是存活的,以及主题分区leader在哪里的元数据请求。同时这也意味着生产者可以根据情况决定发给哪个broker,那么既可以随机…

Bert基础(三)--位置编码

背景 还是以I am good(我很好)为例。 在RNN模型中,句子是逐字送入学习网络的。换言之,首先把I作为输入,接下来是am,以此类推。通过逐字地接受输入,学习网络就能完全理解整个句子。然而&#x…

Eclipse的Java Project的入口main函数

在使用Eclipse创建java project项目的时候,一个项目里面通常只有一个main,那么一个项目里面是否可以有多个main函数呢?其实可以的,但是运行java application的时候要选择执行哪个main函数。 下面举个例子: 1、创建一个…

(二十二)Flask之上下文管理第三篇【收尾—讲一讲g】

目录: 每篇前言:g到底是什么?生命周期在请求周期内保持数据需要注意的是: 拓展—面向对象的私有字段深入讲解一下那句: 每篇前言: 🏆🏆作者介绍:【孤寒者】—CSDN全栈领域…

Django使用Celery异步

安装包 pip install celerypip install eventlet 1.在项目文件的根目录下创建目录结果 2. 在main.py文件中 # !/usr/bin/env python # -*-coding:utf-8 -*-""" # Author :skyTree # version :python 3.11 # Description&#…

备战蓝桥杯---动态规划(应用2(一些十分巧妙的优化dp的手段))

好久不见,甚是想念,最近一直在看过河这道题(感觉最近脑子有点宕机QAQ),现在算是有点懂了,打算记录下这道又爱又恨的题。(如有错误欢迎大佬帮忙指出) 话不多说,直接看题&…

2024年【T电梯修理】最新解析及T电梯修理操作证考试

题库来源:安全生产模拟考试一点通公众号小程序 T电梯修理最新解析根据新T电梯修理考试大纲要求,安全生产模拟考试一点通将T电梯修理模拟考试试题进行汇编,组成一套T电梯修理全真模拟考试试题,学员可通过T电梯修理操作证考试全真模…

maven配置多仓库私服

经常见我们除了需要官方的仓库以外,更多是配置了国内的阿里云公共仓库。但很多的企业会有自己的公共组件,两者会结合起来使用,就需要配置公司的私服。 而经常性的,我们会在 apache-maven-3.8.6\conf\settings.xml 中,…

Django学习笔记-HTML实现服务器图片的下载

1.index编写代码,跳转下载页面 2.创建download界面 3.编写download路由 4.创建download函数 1).如果请求的方法是GET,imglist变量存储从models.imgModel模型中获取的所有对象,创建字典ctx,使用render函数来渲染download.htm 2).如果请求的方法是POST,获取要下载的文…

啤酒:精酿啤酒与沙拉的轻盈享受

在繁忙的生活中,我们总是在寻找一种简单而健康的美食享受。当Fendi Club啤酒与沙拉相遇,它们将为我们带来一场轻盈的味觉之旅。 Fendi Club啤酒,以其醇厚的口感和淡淡的麦芽香气而闻名。这款啤酒在酿造过程中采用了特别的工艺,使得…

MCU中断响应流程及注意事项

本文介绍MCU中断响应流程及注意事项。 1.中断响应流程 中断响应的一般流程为: 1)断点保护 硬件操作,将PC,PSR等相关寄存器入栈保护 2)识别中断源 硬件操作,识别中断的来源,如果多个中断同时发生,高优…

uniapp 如何嵌套H5 页面?

如何在 uniapp项目中 嵌套h5页面 在UniApp中可以通过使用 web-view 组件来嵌入H5页面。 首先需要安装uni-app的依赖包,然后创建一个新的页面(比如名为"WebPage.vue")作为容器页面,并将其放置于pages目录下。 接下来&…

【C++】封装

1.封装的意义 封装是C面向对象三大特性之一 实例化(通过一个类 创建一个对象的过程) 类中的属性和行为 我们统一称为 成员 属性 成员属性 成员变量 行为 成员函数 成员方法 封装的意义: 1.将属性和行为作为一个整体,表现生活中的事…

【Python】2019年蓝桥杯省赛真题——完全二叉树的权值

蓝桥杯 2019 省 A&B:完全二叉树的权值 题目描述 给定一棵包含 N N N 个节点的完全二叉树,树上每个节点都有一个权值,按从上到下、从左到右的顺序依次是 A 1 , A 2 , ⋯ A N A_1,A_2, \cdots A_N A1​,A2​,⋯AN​,如下图所…