动手学深度学习——线性回归从零开始

  1. 生成数据集synthetic_data()
  2. 读取数据集data_iter()
  3. 初始化模型参数w, b
  4. 定义模型:线性回归模型linreg()
  5. 定义损失函数:均方损失squared_loss()
  6. 定义优化算法:梯度下降sgd()
  7. 进行训练:输出损失loss和估计误差
%matplotlib inline
import random
import torch
from d2l import torch as d2l# 生成数据集
def synthetic_data(w, b, num_examples): #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape(-1, 1)true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)# 读取数据集
def data_iter(batch_size, features, labels):# 获取x中特征的长度,转换成列表,通过for循环进行批量生成num_examples = len(features)indices = list(range(num_examples))# 这些样本是随机读取的,没有特定的顺序random.shuffle(indices)for i in range(0, num_examples, batch_size):# 此时获取的是向量了,最后如果不足批量大小取最后剩余的batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]# 初始化模型参数
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 定义模型:线性回归模型
def linreg(X, w, b):return torch.matmul(X, w) + b# 定义优化算法sgd
# lr:学习率
def sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()"""训练:1、读取批量样本获取预测2、计算损失,反向传播,存储每个参数的梯度3、调用优化算法sgd来更新模型参数4、输出每轮的损失
"""
lr = 0.03
num_epochs = 10
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):# X和y的小批量损失# net()返回y=X*w+b,loss()返回(y'-y)^2/2l = loss(net(X, w, b), y)\# 因为l形状是(batch_size, 1),而不是一个标量。L中的所有元素被加到一起# 并以此计算关于[w, b]的梯度l.sum().backward()# sgd():w = w - lr*w/batch_size# 使用参数的梯度更新参数sgd([w, b], lr, batch_size)with torch.no_grad():# loss(y_hat, y)# net(features, w, b)相当于y_hat,labels相当于ytrain_1 = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss{float(train_1.mean()):f}')# 输出w和b的估计误差
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/8686.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java项目之人才公寓管理系统(ssm+mysql+jsp)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的人才公寓管理系统。技术交流和部署相关看文章末尾! 开发环境: 后端: 开发语言:Java 框架&…

iOS pod EaseIMKit库如何放在本地使用

在使用环信EaseIMKit库的时候,发现有些开发者需要改动库中的一些逻辑,或者有UI上的一些调整,如果直接去改pods里面的库,在之后的库版本升级会把之前修改过的代码覆盖掉,这个时候我们就需要pod指向本地的库,…

KubeVela篇05:为kubevela开发terraform-mycloud Addon插件

通过前面的章节,我们已经学习了解terraform,并通过vpc资源例子,为私有云/混合云开发了terraform provider,这一节介绍如何将我们开发的mycloud terraform provider整合到kubevela控制平台上,以通过在application中声明一个kubevela组件的方式去申请基础设施资源。 我们需…

【数据结构】---时间复杂度与空间复杂度

时间复杂度与空间复杂度 1.📉 时间复杂度📌1.1 时间复杂度的概念1.2 大O的渐进表示法 🏰空间复杂度📃例题分析1.案例(常数阶)2.案例(线性阶)3.案例:(平方阶&a…

css元素定位:通过元素的标签或者元素的id、class属性定位

前言 大部分人在使用selenium定位元素时,用的是xpath元素定位方式,因为xpath元素定位方式基本能解决定位的需求。xpath元素定位方式更直观,更好理解一些。 css元素定位方式往往被忽略掉了,其实css元素定位方式也有它的价值&…

【数据库 - 用户权限管理】(简略)

目录 一、概述 二、用户权限类型 1.ALL PRIVILEGES 2.CREATE 3.DROP 4.SELECT 5.INSERT 6.UPDATE 7.DELETE 8.INDEX 9.ALTER 10.CREATE VIEW和CREATE ROUTINE 11.SHUTDOWN 12GRANT OPTION 三、语句格式 1.用户赋权 2.权限删除 3.用户删除 一、概述 数据库用…

Redis多级缓存

文章目录 多级缓存背景JVM进程缓存Caffeine案例分析安装MySQL导入SQL Lua语法变量与循环数据类型声明变量循环 函数与条件控制函数条件控制 实现多级缓存安装OpenResty安装opm工具目录结构配置Nginx的环境变量运行启动 快速入门反向代理流程OpenResty监听请求编写item.lua 请求…

基于深度学习的高精度交通信号灯检测系统(PyTorch+Pyside6+YOLOv5模型)

摘要:基于深度学习的高精度交通信号灯检测识别可用于日常生活中检测与定位交通信号灯目标,利用深度学习算法可实现图片、视频、摄像头等方式的交通信号灯目标检测识别,另外支持结果可视化与图片或视频检测结果的导出。本系统采用YOLOv5目标检…

AI > 语音识别开源项目列举

名称所属开发机构使用场景优缺点技术特点占有率描述CMU Sphinx卡内基梅隆大学嵌入式设备、服务器应用优点:可用于嵌入式设备和服务器应用。 缺点:准确率相对较低,适用范围有限。- 支持多种语言模型和工具。- 适用于嵌入式设备和服务器应用。中…

站在读者角度:10个技巧写出有价值的文章

站在读者的角度,以下是10个写出有价值的文章的技巧: 1.确定你的目标读者:在开始写作之前,确定你的目标读者是谁,这有助于你更好地针对他们的需求和兴趣来写作。 2.了解你的读者:通过调查、研究和互动&…

Unity UGUI的EventSystem(事件系统)组件的介绍及使用

Unity UGUI的EventSystem(事件系统)组件的介绍及使用 1. 什么是EventSystem组件? EventSystem是Unity UGUI中的一个重要组件,用于处理用户输入事件,如点击、拖拽、滚动等。它负责将用户输入事件传递给合适的UI元素&a…

【LeetCode】78.子集

题目 给你一个整数数组 nums ,数组中的元素 互不相同 。返回该数组所有可能的子集(幂集)。 解集 不能 包含重复的子集。你可以按 任意顺序 返回解集。 示例 1: 输入:nums [1,2,3] 输出:[[],[1],[2],[1…

vue实现@唤起列表功能(借助ElAutocomplete)

实现一个输入组件 myAutoComplete.vue <template><el-autocomplete ref"autoRef" :model-value"state" input"handleInput" :onkeyup"handleKey":fetch-suggestions"querySearch" select"handleSelect" …

Spring动态代理

一、代理 代理&#xff08;Proxy&#xff09;是一种设计模式&#xff0c;提供了对目标对象的另外的访问方式。 代理意义&#xff1a;可以再目标对象代码实现的基础上&#xff0c;增强额外的功能代码。 二、静态代理 静态代理&#xff0c;编译时就已经确定下来了接口代理类被…

LeetCode每日一题-接雨水

给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 示例 1&#xff1a; 输入&#xff1a;height [0,1,0,2,1,0,1,3,2,1,2,1] 输出&#xff1a;6 解释&#xff1a;上面是由数组 [0,1,0,2,1,0,1,3,2,1,2,1] 表…

Spring中事务失效的8中场景

1. 数据库引擎不支持事务 这里以 MySQL为例&#xff0c;MyISAM引擎是不支持事务操作的&#xff0c;一般要支持事务都会使用InnoDB引擎&#xff0c;根据MySQL 的官方文档说明&#xff0c;从MySQL 5.5.5 开始的默认存储引擎是 InnoDB&#xff0c;之前默认的都是 MyISAM&#xff…

Python in VS Code 2023年7月发布|Mypy 扩展预览版与调试扩展、Pylance 本地化及其他

排版&#xff1a;Alan Wang 我们很高兴地宣布 Visual Studio Code 的 Python 和 Jupyter 扩展将于 2023 年 7 月发布&#xff01; 此版本包括以下更新&#xff1a; Mypy 扩展预览版预览版中的调试扩展Pylance 本地化使用 Pylance 的第三方库的索引持久性即将弃用 Python 3.7 支…

分享5款有点冷门的实用派软件

​ 分享5款冷门但值得下载的Windows软件&#xff0c;个个都是实用&#xff0c;你可能一个都没见过&#xff0c;但是 我觉得你用过之后可能就再也离不开了。 系统监控——XMeters ​ XMeters是一个系统监控软件&#xff0c;可以让你在任务栏上显示各种系统信息&#xff0c;如C…

C# List 详解三

目录 11.Equals(Object) 12.Exists(Predicate) 13.Find(Predicate) 14.FindAll(Predicate) 15.FindIndex(Int32, Int32, Predicate) 16.FindIndex(Int32, Predicate) 17.FindIndex(Predicate) C# List 详解一 1.Add(T)&#xff0c;2.AddRa…

15 | 线性回归代码实现

文章目录 线性回归实现Lasso回归和岭回归多项式回归线性回归实现 线性回归是处理一个或者多个自变量和因变量之间的关系,然后进行建模的一种回归分析方法。如果只有一个自变量的情况称为一元线性回归,如果有两个或两个以上的自变量,就称为多元回归。在sklearn中linear_mode…