【李沐】3.2线性回归从0开始实现

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1、生成数据集:
看最后的效果,用正态分布弄了一些噪音
在这里插入图片描述
上面这个具体实现可以看书,又想了想还是上代码把:
在这里插入图片描述
按照上面生成噪声,其中最后那个代表服从正态分布的噪声

def synthetic_data(w, b, num_examples):  # 定义函数 synthetic_data,接受权重 w、偏差 b 和样本数量 num_examples 作为参数"""生成 y = Xw + b + 噪声 的合成数据集"""# 生成一个形状为 (num_examples, len(w)) 的特征矩阵 X,其中的元素是从均值为 0、标准差为 1 的正态分布中随机采样得到X = torch.normal(0, 1, (num_examples, len(w)))# 计算目标值 y,通过将特征矩阵 X 与权重 w 相乘,然后加上偏差 b,模拟线性回归的预测过程y = torch.matmul(X, w) + b# 给目标值 y 添加一个小的随机噪声,以模拟真实数据中的噪声。噪声从均值为 0、标准差为 0.01 的正态分布中随机采样得到y += torch.normal(0, 0.01, y.shape)# 返回特征矩阵 X 和目标值 y(将目标值 y 重塑为列向量的形式)return X, y.reshape((-1, 1)
# 定义真实的权重 true_w 为 [2, -3.4]
true_w = torch.tensor([2, -3.4])# 定义真实的偏差 true_b 为 4.2
true_b = 4.2# 调用 synthetic_data 函数生成合成数据集,传入真实的权重 true_w、偏差 true_b 和样本数量 1000
# 这将返回特征矩阵 features 和目标值 labels
features, labels = synthetic_data(true_w, true_b, 1000)

2、读取数据集
注意一般情况下要打乱。
下面函数的作用是该函数接收批量⼤⼩、特征矩阵和标签向量作为输⼊,⽣成⼤⼩为batch_size的⼩批量。每个⼩批量包含⼀组特征和标签。

def data_iter(batch_size, features, labels):num_examples = len(features)  # 获取样本数量indices = list(range(num_examples))  # 创建一个样本索引列表,表示样本的顺序# 将样本索引列表随机打乱,以便随机读取样本,没有特定的顺序random.shuffle(indices)# 通过循环每次取出一个批次大小的样本for i in range(0, num_examples, batch_size):# 计算当前批次的样本索引范围,确保不超出总样本数量batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])# 通过索引获取对应的特征和标签,然后通过 yield 返回这个批次的数据# yield 使得函数可以作为迭代器使用,在每次迭代时产生一个新的批次数据yield features[batch_indices], labels[batch_indices]

3、初始化模型参数
第一步:前面两行代码,,我
们通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重,并将偏置初始化为0。
计算梯度使用2.5节引入的自动微分

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

4、定义模型
这里注意b是一个标量和向量相加,咋办?
前面说过向量的广播机制,就相当于是加到每一个上面

def linreg(X, w, b): #@save
"""线性回归模型"""
return torch.matmul(X, w) + b

5、定义损失函数
y.reshape(y_hat.shape))啥意思?
y_hat是真实值,这里的意思是弄成和y_hat相同的大小

def squared_loss(y_hat, y): #@save
"""均⽅损失"""
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

6、优化算法
问:这里的参数是啥参数?params
更新完的参数不用返回吗?
为什么需要梯度清零?

def sgd(params, lr, batch_size):  # 定义函数 sgd,接受参数 params、学习率 lr 和批次大小 batch_size"""小批量随机梯度下降"""with torch.no_grad():  # 使用 torch.no_grad() 来关闭梯度跟踪,以减少内存消耗for param in params:  # 遍历模型参数列表param -= lr * param.grad / batch_size  # 更新参数:参数 = 参数 - 学习率 * 参数梯度 / 批次大小param.grad.zero_()  # 清零参数的梯度,以便下一轮梯度计算

7、训练
问:反向传播是为了干啥?
是为了计算梯度,那梯度是啥呢
梯度是参数更快收敛的方向(就是向量)
优化方法是干啥的?
优化方法就是根据上面传过来的梯度,计算参数更新
所以,这几章看完后需要梳理深度学习的整个过程,以及每块有哪些方法,这些方法的特点和用那种方法更好
问(1)每个epoch训练多少数据?
整个训练集
(2)损失函数是啥?
损失函数是用来计算真实值域预测值之间的距离,当然是距离越小越好,可以拿均方误差想一下
(3)l.sum().backward()是啥意思?
看注释,补充:.backward() 方法用于执行自动求导,计算总的损失值对于模型参数的梯度。这将会构建计算图并沿着图的反向传播路径计算梯度。
(4)但是上面所说的梯度保存在哪里呢?
w.grad 和 b.grad 中
(5)但是sgd中也没有用到w.grad 啊?
用到了,param 可以是 w 或者 b,而 param.grad 则是相应参数的梯度。
(6)新问题:train_l = loss(net(features, w, b), labels)不是在前面已经计算过损失函数了吗?为啥在这里还需要计算?
前面计算损失函数是间断性的,目的是更新模型参数。
后面仍然计算的目的是根据更新完的参数对模型在整个训练集上与真实标签的差距做一个评估。

lr = 0.03  # 设置学习率为 0.03,控制每次参数更新的步幅num_epochs = 3  # 设置训练的轮次(迭代次数)为 3,即遍历整个数据集的次数net = linreg  # 定义模型 net,通常表示线性回归模型loss = squared_loss  # 定义损失函数 loss,通常为均方损失函数,用于衡量预测值与真实值之间的差距
for epoch in range(num_epochs):  # 迭代 num_epochs 轮,进行训练for X, y in data_iter(batch_size, features, labels):  # 遍历数据集的每个批次l = loss(net(X, w, b), y)  # 计算当前批次的损失值 l,表示预测值与真实值之间的差距# 因为 l 的形状是 (batch_size, 1),而不是一个标量。将 l 中的所有元素加起来,# 并计算关于 [w, b] 的梯度l.sum().backward()sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数,执行随机梯度下降算法with torch.no_grad():train_l = loss(net(features, w, b), labels)  # 在整个训练集上计算损失值# 打印当前迭代轮次和训练损失值的均值print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

8、练习中的问题

  1. 如果我们将权重初始化为零,会发⽣什么。算法仍然有效吗?
    无效,为啥?因为,不同的X输入是相同的输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/48749.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

利用大模型反馈故障的解决方案

背景 观测云有两个错误巡检脚本,RUM 错误巡检和 APM 错误巡检,代码均开源。 错误巡检的主要目的是发现新出现的错误消息(error stack),原有的巡检在上报了相应的事件报告后,只是定位了问题,并没有给出合适的解决方案。…

C++(3)C++对C的扩展Extension

类型增强 1、类型更加严格 不初始化&#xff0c;无法通过编译&#xff1b;C不初始化&#xff0c;则随机赋值 #include <iostream> #include <stdlib.h>int main() {const int a 100; //真正的const,无法修改 // int *p &a; 报错const int *p…

Pandas基础知识

文章目录 Pandas的数据结构Series --- 由数据和索引组成&#xff08;索引&#xff08;index&#xff09;在左&#xff0c;数据&#xff08;values&#xff09;在右&#xff09;DataFrame --- 索引包括行索引和列索引&#xff0c;每列数据可以是不同的类型 Pandas的索引操作 ---…

SpringMVC拦截器学习笔记

SpringMVC拦截器 拦截器知识 拦截器(Interceptor)用于对URL请求进行前置/后置过滤 Interceptor与Filter用途相似但实现方式不同 Interceptor底层就是基于Spring AOP面向切面编程实现 拦截器开发流程 Maven添加依赖包servlet-api <dependency><groupId>javax.se…

C++学习之九

1)普通类的成员函数模板 class A { public:template<typename T> //类的成员函数模板,//成员函数模板和函数模板长得样子一样&#xff01;void func(T tmp); };template<typename T> void A::func(T tmp) {cout << tmp << endl; }int main() {A a;a.…

nginx代理webSocket链接,webSocket频繁断开重连

一、场景 1、使用nginx代理webSocket链接&#xff0c;消息发送和接收都是正常的&#xff0c;但webSocket链接会频繁断开重连 2、如果不使用nginx代理则一切正常 3、程序没有做webSocket心跳处理 如下图 二、nginx代理配置 upstream cloud_ass {#ip_hash;server 192.168.1.…

2023年国赛数学建模思路 - 案例:随机森林

文章目录 1 什么是随机森林&#xff1f;2 随机深林构造流程3 随机森林的优缺点3.1 优点3.2 缺点 4 随机深林算法实现 建模资料 ## 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 什么是随机森林&#xff…

leetcode做题笔记95. 不同的二叉搜索树 II

给你一个整数 n &#xff0c;请你生成并返回所有由 n 个节点组成且节点值从 1 到 n 互不相同的不同 二叉搜索树 。可以按 任意顺序 返回答案。 思路一&#xff1a;递归 struct TreeNode ** partition(int start, int end, int* returnSize){*returnSize 0;int size 32;stru…

从2023年世界机器人大会发现机器人新趋势

机器人零部件为何成2023年世界机器人大会关注热门&#xff1f; 在原先&#xff0c;机器人的三大核心零部件是控制系统中的控制器、驱动系统中的伺服电机和机械系统中的精密减速器。如今&#xff0c;机器人的主体框架结构已经落实&#xff0c;更多机器人已经开始深入到各类场景中…

c++ std::sort的简单用法

直接看代码即可&#xff0c;对于lambda部分的解释&#xff1a;a,b基本算作两个抽象的比较对象&#xff0c;每次会比较这两个&#xff0c;返回true:表示a应该排在b前面(a<b)&#xff0c;返回false:表示b应该排在a前面(a>b)&#xff0c;具体可查看cpprefernece #include &…

Prompt本质解密及Evaluation实战(一)

一、基于evaluation的prompt使用解析 基于大模型的应用评估与传统应用程序的评估不太一样&#xff0c;特别是基于GPT系列或者生成式语言模型&#xff0c;因为模型生成的内容与传统意义上所说的内容或者标签不太一样。 以下是借用了ChatGPT官方的evaluation指南提出的对结果的具…

kali的一些使用和ms08-067、ms17-010漏洞

VM虚拟机-三种网络连接方式&#xff08;桥接、NAT、仅主机模式&#xff09; 虚拟机网络连接 一、Bridged&#xff08;桥接&#xff09; 二、NAT&#xff08;网络地址转换&#xff09; 三、Host-Only&#xff08;仅主机&#xff09; 在vmware软件中&#xff0c;选项栏的“编…

[计算机入门] 窗口操作

3.3 窗口操作 之前介绍过如何调整窗口大小。接下来介绍如何对窗口进行排布等操作。 当我们想要将某个窗口调整到整个屏幕的左边或者右边(占整个屏幕的一半)&#xff0c;可以在选中并激活窗口后&#xff0c;按Win ←/→ 进行调整。 此时&#xff0c;还可以通过Win↑/↓调整该…

微信小程序教学系列(3)

微信小程序教学系列 第三章&#xff1a;小程序高级开发技巧 1. 小程序API的使用 小程序API简介 小程序API是小程序提供的一系列接口&#xff0c;用于实现各种功能和操作。通过调用小程序API&#xff0c;可以实现页面跳转、数据存储、网络请求等功能。 使用小程序API的步骤…

ATFX汇市:杰克逊霍尔年会降至,鲍威尔或再发鹰派言论

环球汇市行情摘要—— 昨日&#xff0c;美元指数下跌0.11%&#xff0c;收盘在103.33点&#xff0c; 欧元升值0.22%&#xff0c;收盘价1.0898点&#xff1b; 日元贬值0.58%&#xff0c;收盘价146.23点&#xff1b; 英镑升值0.18%&#xff0c;收盘价1.2757点&#xff1b; 瑞…

Flutter GetXController 动态Tabbar 报错问题

场景&#xff1a; 1.Tabbar的内容是接口获取的 2. TabController? tabController;&#xff1b; 在onInit 方法中初始化tabbarController tabController TabController(initialIndex: 0, length: titleDataList.length, vsync: this); 这时候会报一个错误 Controllers l…

docker版jxTMS使用指南:使用jxTMS提供数据

本文讲解了如何jxTMS的数据访问框架&#xff0c;整个系列的文章请查看&#xff1a;docker版jxTMS使用指南&#xff1a;4.4版升级内容 docker版本的使用&#xff0c;请查看&#xff1a;docker版jxTMS使用指南 4.0版jxTMS的说明&#xff0c;请查看&#xff1a;4.0版升级内容 4…

微信小程序教学系列(8)

微信小程序教学系列 第八章&#xff1a;小程序国际化开发 欢迎来到第八章&#xff01;这一次我们要谈论的是小程序国际化开发。你可能会问&#xff0c;什么是国际化&#xff1f;简单来说&#xff0c;国际化就是让小程序能够适应不同的语言和地区&#xff0c;让用户们感受到更…

Python 合并多个 PDF 文件并建立书签目录

今天在用 WPS 的 PDF 工具合并多个文件的时候&#xff0c;非常不给力&#xff0c;居然卡死了好几次&#xff0c;什么毛病&#xff1f;&#xff01; 心里想&#xff0c;就这么点儿功能&#xff0c;居然收了我会员费都实现不了&#xff1f;不是吧…… 只能自己来了&#xff0c;…

导出pdf

该方法导出的pdf大小是A4纸的尺寸&#xff0c;如果大于1页需要根据元素高度进行截断的话&#xff0c;页面元素需要加 class ergodic-dom&#xff0c;方法里面会获取ergodic-dom元素&#xff0c;对元素高度和A4高度做比较&#xff0c;如果大于A4高度&#xff0c;会塞一个空白元素…