线性回归+基础优化算法

案例代码用法

torch.tensor(data, dtype=None, device=None, requires_grad=False)
# data:表示要转换为张量的数据。可以是列表、NumPy 数组、标量值或其他可转换为张量的对象。
# dtype:可选参数,用于指定输出张量的数据类型。如果不指定,则会自动推断数据类型。
# device:可选参数,用于指定输出张量所在的设备。默认为 None,表示使用默认设备(通常是 CPU)。
# requires_grad:可选参数,表示是否需要计算梯度。默认为 False,表示不计算梯度。torch.normal(mean, std, size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
# mean:表示正态分布的均值。
# std:表示正态分布的标准差。
# size:表示输出张量的大小。
# out:可选参数,用于指定输出张量。
# dtype:可选参数,用于指定输出张量的数据类型。
# layout:可选参数,用于指定输出张量的布局。
# device:可选参数,用于指定输出张量所在的设备。
# requires_grad:可选参数,表示是否需要计算梯度。torch.matmul(input, other, out=None)
# input:表示第一个输入张量,可以是一个具有至少两个维度的张量。
# other:表示第二个输入张量,也可以是一个具有至少两个维度的张量。
# out:可选参数,表示输出张量。如果指定了 out,则结果将被写入到该张量中。如果未指定 out,则会创建一个新的张量来保存结果。d2l.set_figsize() #是 Deep Learning - The Straight Dope(D2L)教材中提供的一个辅助函数,用于设置绘图的图像尺寸。def load_array(data_arrays, batch_size, is_train=True):# data_arrays:一个包含输入特征    X    和对应标签    y    的列表或元组。通过解包    data_arrays,将其中的元素作为参数传递给    TensorDataset    类,构建一个数据集。# batch_size:整数值,指定每个批次的样本数量。# is_train:一个布尔值,指示是否在训练模式下,默认为    True。dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)# dataset:数据集,可以是 TensorDataset 或其他自定义的数据集类。数据集中包含了训练样本的特征和标签。# batch_size:一个整数,表示每个批次(batch)中包含的样本数量。# shuffle:一个布尔值,指示是否在每个迭代周期中打乱数据顺序。# num_workers:一个整数,表示用于加载数据的子进程数量。net[0]是模型中的第一个全连接层。.weight.data用于访问全连接层的权重参数,并使其正态分布地初始化。

线性回归

import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):""""生成 y = Xw + b + 噪声"""# 第一个参数 0 表示正态分布的均值。# 第二个参数 1 表示正态分布的标准差(方差的平方根)。# 第三个参数 (num_examples, len(w)) 表示输出张量的形状,其中 num_examples 是例子的数量,len(w) 是 w 向量的长度。X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)print("features.shape:", features.shape)
print("labels.shape:", labels.shape)d2l.set_figsize()# features[:,1]表示选择 features 张量的第二列数据
#    .detach().numpy() 将其转换为 NumPy 数组
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1)
# d2l.plt.show()# data_iter函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量
def data_iter(batch_size, features, labels):num_examples = len(features)  # 样本个数indices = list(range(num_examples))  # 样本索引random.shuffle(indices)  # 把索引随机打乱for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i+batch_size,num_examples)])  # 当i+batch_size超出时,取num_examplesyield features[batch_indices], labels[batch_indices]  # 它会在迭代过程中依次生成值,而不是一次性返回所有值。batch_size = 10
for X, y in data_iter(batch_size, features, labels):# print(X, '\n', y)breakw = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + bdef squared_loss(y_hat, y):"""均方损失"""return (y_hat - y.reshape(y_hat.shape))**2 / 2def sgd(params, lr, batch_size):"""小批量随机梯度下降"""with torch.no_grad():  # 不要产生梯度计算,减少内存消耗for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss
# 训练过程
for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)  # x和y的小批量损失l.sum().backward()  # 反向传播,计算梯度sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch{epoch + 1}, loss{float(train_l.mean()):f}')print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

代码优化

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
from torch import nntrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)  # 生成人工数据集def load_array(data_arrays, batch_size, is_train=True):"""构造一个pytorch数据迭代器"""#*data_arrays 表示将 data_arrays 中的元素解包,并作为参数传递给 TensorDataset 类dataset = data.TensorDataset(*data_arrays)#使用 DataLoader 创建一个数据加载器,用于从数据集中随机选择批次大小的样本。return data.DataLoader(dataset,batch_size,shuffle=is_train)#返回的是从dataset中随机挑选出batch_size个样本出来batch_size = 10
data_iter = load_array((features,labels),batch_size)# 返回的数据的迭代器
print(next(iter(data_iter)))
# print(next(iter(data_iter)))net = nn.Sequential(nn.Linear(2,1))# 初始化参数
net[0].weight.data.normal_(0,0.01)# 初始化第一个全连接层的权重参数
net[0].bias.data.fill_(0)# 初始化第一个全连接层的偏置为0loss = nn.MSELoss()trainer= torch.optim.SGD(net.parameters(),lr=0.03)num_epochs = 3
for epoch in range(num_epochs):for X,y in data_iter:# print("X:",X)# print("y:",y)l = loss(net(X),y)# net(X) 为计算出来的线性回归的预测值trainer.zero_grad()# 梯度清零l.backward()trainer.step()# SGD优化器优化模型l = loss(net(features),labels)print(f'epoch{epoch+1},loss{l:f}')

跟学:B站李沐《动手学深度学习》

资源来源:B站小王同学在积累

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/46249.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode:字符串相乘(两种方法)

题目: 给定两个以字符串形式表示的非负整数 num1 和 num2,返回 num1 和 num2 的乘积,它们的乘积也表示为字符串形式。 注意:不能使用任何内置的 BigInteger 库或直接将输入转换为整数。 示例 1: 输入: num1 "2", nu…

【生态经济学】利用R语言进行经济学研究技术——从数据的收集与清洗、综合建模评价、数据的分析与可视化、因果推断等方面入手

查看原文>>>如何快速掌握利用R语言进行经济学研究技术——从数据的收集与清洗、综合建模评价、数据的分析与可视化、因果推断等方面入手 近年来,人工智能领域已经取得突破性进展,对经济社会各个领域都产生了重大影响,结合了统计学、…

周易卦爻解读笔记——未济

第六十四卦未济 火水未济 离上坎下 未济卦由否卦所变,否卦六二与九五换位,象征尚未完成。 天地否 未济卦和既济卦既是错卦又是覆卦,这也是最后一卦,序卦传【物不可穷也,故受之以未济终焉】 未济卦象征尚未完成&…

跨域资源共享 (CORS) | PortSwigger(burpsuite官方靶场)【万字】

写在前面 在开始之前,先要看看ajax的局限性和其他跨域资源共享的方式,这里简单说说。 下面提到大量的origin,注意区分referer,origin只说明请求发出的域。 浏览器的同源组策略:如果两个 URL 的 protocol、port 和 h…

达梦数据库表空间创建和管理

概述 本文将介绍在达梦数据库如何创建和管理表空间。 1.创建表空间 1.1表空间个数限制 理论上最多允许有65535个表空间,但用户允许创建的表空间 ID 取值范围为0~32767, 超过 32767 的只允许系统使用,ID 由系统自动分配,ID不能…

网页及屏幕的尺寸区域宽高总结

网页可见区域宽 document.body.clientWidth 网页可见区域高 document.body.clientHeight 网页可见区域宽(包括边线的宽) document.body.offsetWidth 网页可见区域高(包括边线的宽) document.body.offsetHeight 网页正文全文宽 document.body.scrollWidth 网页正…

数据库厂商智臾科技加入龙蜥社区,打造多样化的数据底座

近日,浙江智臾科技有限公司(以下简称“智臾科技”)正式签署 CLA 贡献者许可协议,加入龙蜥社区(OpenAnolis)。 智臾科技主创团队从 2012 年开始投入研发 DolphinDB。DolphinDB 作为一款基于高性能时序数据库…

W5500-EVB-PICO做UDP Client进行数据回环测试(八)

前言 上一章我们用开发板作为UDP Server进行数据回环测试,本章我们让我们的开发板作为UDP Client进行数据回环测试。 连接方式 使开发板和我们的电脑处于同一网段: 开发板通过交叉线直连主机开发板和主机都接在路由器LAN口 测试工具 网路调试工具&a…

Vue--进度条

挺有意思的&#xff0c;大家可以玩一玩儿&#xff1a; 前端代码如下&#xff1a;可以直接运行的代码。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content&qu…

open cv学习 (三) 绘制图形和文字

绘制图形和文字 demo1 # 绘制线段 import cv2 import numpy as np # 创建一个300300 3通道的图像 canvas np.ones((300, 300, 3), np.uint8)*255 # 绘制一条直线起点坐标为(50, 50)终点坐标为(250,50),颜色的BGR值为(255, 0, 0)(蓝色)&#xff0c;粗细为5 canvas cv2.line(…

使用Mavon-Editor编辑器上传本地图片到又拍云云存储(Vue+SpringBoot)

需求&#xff1a;将本地的图片上传到服务器或者云存储中&#xff0c;考虑之后&#xff0c;这里我选的是上传到又拍云云存储。 技术背景&#xff1a; 前端&#xff1a;VueAjax 后端&#xff1a;SpringBoot 存储&#xff1a;又拍云云存储原理&#xff1a;Mavon-Editor编辑器有两个…

flutter定位简单工具类

import package:permission_handler/permission_handler.dart;class PermissionUtil {/// 获取用户定位权限static Future<bool> getLocationStatus() async {Map<Permission, PermissionStatus> statuses await [Permission.location,].request();return statuse…

财务数据分析用什么软件好?奥威BI自带方案

做财务数据分析&#xff0c;光有软件还不够&#xff0c;还需要有标准化的智能财务数据分析方案。奥威BI数据可视化工具就是这样一款自带智能财务数据分析方案的软件。 ”BI方案“&#xff0c;一站式做财务数据分析 奥威BI数据可视化工具和智能财务分析方案结合&#xff0c;可…

Dockerfile创建 LNMP 服务+Wordpress 网站平台

文章目录 一.环境及准备工作1.项目环境2.服务器环境3.任务需求 二.Linux 系统基础镜像三.docker构建Nginx1.建立工作目录上传安装包2.编写 Dockerfile 脚本3.准备 nginx.conf 配置文件4.生成镜像5.创建自定义网络6.启动镜像容器7.验证 nginx 四.docker构建Mysql1. 建立工作目录…

第22次CCF计算机软件能力认证

第一题&#xff1a;灰度直方图 解题思路&#xff1a; 哈希表即可 #include<iostream> #include<cstring>using namespace std;const int N 610; int a[N]; int n , m , l;int main() {memset(a , 0 , sizeof a);cin >> n >> m >> l;for(int …

Python魔术方法大全

Python魔术方法大全 在Python中&#xff0c;所有以“__”双下划线包起来的方法&#xff0c;都统称为“Magic Method”&#xff08;魔术方法&#xff09;,例如类的初始化方法 init ,Python中所有的魔术方法均在官方文档中有相应描述&#xff0c;这边给大家把所有的魔术方法汇总…

Docker mysql主从同步安装

1. 构建master实例 docker run -p 3307:3306 --name mysql-master \ -v /mydata/mysql-master/log:/var/log/mysql \ -v /mydata/mysql-master/data:/var/lib/mysql \ -v /mydata/mysql-master/conf:/etc/mysql \ -e MYSQL_ROOT_PASSWORDroot \ -d mysql:5.7 2. 构建master配置…

热烈祝贺贵州董程酿酒成功入选航天系统采购供应商库

经过航天系统采购平台的严审&#xff0c;贵州董程酿酒有限公司成功入选中国航天系统采购供应商库。航天系统采购平台是航天系统内企业采购专用平台&#xff0c;服务航天全球范围千亿采购需求&#xff0c;目前&#xff0c;已有华为、三一重工、格力电器、科大讯飞等企业、机构加…

Markdown 基本语法

风无痕 August 21,2023 总览 几乎所有 Markdown 应用程序都支持 John Gruber 原始设计文档中列出的 Markdown 基本语法。但是&#xff0c;Markdown 处理程序之间存在着细微的变化和差异&#xff0c;我们都会尽可能标记出来。 标题&#xff08;Headings&#xff09; 要创建标…

「UG/NX」Block UI 体收集器BodyCollector

✨博客主页何曾参静谧的博客📌文章专栏「UG/NX」BlockUI集合📚全部专栏「UG/NX」NX二次开发「UG/NX」BlockUI集合「VS」Visual Studio「QT」QT5程序设计「C/C+&#