动手学深度学习(Pytorch版)代码实践 -深度学习基础-03线性回归简洁版

03线性回归简洁版

主要内容

  1. 生成数据集:使用给定的权重和偏置,以及一些噪声,生成模拟数据。
  2. 读取数据集:将数据打乱,并按批次读取数据。
  3. 初始化模型参数:随机初始化模型的权重和偏置,并启用自动求导。
  4. 定义模型:实现简单的线性回归模型。
  5. 定义损失函数:使用均方损失函数计算误差。
  6. 定义优化函数:实现小批量随机梯度下降算法。
  7. 模型训练:训练模型,更新参数,输出训练过程中每轮的损失以及最终的参数误差。
import numpy as np
import torch
from torch.utils import data# 生成数据集
def synthetic_data(w, b, num_examples):"""生成 y = Xw + b + 噪声"""# torch.normal: 返回一个从均值为0,标准差为1的正态分布中提取的随机数的张量# 生成形状为(num_examples, len(w))的矩阵X = torch.normal(0, 1, (num_examples, len(w)))# torch.matmul: 矩阵乘法y = torch.matmul(X, w) + b# 添加噪声:torch.normal(0, 0.01, y.shape)y += torch.normal(0, 0.01, y.shape)# reshape: 只改变张量的视图,不改变数据,将y转换为列向量return X, y.reshape((-1, 1))# 设定真实的权重和偏置
true_w = torch.tensor([2, -3.4])
true_b = 4.2
# 生成特征和标签
features, labels = synthetic_data(true_w, true_b, 1000)# 定义数据加载器
def load_array(data_arrays, batch_size, is_train=True):"""构造一个Pytorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)# 使用DataLoader每次从dataset中抽选batch_size个样本# shuffle设定是否随机抽取return data.DataLoader(dataset, batch_size, shuffle=is_train)# 设置批量大小
batch_size = 10
data_iter = load_array((features, labels), batch_size)# "nn"是神经网络的缩写
from torch import nn# 定义模型
# Sequential类将多个层串联在一起
net = nn.Sequential(nn.Linear(2, 1))
# nn.Linear: 全连接层,输入特征数为2,输出特征数为1# 初始化模型参数
# normal: 生成符合正态分布的随机数,参数为均值0和标准差0.01
net[0].weight.data.normal_(0, 0.01)
# fill_: 将偏置初始化为0
net[0].bias.data.fill_(0)# 定义损失函数
# MSELoss: 计算均方误差,默认返回所有样本损失的平均值
loss = nn.MSELoss()# 实例化SGD优化器
trainer = torch.optim.SGD(net.parameters(), lr=0.03)#实例化SGD优化器时,需要指定要优化的参数和学习率(lr)
#这里的参数通过net.parameters()获取,学习率设置为0.03# 训练模型
num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:# 计算模型输出和损失l = loss(net(X), y)# 梯度清零trainer.zero_grad()# 反向传播,计算梯度l.backward()# 更新模型参数trainer.step()# 每个epoch结束后计算整个数据集上的损失l = loss(net(features), labels)print(f'第{epoch + 1}轮,损失: {l:f}')# 打印生成数据集的真实参数和通过有限数据训练获得的模型参数的误差
w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)# 示例输出:
# 第1轮,损失: 0.000246
# 第2轮,损失: 0.000103
# 第3轮,损失: 0.000103
# w的估计误差: tensor([ 0.0006, -0.0005])
# b的估计误差: tensor([0.0006])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/844434.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA 17

文章目录 概述一 语法层面变化1_JEP 409:密封类2_JEP 406:switch模式匹配(预览) 二 API层面变化1_JEP 414:Vector API(第二个孵化器)2_JEP 415:特定于上下文的反序列化过滤器 三 其他…

手机投屏技巧:手机怎么投屏到电脑显示屏上?精选6招解决!

手机怎么投屏到电脑显示屏上?出于一些不同的原因,大多数人都希望能将手机投屏到电脑上。其中一个常见的原因是,大家经常会希望在笔记本电脑上共享图片,而无需上传或者登录微信进行文件传输。以及希望不依靠投影仪,就能…

只刷题可以通过PMP考试吗?

咱们都知道,PMBOK那本书,哎呀,读起来确实有点费劲。所以,有些人就想了,干脆我就刷题吧,题海战术,没准儿也能过。这话啊,听起来似乎有点道理,但咱们得好好琢磨琢磨。 刷题…

Linux: 为什么不应该在内核代码中使用 volatile ?

文章目录 1. 前言2. 背景3. 为什么不应该在内核代码中使用 volatile ?4. 参考资料 1. 前言 限于作者能力水平,本文可能存在谬误,因此而给读者带来的损失,作者不做任何承诺。 2. 背景 本文基于 Linux 内核文档 Why the “volati…

【YashanDB知识库】自动选举配置错误引发的一系列问题

问题现象 问题出现的步骤/操作: ● 配置自动选举,数据库备库手动发起switch over,命令会报错 ● 主、备库变为只读状态,数据库无法进行读写操作 ● shutdown immediate 停止数据库,此时发现数据库一直没有退出&…

C++ Primer Chapter 2 Variables and Basic Types

C Primer Chapter 2 Variables and Basic Types 2024/05/27 2.3 复合类型 引用 定义 通过将声明符写成&d的形式来定义引用类型,其中d是声明的变量名。 int ival1024; int &refValival; int &refVal2; //报错:引用必须被初始化 引用即别名 …

script 标签中 defer 和 async 属性的区别

script 标签中的 defer Vs. async 在 HTML 中,script 标签可以使用 defer 和 async 属性来控制外部 JavaScript 脚本加载和执行的方式。defer 和 async 都可以提高页面的加载性能,主要区别整理如下。 区别点deferasync加载顺序按顺序加载异步加载&…

论文笔记:Vision GNN: An Image is Worth Graph of Nodes

neurips 2022 首次将图神经网络用于视觉任务,同时能取得很好的效果 1 方法 2 架构 在计算机视觉领域,常用的 transformer 通常是 isotropic 的架构(如 ViT),而 CNN 更喜欢使用 pyramid 架构(如 ResNet&am…

开源数据库同步工具DBSyncer

前言: 这么实用的工具,竟然今天才发现,相见恨晚呀!!!! DBSyncer(英[dbsɪŋkɜː],美[dbsɪŋkɜː 简称dbs)是一款开源的数据同步中间件,提供M…

必看项目|多维度揭示心力衰竭患者生存关键因素(生存分析、统计检验、随机森林)

1.项目背景 心力衰竭是一种严重的公共卫生问题,影响着全球数百万人的生活质量和寿命,心力衰竭的病因复杂多样,既有个体生理因素的影响,也受到环境和社会因素的制约,个体的生活方式、饮食结构和医疗状况在很大程度上决定了其心力衰竭的风险。在现代社会,随着生活水平的提…

使用moquette mqtt发布wss服务

文章目录 概要一、制作的ssl证书二、配置wss小结 概要 moquette是一款不错的开源mqtt中间件,github地址:https://github.com/moquette-io/moquette。我们在发布mqtt服务的同时,是可以提供websocket服务器的,有些场景下需要用到&a…

OpenAI新模型开始训练!GPT6?

国内可用潘多拉镜像站GPT-4o、GPT-4(更多信息请加Q群865143845): 站点:https://xgpt4.ai0.cn/ OpenAI 官网 28 日发文称,新模型已经开始训练! 一、新模型开始训练 原话:OpenAI has recently begun training…

价值飙升30%,AI PC拉动半导体出货潮

由于处理器和DRAM的升级,大摩预测每台AI PC的半导体价值将增长20%-30%,PC平均售价也将提高7%。 台北国际电脑展即将于6月2日隆重开幕。 随着展会的临近,各种现象级的AI PC也蓄势待发。 就在上周,联想在业绩会上,首次…

2-EMMC启动及各分区文件生成过程

EMMC的使用比nand flash还是复杂一些,有其特有的分区和电器性能 1、启动过程介绍 跟普通nand或spi flash不同,uboot前面还有好几级 在vendor某些厂商的设计中,ATF并不是BOOTROM加载后的第一个启动镜像,可能是这样的: …

java的方法重写

重写的概述 重写是基于继承来说的,因为父类的方法需求不满足于子类,所以就要在进行方法重写,如果不知道继承是啥可以看我上一篇笔记 在这里用代码举个栗子 例如:我们定义了一个动物类代码如下: public class Animal…

Leecode热题100---二分查找--4:寻找两个正序数组的中位数

题目: 给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。 解法1、暴力解法(归并) 思路: 合并 nums1,nums2 为第三个数组 排序第三个数…

XXL-JOB分布式任务调度框架详解(全网最详细!!!)

​​​​​​​ 引言 第一部分:XXL-JOB概述 第二部分:架构与组件 第三部分:使用教程 第四部分:源码分析 第五部分:最佳实践 引言 在分布式系统中,任务调度是一项基础而又关键的服务,它涉…

Java设计模式:享元模式实现高效对象共享与内存优化(十一)

码到三十五 : 个人主页 目录 一、引言二、享元设计模式的概念1. 对象状态的划分2. 共享机制 三、享元设计模式的组成四、享元设计模式的工作原理五、享元模式的使用六、享元设计模式的优点和适用场景结语 [参见]: Java设计模式:核心概述&…

解决Spring BeanCreationException的常见问题

解决Spring BeanCreationException的常见问题 在使用Spring框架进行开发时,可能会遇到各种异常,其中之一就是BeanCreationException。本文将介绍如何解决以下特定的异常: org.springframework.beans.factory.BeanCreationException: Error …

拼接字符串

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 使用“”运算符可完成对多个字符串的拼接,“”运算符可以连接多个字符串并产生一个字符串对象。 例如,定义两个字符串&#…