采用自动微分进行模型的训练

 自动微分训练模型

 简单代码实现:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的线性回归模型
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1def forward(self, x):return self.linear(x)# 准备训练数据
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])# 实例化模型、损失函数和优化器
model = LinearRegression()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器# 训练模型
epochs = 1000
for epoch in range(epochs):# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自动计算梯度optimizer.step()  # 更新模型参数if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')# 测试模型
x_test = torch.tensor([[4.0]])
predicted = model(x_test)
print(f'预测值: {predicted.item():.4f}')

代码分解:

1.定义一个简单的线性回归模型:

  • LinearRegression 类继承自nn.Module,这是所有神经网络模型的基类
  • 在 __init__ 方法中,定义了一个线性层 self.linear,它的输入维度是1,输出维度也是1。
  • forward 方法定义了数据在模型中的传播路径,即输入 x 经过 self.linear 层后得到输出。
    class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1def forward(self, x):return self.linear(x)
    

2.准备训练数据:

  • x_train 和 y_train 分别是输入和目标输出的训练数据。每个张量表示一个样本,x_train 中的每个元素是一个维度为1的张量,因为模型的输入维度是1。
    x_train = torch.tensor([[1.0], [2.0], [3.0]])
    y_train = torch.tensor([[2.0], [4.0], [6.0]])
    

3.实例化模型,损失函数和优化器:

  • model 是我们定义的 LinearRegression 类的一个实例,即我们要训练的线性回归模型。
  • criterion 是损失函数,这里选择了均方误差损失(MSE Loss),用于衡量预测值与实际值之间的差异。
  • optimizer 是优化器,这里选择了随机梯度下降(SGD),用于更新模型参数以最小化损失。
    model = LinearRegression()
    criterion = nn.MSELoss()  # 均方误差损失函数
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器
    

4.训练模型:

  • 这里进行了1000次迭代的训练过程。
  • 在每个迭代中,首先进行前向传播,计算模型对 x_train 的预测输出 outputs,然后计算损失 loss
  • 调用 optimizer.zero_grad() 来清空之前的梯度,然后调用 loss.backward() 自动计算梯度,最后调用 optimizer.step() 来更新模型参数
    epochs = 1000
    for epoch in range(epochs):# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自动计算梯度optimizer.step()  # 更新模型参数if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    

5.测试模型:

  • x_test 是用来测试模型的输入数据,这里表示输入为4.0。
  • model(x_test) 对 x_test 进行前向传播,得到预测结果 predicted
  • predicted.item() 取出预测结果的标量值并打印出来。
    x_test = torch.tensor([[4.0]])
    predicted = model(x_test)
    print(f'预测值: {predicted.item():.4f}')
    

运行结果:

运行结果如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/44963.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】数据流重定向

数据流重定向(redirect)由字面上的意思来看,好像就是将【数据给它定向到其他地方去】的样子? 没错,数据流重定向就是将某个命令执行后应该要出现在屏幕上的数据,给它传输到其他的地方,例如文件或…

[图解]企业应用架构模式2024新译本讲解26-层超类型2

1 00:00:00,510 --> 00:00:03,030 这个时候,如果再次查找所有人员 2 00:00:03,040 --> 00:00:03,750 我们会发现 3 00:00:05,010 --> 00:00:06,370 这一次所有的对象 4 00:00:06,740 --> 00:00:08,690 都是来自标识映射的 5 00:00:10,540 --> 00…

2024辽宁省数学建模C题【改性生物碳对水中洛克沙胂和砷离子的吸附】原创论文分享

大家好呀,从发布赛题一直到现在,总算完成了2024 年辽宁省大学数学建模竞赛C题改性生物碳对水中洛克沙胂和砷离子的吸附完整的成品论文。 本论文可以保证原创,保证高质量。绝不是随便引用一大堆模型和代码复制粘贴进来完全没有应用糊弄人的垃…

【JavaScript】解决 JavaScript 语言报错:Uncaught SyntaxError: Unexpected identifier

文章目录 一、背景介绍常见场景 二、报错信息解析三、常见原因分析1. 缺少必要的标点符号2. 使用了不正确的标识符3. 关键词拼写错误4. 变量名与保留字冲突 四、解决方案与预防措施1. 检查和添加必要的标点符号2. 使用正确的标识符3. 检查关键词拼写4. 避免使用保留字作为变量名…

全栈 Discord 克隆:Next.js 13、React、Socket.io、Prisma、Tailwind、MySQL笔记(一)

前言 阅读本文你需要有 Next.js 基础 React 基础 Prisma 基础 tailwind 基础 MySql基础 准备工作 打开网站 https://ui.shadcn.com/docs 这不是一个组件库。它是可重用组件的集合,您可以将其复制并粘贴到应用中。 打开installation 选择Next.js 也就是此页面…

智慧校园服务监控功能

智慧校园系统中的服务监控功能,扮演着维护整个校园数字化生态系统稳定与高效运作的重要角色。它如同一位全天候的守护者,通过实时跟踪、分析并响应系统各层面的运行状况,确保教学、管理等核心业务流程的顺畅进行。 服务监控功能覆盖了智慧校园…

开发个人Ollama-Chat--6 OpenUI

开发个人Ollama-Chat–6 OpenUI Open-webui Open WebUI 是一种可扩展、功能丰富且用户友好的自托管 WebUI,旨在完全离线运行。它支持各种 LLM 运行器,包括 Ollama 和 OpenAI 兼容的 API。 功能 由于总所周知的原由,OpenAI 的接口需要密钥才…

知识图谱与 LLM:微调与检索增强生成

Midjourney 的知识图谱聊天机器人的想法。 大型语言模型 (LLM) 的第一波炒作来自 ChatGPT 和类似的基于网络的聊天机器人,这些模型在理解和生成文本方面非常出色,这让人们(包括我自己)感到震惊。 我们中的许多人登录并测试了它写…

微信视频号的视频怎么下载到本地?快速教你下载视频号视频

天来说说市面上常见的微信视频号视频下载工具,教大家快速下载视频号视频! 方法一:缓存方法 该方法来源早期视频技术,因早期无法将大量视频通过网络存储,故而会有缓存视频文件到手机,其目的为了提高用户体验…

尚硅谷Vue3入门到实战,最新版vue3+TypeScript前端开发教程

Vue3 编码规范 创建vue3工程 基于vite创建 快速上手 | Vue.js (vuejs.org) npm create vuelatest 在nodejs环境下运行进行创建 按提示进行创建 用vscode打开项目 安装依赖 源文件有src 内有main.ts App.vue 简单分析 编写src vue2语法在三中适用 vue2中的date metho…

【深度学习入门篇 ⑤ 】PyTorch网络模型创建

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】 大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官…

OSPF.综合实验

1、首先将各个网段基于172.16.0.0 16 进行划分 1.1、划分为4个大区域 172.16.0.0 18 172.16.64.0 18 172.16.128.0 18 172.16.192.0 18 四个网段 划分R4 划分area2 划分area3 划分area1 2、进行IP配置 如图使用配置指令进行配置 ip address x.x.x.x /x 并且将缺省路由…

Sortable.js板块拖拽示例

图例 代码在图片后面 点赞❤️关注&#x1f64f;收藏⭐️ 页面加载后显示 拖拽效果 源代码 由于js库使用外链&#xff0c;所以会加载一会儿 <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name&qu…

C语言 ——— 实用调试技巧(Visual Studio)

目录 Debug 和 Release 的区别 F10 --- 逐过程调试 & F11 --- 逐语句调试 F9 --- 新建/切换断点 & F5 --- 开始调试 shift F5 & ctrl F5 Debug 和 Release 的区别 Debug&#xff1a;通常为调试版本&#xff0c;它包含调试信息&#xff0c;并且不作任何优化…

一 GD32 MCU 开发环境搭建

GD32 系列为通用型 MCU &#xff0c;所以开发环境也可以使用通用型的 IDE &#xff0c;目前使用较多的是 KEIL、 IAR 、 GCC 和 Embedded Builder &#xff0c;客户可以根据个人喜好来选择相应的开发环境。 目录 1、使用 Keil 开发 GD32 目前市面通用的MDK for ARM版本有Kei…

企业网三层架构

企业网三层架构&#xff1a;是一种层次化模型设计&#xff0c;旨在将复杂的网络设计分成三个层次&#xff0c;每个层次都着重于某些特定的功能&#xff0c;以提高效率和稳定性。 企业网三层架构层次&#xff1a; 接入层&#xff1a;使终端设备接入到网络中来&#xff0c;提供…

Android12 MultiMedia框架之GenericSource extractor

前面两节学习到了各种Source的创建和extractor service的启动&#xff0c;本节将以本地播放为例记录下GenericSource是如何创建一个extractor的。extractor是在PrepareAsync()方法中被创建出来的&#xff0c;为了不过多赘述&#xff0c;我们直接从GenericSource的onPrepareAsyn…

13_Shell系统函数

13_Shell系统函数和自定义函数 一、系统函数 basename 获取文件名 #!/bin/bash#basename 相对路径文件名 basename ./1.sh#basename 绝对路径文件名 basename /tmp/1.sh#basename 去除文件后缀名 basename /tmp/1.sh .shdirname 获取文件所在目录名 #!/bin/bash#dirname 相对路…

Redis持久化RDB,AOF

目 录 CONFIG动态修改配置 慢查询 持久化 在上一篇主要对redis的了解入门&#xff0c;安装&#xff0c;以及基础配置&#xff0c;多实例的实现&#xff1a;redis的安装看我上一篇&#xff1a; Redis安装部署与使用,多实例 redis是挡在MySQL前面的&#xff0c;运行在内存…

产品经理-研发流程-敏捷开发-迭代-需求评审及产品规划(15)

敏捷开发是以用户的需求进化为核心&#xff0c;采用迭代、循序渐进的方法进行软件开发。 通俗来说&#xff0c;敏捷开发是一个软件开发流程&#xff0c;是一个采用了迭代方法的开发流程 简单来说&#xff0c;迭代就是把一个大产品拆分出一些最小的实现单位。完成不同的迭代就最…