使用pytorch实现线性回归(很基础模型搭建详解)

使用pytorch实现线性回归

步骤:

        1.prepare dataset
        2.design model using Class 目的是为了前向传播forward,即计算y hat(预测值)
        3.Construct loss and optimizer (using pytorch API) 其中计算loss是为了进行反向传播,optimizer是为了更新梯度
        4.Train Cycle
import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

1.nn.Module: 是所有神经网络单元(neural network modules)的基类

nn:neural network

2.构造函数__init__():是用来初始化对象的时候默认调用的函数

3.class torch.nn.Linear(in_features,out_features,bias=True) y = Ax+b

in_features   : size of each input sample
out_features : size of each output sample
# 实例化模型
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造self.linear = torch.nn.Linear(1,1) # 实例化类,构造对象,包含了权重和偏置# Linear也是继承自module的,也能进行反向传播# nn:neural networkdef forward(self,x):y_pred = self.linear(x)return y_predmodel = LinearModel()

1.class torch.nn.MSELoss(size_average=True, reduce=True)

size_average:是否计算损失均值
reduce:

2.class torch.optim.SGD(model.parameters(),lr=<object object》,momentum=0,dampening=0,weight_decay=0,nesterov=False)

model.parameters(): 告诉优化器对哪些Tensor进行优化,检查所有成员
lr:learning rate 学习率

3.torch.nn.MSELoss也跟torch.nn.Module有关,参与计算图的构建,torch.optim.SGD与torch.nn.Module无关,不参与构建计算图。

# 定义损失函数
criterion = torch.nn.MSELoss(size_average=False)
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
# 训练
for epoch in range(100):y_pred = model(x_data) # 计算y hatloss = criterion(y_pred,y_data) # 计算lossprint(epoch,loss.item())# optimizer.zero_grad() # 梯度归零loss.backward() # 反向传播,计算梯度optimizer.step() # update 参数,即更新w和b的值optimizer.zero_grad() # 梯度归零print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred',y_test.data)

1.第三行直接使用对象调用函数(可调用对象)y_pred = model(x_data):¶

Module实现了魔法函数__call__(),call()里面有一条语句是要调用forward(),因此新写类中需要重写forward()覆盖父类的forward()
call() 函数的作用四可以直接在对象后面加(),例如实例化的model对象,和实例化的linear对象
调用此方法:super(LinearModel, self).init() 实例化了父类中所有方法

2.每一次epoch的训练过程,总结就是

①前向传播,求y hat (输入的预测值)
②根据y_hat和y_label(y_data)计算loss
③反向传播 backward (计算梯度)
④根据梯度,更新参数
⑤梯度清零

3.optimizer.zero_grad()

因为grad在 反向传播 的过程中是累加的,也就是说上一次反向传播的结果会对下一次的反向传播的结果造成影响,则意味着每一次运行反向传播,梯度都会累加之前的梯度,所以一般在反向传播之前需要把梯度清零。

本次代码实现的是批量处理数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/736832.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FPGA-AXI4接口协议概述

假设我们要传一帧1080P的图片到显示屏显示&#xff0c;那么需要多大的储存空间呢&#xff1f; 一帧1080P的RGB565图像数据需要1920*1080*1633.1776Mb 存储空间 下图是ZYNQ-7000系列中Block RAM的大小&#xff1a; 可以看到最大存储空间的BRAM都不能存储一帧图片&#xff0c;那…

为什么选择.com域名?

com是company简称&#xff0c;表示公司企业。.com是目前全球最流行的通用域名后缀&#xff0c;全球的注册量1.1亿个&#xff0c;所有公司都会优先注册.com域名。 西部数码连续7年被评选为五星级域名注册服务商&#xff0c;22年行业经验&#xff0c;全国3强。是.com域名注册&am…

MySQL性能分析:性能模式和慢查询日志的使用

目录 一、性能模式 步骤1. 启用性能模式 步骤2. 查询性能数据 步骤3. 分析性能数据 步骤4. 优化与调整 注意事项 二、慢查询日志 步骤1. 启用慢查询日志

深入理解Vue.js中的nextTick:实现异步更新的奥秘

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

PostgreSQL教程(三十三):服务器管理(十五)之可靠性和预写式日志

本章解释预写式日志如何用于获得有效的、可靠的操作。 一、 可靠性 可靠性是任何严肃的数据库系统的重要属性&#xff0c;PostgreSQL尽一切可能来保证可靠的操作。可靠的操作的一个方面是&#xff0c;被一个提交事务记录的所有数据应该被存储在一个非易失的区域&#xff0c; …

javascript实现解决浮点数加减乘除运算误差丢失精度问题【收藏点赞】

相信程序都会遇到这样的问题&#xff0c;有时需要在js上做运算合计等浮点数加减乘除&#xff0c;但会有些浮点数会有误差问题。下面用js来解决浮点数加减乘除运算误差丢失精度这个请 【收藏点赞】。 是程序都会在浮点数加减乘除上有误差问题&#xff0c;这是计算机二进制生成的…

GPU:使用阿里云服务器,免费部署一个开源大模型

前面提到CPU版本如何安装和部署ChatGLM&#xff0c;虽然能部署&#xff0c;但是速度和GPU比起来确实一言难尽。 然后找阿里云白嫖了一个服务器&#xff08;省点用的话&#xff0c;不用的时候关机&#xff0c;可以免费用两个多月没问题&#xff09;&#xff0c;只要没有申请过 …

大带宽服务器租用 满足高速网络访问

大带宽服务器租用通常指的是租用具备较大网络带宽的服务器&#xff0c;以满足对高速网络访问需求较为迫切的业务场景。RAKsmart小编为您整理发布大带宽服务器租用如何才能满足高速网络访问的详细信息。 以下是一些关于大带宽服务器租用的详细信息&#xff1a; 1. **带宽大小**&…

计算机网络—eNSP搭建基础 IP网络

目录 1.下载eNSP 2.启动eNSP 3.建立拓扑 4.建立一条物理连接 5.进入终端系统配置界面 6.配置终端系统 7.启动终端系统设备 8.捕获接口报文 9.生成接口流量 10.观察捕获的报文 1.下载eNSP 网上有许多下载eNSP的方式&#xff0c;记得还要下其它三个Virtual Box、Winpa…

composer require 包时,指定版本

composer 如果不加版本上去&#xff0c;则默认是下载最新的版本。 版本约束使用示例 : 和 都可以 版本约束可以加引号&#xff0c;也可以不加 composer官方文档使用的是 : 并且版本约束加引号 示例代码&#xff1a; composer create-project topthink/think:"5…

2024.3.11每日一题

LeetCode 将标题首字母大写 题目链接&#xff1a;2129. 将标题首字母大写 - 力扣&#xff08;LeetCode&#xff09; 题目描述 给你一个字符串 title &#xff0c;它由单个空格连接一个或多个单词组成&#xff0c;每个单词都只包含英文字母。请你按以下规则将每个单词的首字…

【STA】SRAM / DDR SDRAM 接口时序约束学习记录

1. SRAM接口 相比于DDR SDRAM&#xff0c;SRAM接口数据与控制信号共享同一时钟。在用户逻辑&#xff08;这里记作DUA&#xff08;Design Under Analysis&#xff09;&#xff09;将数据写到SRAM中去的写周期中&#xff0c;数据和地址从DUA传送到SRAM中&#xff0c;并都在有效时…

安卓studio安装

安卓studio安装 2024.3.11官网的版本&#xff08;有些翻墙步骤下载东西也解决了&#xff09; 这次写的略有草率&#xff0c;后面会更新布局的&#xff0c;因为截图量太大了&#xff0c;有需要的小伙伴可以试着接受一下哈哈哈哈 !(https://gitee.com/jiuzheyangbawjf/img/raw/ma…

(学习日记)2024.03.08:UCOSIII第十节:临界段

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

mybatis如何打印出完整sql语句

分两步: 1. 在application.properties配置中添加配置项: mybatis-plus.configuration.log-implorg.apache.ibatis.logging.stdout.StdOutImpl logging.level.mapper文件的包路径DEBUG (示例: logging.level.com.test.biztest.service.dalDEBUG, com.test.biztest.service.d…

基于SpringBoot的农产品特色供销系统(蔬菜商城)

基于SpringBoot的农产品特色供销系统&#xff08;蔬菜商城&#xff09; 系统介绍 该系统使用Java、MySQL、Redis、Spring Boot和HTML等技术作为系统的技术支撑&#xff0c;实现了以下功能模块&#xff1a; &#xff08;1&#xff09;后台管理模块&#xff0c;包括权限、日志、…

MySQL数据库在Windows和Linux中由于大小写默认规则不同,出现大小写问题如何解决?

Windows和Linux差异&#xff1a;在Windows上&#xff0c;lower_case_table_names默认为1&#xff0c;而在Linux上&#xff0c;默认值通常为0。因此&#xff0c;在Linux上更改这个设置更常见&#xff0c;以确保与Windows环境的兼容性或实现特定的大小写敏感性需求。 操作系统的大…

[Flutter]自定义等待转圈和Toast提示

1.自定义样式 2.自定义LoadingView import package:flutter/material.dart;enum LoadingStyle {onlyIndicator, // 仅一个转圈等待roundedRectangle, // 添加一个圆角矩形当背景maskingOperation, // 添加一个背景蒙层&#xff0c; 阻止用户操作 }class LoadingView {static f…

【数据结构与算法】贪心算法题解(一)

这里写目录标题 一、455. 分发饼干二、56. 合并区间三、53. 最大子数组和 一、455. 分发饼干 简单 假设你是一位很棒的家长&#xff0c;想要给你的孩子们一些小饼干。但是&#xff0c;每个孩子最多只能给一块饼干。 对每个孩子 i&#xff0c;都有一个胃口值 g[i]&#xff0c;这…

Visual Studio 2019重装vs2019打不开.netcore项目

无法打开项目文件。 .NET SDK 的版本 7.0.306 至少需要 MSBuild 的 17.4.0 版本。当前可用的 MSBuild 版本为 16.11.2.50704。请将在 global.json 中指定的 .NET SDK 更改为需要当前可用的 MSBuild 版本的旧版本。 无法打开项目文件。 .NET SDK 的版本 7.0.306 至少需要 MSBui…