深度学习_4 数据训练之线性回归

训练数据

线性回归

基本原理

比如我们要买房,机器学习深度学习来预测房价。房价的影响因素有:卧室数量,卫生间数量,居住面积。此外,还需要加上偏差值来计算。我们要找到一个正确率高的计算方法来计算。

首先,我们需要有一个计算公式: y = w 1 x 1 + w 2 x 2 + w 3 x 3 + b y=w_1x_1+w_2x_2+w_3x_3+b y=w1x1+w2x2+w3x3+b ,w是权重,b是偏差值。w 和 b 的值需要我们自己选取最优解。

线性模型实现:输入 x 是一个 n*1 的向量,权重 w 是一个 n*1 的向量,b是一个标量。 y = < w , x > + b y=<w,x>+b y=<w,x>+b

这就可以看做一个简单的神经网络了。我们输入多个参数,神经网络处理后最终得到一个标量结果。

1698814961423

神经网络就像级联的神经元一样,每一个神经元是一层,进行一次处理,得到的参数再传递给下一层进行进一步处理。这个例子中层数比较少。

计算得到结果后,如何评价结果的质量?

1698815105842

训练数据:结合计算公式和每次的 loss 反馈,得到 w 和 b 的最优解。

线性回归是一种可以得到显示解(就是确定的解,不像 y=x+C 这种包含未知部分的函数解)的单层神经网络,主要运用加权和偏差对 n 维输入进行处理,通过平方损失来衡量预测值和真实值的差异。

优化方法

具体是用什么样的方法对 w b 进行优化呢?

image-20231101131536440

每次我们沿垂直于切线的方向,也就是求导方向,前进一个步长(学习率),因为沿着这个方向 w 的优化效率最高。

因此很容易联想到学习率步长是有一个合适的范围的,太长了一下子跳太远了。太短了要迭代太多次速度太慢。

另外,重新取点计算梯度值对于复杂模型来说其实是很耗时间的,可能几个小时以上。而且最小值可能不止一个,非要找到损失最小的样本点也很废算力。因此实际操作中一般是随机挑选几个样本点,一部分批量(batch)平均来计算梯度值(小批量随机梯度下降)。batchsize 太小,并行化计算无法完全发挥出来;太大,浪费计算。

实现
%matplotlib inline
import random
import torch
from d2l import torch as d2l# 生成人造数据集
def synthetic_data(w, b, num_examples):  #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)   # 噪声return X, y.reshape((-1, 1))         # 列向量,-1 表示行数自己计算应该是多少。 # 我们人造数据集的参数。训练期望就是得到的 w 和 b 离我们的 true_w true_b 误差很小
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)# 随机读取一批数据
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))# 这些样本是随机读取的,没有特定的顺序random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]# 初始化模型参数,批量大小10,初始 wb 如下
batch_size = 10
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)def linreg(X, w, b):  #@save"""线性回归模型"""return torch.matmul(X, w) + bdef squared_loss(y_hat, y):  #@save"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降优化函数"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_size  # 损失函数那里没有归一化,这里归一化param.grad.zero_()lr = 0.03			# 学习率
num_epochs = 3		# 优化计算重复3次
net = linreg		# 网络模型
loss = squared_loss	# 损失。这样写后面直接改参数很方便for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)  # X和y的小批量损失# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,# 并以此计算关于[w,b]的梯度l.sum().backward()sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')# output:loss 代表每一次循环 wb 的均方损失
epoch 1, loss 0.038786
epoch 2, loss 0.000152
epoch 3, loss 0.000048# 和我们自己生成的真正的数据参数比较,来看训练准确性:
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
# output:
w的估计误差: tensor([0.0003, 0.0003], grad_fn=<SubBackward0>)
b的估计误差: tensor([-0.0002], grad_fn=<RsubBackward1>)

image-20231101203301587

简洁实现

由于数据迭代器、损失函数、优化器和神经网络层很常用, 现代深度学习库也为我们实现了这些组件。

首先,生成数据集部分是一样的。

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

读取数据集不用自己定义一个随机读取函数:

def load_array(data_arrays, batch_size, is_train=True):  #@save"""构造一个PyTorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features, labels), batch_size)

用法可以和前面的复杂实现一样,for X, y in data_iter 。我们也可以用 iter 构造迭代器,用 next 迭代。next(iter(data_iter)) 获取第一项。

模型也可以直接利用深度学习框架预先定义好的层,没有必要自己写。

Sequential 将多个层串联到一起。其实我们这个简单例子只用到了一个层,但是还是使用 Sequential 定义一下来熟悉一波流程。首先我们定义 Sequential 的输入输出,然后给 Sequential 内部具体的层进行配置。

# nn是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2, 1))net[0].weight.data.normal_(0, 0.01)	# 输入层权重参数的正态分布
net[0].bias.data.fill_(0)			# 偏置参数的初始值

损失函数使用 MSELoss L2 范式返回所有样本损失的平均值。

loss = nn.MSELoss()

优化算法:采用小批量优化算法,给定参数和学习率,参数从 net.parameter 中可以直接获取。

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

训练:

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X) ,y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')

误差:

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

我的输出:

epoch 1, loss 0.000201

epoch 2, loss 0.000103

epoch 3, loss 0.000103

w的估计误差: tensor([-0.0004, 0.0003])

b的估计误差: tensor([0.0003])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/129131.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SOLIDWORKS参数化设计之部分打包 慧德敏学

参数化设计就是通过主参数来驱动整个模型的变化&#xff0c;类似于SOLIDWORKS的方程式中&#xff0c;使用全局变量来控制模型其它参数的变化&#xff0c;因此要做参数化就必须要确定好主参数以及变化逻辑。 我们之前介绍过SOLIDWORKS参数化设计软件-SolidKits.AutoWorks&#…

【C++ 系列文章 -- 程序员考试 201811 下午场 C++ 专题 】

1.1 C 题目六 阅读下列说明和C代码&#xff0c;填写程序中的空&#xff08;1&#xff09; &#xff5e;&#xff08;5&#xff09;&#xff0c;将解答写入答题纸的对应栏内。 【说明】 以下C代码实现一个简单乐器系统&#xff0c;音乐类&#xff08;Music&#xff09;可以使用…

[Unity][VR]透视开发系列4-解决只看得到Passthrough但看不到Unity对象的问题

【视频资源】 视频讲解地址请关注我的B站。 专栏后期会有一些不公开的高阶实战内容或是更细节的指导内容。 B站地址: https://www.bilibili.com/video/BV1Zg4y1w7fZ/ 我还有一些免费和收费课程在网易云课堂(大徐VR课堂): https://study.163.com/provider/480000002282025/…

MongoDB——MongoDB删除系统自带的local数据库

一、MongoDB删除系统自带的local数据库 1.1、linux环境进入mongo客户端 输入 mongo 命令&#xff0c;进入命令行客户端 进入admin库&#xff0c;并登录&#xff0c;查看所有数据库 #进入admin库 use admin #并登录admin db.auth("username","password")…

理解训练深度前馈神经网络的难度【PMLR 2010】

论文地址&#xff1a;Excellent-Paper-For-Daily-Reading/summarize at main 类别&#xff1a;综述 时间&#xff1a;2023/11/03 摘要 这篇论文比较久了&#xff0c;但仍能从里面获得一些收获&#xff0c;论文主要是讨论并研究了不同的非线性激活函数的影响&#xff0c;sig…

不一样的编程方式 —— 协程(设计原理与汇编实现)

主要通过以下9个方面来了解协程的原理&#xff1a; 目录 1、为什么使用协程 1.3、协程的适用场景 2、协程的原语操作 3、协程的切换 3.1、汇编实现 4.协程的运行流程 5.协程的结构体定义(我们其实可以参照线程或者进程的状态来设计) 5.1、多状态集合设计 6.协程的调度…

UE5.0.3版本 像素流送 Pixel Streaming

目录 0 引言1 准备工作1.1 下载Node.js1.2 下载 PixelStreaming&#xff08;非必须&#xff09; 2 快速入门2.1 打包工程2.2 启动信令服务器2.3 启动工程2.4 打开网页 3 总结 &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏&#xff1a;UE虚幻引擎专栏&#x…

MySQL表的增删改查(基础)

文章目录 一、CRUD二、新增&#xff08;Create&#xff09;2.1 单行数据全列插入2.2多行数据指定列插入 三、查询3.1 全列查询3.2 指定列查询3.3 查询字段表达式3.4 别名3.5 去重 DISTINCT3.6 排序3.7 条件查询 WHERE3.8 分页查询 LIMIT 四、修改&#xff08;Update&#xff09…

防火墙日志记录和分析

防火墙监控进出网络的流量&#xff0c;并保护部署防火墙的网络免受恶意流量的侵害。它是一个网络安全系统&#xff0c;它根据一些预定义的规则监控传入和传出的流量&#xff0c;它以日志的形式记录有关如何管理流量的信息&#xff0c;日志数据包含流量的源和目标 IP 地址、端口…

【广州华锐互动】VR特警作战模拟演练系统

在科技发展的驱动下&#xff0c;各行各业都在寻找新的方式来提升效率和培训质量。其中&#xff0c;虚拟现实&#xff08;VR&#xff09;技术在各个领域都有广泛的应用&#xff0c;包括警察培训。VR特警作战模拟演练系统由VR公司广州华锐互动开发&#xff0c;它使用虚拟现实环境…

pb:导入EXCEL,提示“不能连接EXCEL”

pb:导入EXCEL,提示“不能连接EXCEL” ------------------------------------------------------------------------------------------------------------------------------- 1.pb连上EXCEL代码: //从EXCEL读取文件 STRING LS_PATH,LS_FILE,ls_file_tmp oleobject ole_1…

CHS零壹视频恢复程序高级版视频修复OCR使用方法

目前CHS零壹视频恢复程序监控版、专业版、高级版已经支持了OCR&#xff0c;OCR是一种光学识别系统&#xff0c;高级版最新版本中不仅仅是在视频恢复中支持OCR&#xff0c;同时视频修复模块也增加了OCR功能&#xff0c;此功能可以针对一些批量修复的视频文件&#xff08;如执法仪…

预防HPV?谭巍主任分享提高抵抗力的五种水果

在快节奏的现代生活中&#xff0c;我们常常会因为工作、学习或者其他原因而忽视了自己的健康。身体抵抗力是人体抵御外部环境压力和疾病入侵的重要防线。这个防线需要我们通过营养的补充和适当的锻炼来维护。在这篇文章中&#xff0c;劲松HPV防治诊疗中心谭巍主任将介绍五种能够…

使用IDEA生成JavaDoc文档(IDEA2023)

1、Tool-->Generate JavaDoc 2、配置生成JavaDoc文档 1、选择生成范围&#xff0c;可以根据需要选择单独一个文件或者包&#xff0c;也可以是整个项目 2、输出目录&#xff0c;要把JavaDoc文档生成在哪个文件中&#xff0c;最好新建一个文件夹结束 3、Local&#xff1a;…

hadoop mapreduce的api调用WordCount本机和集群代码

本机运行代码 package com.example.hadoop.api.mr;import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache…

Leetcode—100.相同的树【简单】明天写另一种解法!

2023每日刷题&#xff08;十八&#xff09; Leetcode—100.相同的树 递归实现代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ bool isSameTree(struct TreeNode* p, struc…

C++之栈容器

1.简介 stack &#xff0c;栈(堆栈)&#xff0c;是一种先进后出(First In Last Out,FILO)的数据结构&#xff0c;先插入的数据在栈底&#xff0c;后放入的数据在栈顶&#xff0c;所有的数据只能从栈顶取出。   在生活中先进后出的例子友很多&#xff0c;例如我们在桌子上摞书…

最亮那颗星的中年危机 —— 程序员的职场困境与破局

如果说最近的这十年国内市场什么工作是最受瞩目的&#xff0c;那么程序员绝对算得上是夜空中最闪亮的那颗星。 伴随科技的迅猛发展&#xff0c;计算机走进千家万户&#xff0c;智能终端深深融入每个人的生活&#xff0c;程序员这一职业群体也逐渐成为了现代社会中不可或缺的一…

掌控你的Mac性能:System Dashboard Pro,一款专业的系统监视器

作为Mac用户&#xff0c;你是否曾经想要更好地了解你的电脑性能&#xff0c;以便优化其运行&#xff1f;是否想要实时监控系统状态&#xff0c;以便及时发现并解决问题&#xff1f;如果你有这样的需求&#xff0c;那么System Dashboard Pro就是你的不二之选。 System Dashboar…

Visual Studio使用Git忽略不想上传到远程仓库的文件

前言 作为一个.NET开发者而言&#xff0c;有着宇宙最强IDE&#xff1a;Visual Studio加持&#xff0c;让我们的开发效率得到了更好的提升。我们不需要担心环境变量的配置和其他代码管理工具&#xff0c;因为Visual Studio有着众多的拓展工具。废话不多说&#xff0c;直接进入正…