采用自动微分进行模型的训练

 自动微分训练模型

 简单代码实现:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的线性回归模型
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1def forward(self, x):return self.linear(x)# 准备训练数据
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])# 实例化模型、损失函数和优化器
model = LinearRegression()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器# 训练模型
epochs = 1000
for epoch in range(epochs):# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自动计算梯度optimizer.step()  # 更新模型参数if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')# 测试模型
x_test = torch.tensor([[4.0]])
predicted = model(x_test)
print(f'预测值: {predicted.item():.4f}')

代码分解:

1.定义一个简单的线性回归模型:

  • LinearRegression 类继承自nn.Module,这是所有神经网络模型的基类
  • 在 __init__ 方法中,定义了一个线性层 self.linear,它的输入维度是1,输出维度也是1。
  • forward 方法定义了数据在模型中的传播路径,即输入 x 经过 self.linear 层后得到输出。
    class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 输入维度是1,输出维度也是1def forward(self, x):return self.linear(x)
    

2.准备训练数据:

  • x_train 和 y_train 分别是输入和目标输出的训练数据。每个张量表示一个样本,x_train 中的每个元素是一个维度为1的张量,因为模型的输入维度是1。
    x_train = torch.tensor([[1.0], [2.0], [3.0]])
    y_train = torch.tensor([[2.0], [4.0], [6.0]])
    

3.实例化模型,损失函数和优化器:

  • model 是我们定义的 LinearRegression 类的一个实例,即我们要训练的线性回归模型。
  • criterion 是损失函数,这里选择了均方误差损失(MSE Loss),用于衡量预测值与实际值之间的差异。
  • optimizer 是优化器,这里选择了随机梯度下降(SGD),用于更新模型参数以最小化损失。
    model = LinearRegression()
    criterion = nn.MSELoss()  # 均方误差损失函数
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器
    

4.训练模型:

  • 这里进行了1000次迭代的训练过程。
  • 在每个迭代中,首先进行前向传播,计算模型对 x_train 的预测输出 outputs,然后计算损失 loss
  • 调用 optimizer.zero_grad() 来清空之前的梯度,然后调用 loss.backward() 自动计算梯度,最后调用 optimizer.step() 来更新模型参数
    epochs = 1000
    for epoch in range(epochs):# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自动计算梯度optimizer.step()  # 更新模型参数if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    

5.测试模型:

  • x_test 是用来测试模型的输入数据,这里表示输入为4.0。
  • model(x_test) 对 x_test 进行前向传播,得到预测结果 predicted
  • predicted.item() 取出预测结果的标量值并打印出来。
    x_test = torch.tensor([[4.0]])
    predicted = model(x_test)
    print(f'预测值: {predicted.item():.4f}')
    

运行结果:

运行结果如下:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/44963.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】数据流重定向

数据流重定向(redirect)由字面上的意思来看,好像就是将【数据给它定向到其他地方去】的样子? 没错,数据流重定向就是将某个命令执行后应该要出现在屏幕上的数据,给它传输到其他的地方,例如文件或…

[图解]企业应用架构模式2024新译本讲解26-层超类型2

1 00:00:00,510 --> 00:00:03,030 这个时候,如果再次查找所有人员 2 00:00:03,040 --> 00:00:03,750 我们会发现 3 00:00:05,010 --> 00:00:06,370 这一次所有的对象 4 00:00:06,740 --> 00:00:08,690 都是来自标识映射的 5 00:00:10,540 --> 00…

VB 上位机开发

VB 上位机开发第一节 在 VB(Visual Basic)上位机开发的第一节课程中涵盖以下基础内容: 一、上位机开发简介 解释上位机的概念和作用,它是与硬件设备进行通信和控制的软件应用程序。举例说明上位机在工业自动化、智能家居、监控系统等领域的应用。二、VB 开发环境介绍 展示如…

2024辽宁省数学建模C题【改性生物碳对水中洛克沙胂和砷离子的吸附】原创论文分享

大家好呀,从发布赛题一直到现在,总算完成了2024 年辽宁省大学数学建模竞赛C题改性生物碳对水中洛克沙胂和砷离子的吸附完整的成品论文。 本论文可以保证原创,保证高质量。绝不是随便引用一大堆模型和代码复制粘贴进来完全没有应用糊弄人的垃…

Rubber Duck Debugging: History and Benefits 橡皮鸭调试:历史和优势

注:机翻,未校对。 Discover the origins of rubber duck debugging, why it works, and why it has become so popular among programmers. 了解橡皮鸭调试的起源,它为什么有效,以及为什么它在程序员中如此受欢迎。 Debugging co…

AMD CPU加 vega 显卡运行ollama本地大模型

显卡是VEGA56,这个卡代号是gfx900 虽然ollama页面上写着这个卡可以,但是实际是不可以的 报错如下: levelWARN sourceamd_windows.go:97 msg"amdgpu is not supported" gpu0 gpu_typegfx900:xnack 它认为的GPU型号是 gfx900:xna…

【JavaScript】解决 JavaScript 语言报错:Uncaught SyntaxError: Unexpected identifier

文章目录 一、背景介绍常见场景 二、报错信息解析三、常见原因分析1. 缺少必要的标点符号2. 使用了不正确的标识符3. 关键词拼写错误4. 变量名与保留字冲突 四、解决方案与预防措施1. 检查和添加必要的标点符号2. 使用正确的标识符3. 检查关键词拼写4. 避免使用保留字作为变量名…

全栈 Discord 克隆:Next.js 13、React、Socket.io、Prisma、Tailwind、MySQL笔记(一)

前言 阅读本文你需要有 Next.js 基础 React 基础 Prisma 基础 tailwind 基础 MySql基础 准备工作 打开网站 https://ui.shadcn.com/docs 这不是一个组件库。它是可重用组件的集合,您可以将其复制并粘贴到应用中。 打开installation 选择Next.js 也就是此页面…

Python3 第十七课 -- 编程第一步

目录 一. 前言 二. end 关键字 一. 前言 在前面的教程中我们已经学习了一些 Python3 的基本语法知识,接下来我们来尝试一些实例。 打印字符串: print("Hello, world!") 输出结果为: Hello, world! 输出变量值: i 256*256…

智慧校园服务监控功能

智慧校园系统中的服务监控功能,扮演着维护整个校园数字化生态系统稳定与高效运作的重要角色。它如同一位全天候的守护者,通过实时跟踪、分析并响应系统各层面的运行状况,确保教学、管理等核心业务流程的顺畅进行。 服务监控功能覆盖了智慧校园…

开发个人Ollama-Chat--6 OpenUI

开发个人Ollama-Chat–6 OpenUI Open-webui Open WebUI 是一种可扩展、功能丰富且用户友好的自托管 WebUI,旨在完全离线运行。它支持各种 LLM 运行器,包括 Ollama 和 OpenAI 兼容的 API。 功能 由于总所周知的原由,OpenAI 的接口需要密钥才…

知识图谱与 LLM:微调与检索增强生成

Midjourney 的知识图谱聊天机器人的想法。 大型语言模型 (LLM) 的第一波炒作来自 ChatGPT 和类似的基于网络的聊天机器人,这些模型在理解和生成文本方面非常出色,这让人们(包括我自己)感到震惊。 我们中的许多人登录并测试了它写…

微信视频号的视频怎么下载到本地?快速教你下载视频号视频

天来说说市面上常见的微信视频号视频下载工具,教大家快速下载视频号视频! 方法一:缓存方法 该方法来源早期视频技术,因早期无法将大量视频通过网络存储,故而会有缓存视频文件到手机,其目的为了提高用户体验…

尚硅谷Vue3入门到实战,最新版vue3+TypeScript前端开发教程

Vue3 编码规范 创建vue3工程 基于vite创建 快速上手 | Vue.js (vuejs.org) npm create vuelatest 在nodejs环境下运行进行创建 按提示进行创建 用vscode打开项目 安装依赖 源文件有src 内有main.ts App.vue 简单分析 编写src vue2语法在三中适用 vue2中的date metho…

UnityECS学习中问题及总结entityQuery.ToComponentDataArray和entityQuery.ToEntityArray区别

在Unity的ECS&#xff08;Entity Component System&#xff09;开发中&#xff0c;entityQuery.ToComponentDataArray<T>(Allocator.Temp) 和 entityQuery.ToEntityArray(Allocator.Temp) 是两种不同的方法&#xff0c;用于从实体查询中获取数据。除了泛型参数之外&#…

【深度学习入门篇 ⑤ 】PyTorch网络模型创建

【&#x1f34a;易编橙&#xff1a;一个帮助编程小伙伴少走弯路的终身成长社群&#x1f34a;】 大家好&#xff0c;我是小森( &#xfe61;ˆoˆ&#xfe61; ) &#xff01; 易编橙终身成长社群创始团队嘉宾&#xff0c;橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官…

git、huggingface 学术加速

1、git 有时候服务器不能直接访问 github&#xff0c;下载代码会很麻烦&#xff1b;安装库的时候&#xff0c;pip xx git 就更难了 比如&#xff0c;这次我需要安装 unsloth&#xff0c;官方给出的脚本是&#xff1a; pip install “unsloth[cu121-torch220] githttps://git…

【python】函数重构

函数重构 函数重构pycharm函数重构步骤函数重构练习 函数重构 函数重构是指对现有函数进行修改和优化的过程。重构的目的是改善代码的可读性、可维护性和灵活性&#xff0c;同时保持其功能不变。函数重构通常包括以下步骤&#xff1a; 理解函数的功能和目的。了解函数的作用和…

OSPF.综合实验

1、首先将各个网段基于172.16.0.0 16 进行划分 1.1、划分为4个大区域 172.16.0.0 18 172.16.64.0 18 172.16.128.0 18 172.16.192.0 18 四个网段 划分R4 划分area2 划分area3 划分area1 2、进行IP配置 如图使用配置指令进行配置 ip address x.x.x.x /x 并且将缺省路由…

Sortable.js板块拖拽示例

图例 代码在图片后面 点赞❤️关注&#x1f64f;收藏⭐️ 页面加载后显示 拖拽效果 源代码 由于js库使用外链&#xff0c;所以会加载一会儿 <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name&qu…