Pytorch实现线性回归模型

9a3ad9cb79654198ab4c9f862270ed93.gif

在机器学习和深度学习的世界中,线性回归模型是一种基础且广泛使用的算法,简单易于理解,但功能强大,可以作为更复杂模型的基础。使用PyTorch实现线性回归模型不仅可以帮助初学者理解模型的基本概念,还可以为进一步探索更复杂的模型打下坚实的基础。⚔️

💡在接下来的教程中,我们将详细讨论如何使用PyTorch来实现线性回归模型,包括代码实现、参数调整以及模型优化等方面的内容~

💡我们接下来使用Pytorch的API来手动构建一个线性回归的假设函数损失函数及优化方法,熟悉一下训练模型的流程。熟悉流程之后我们再学习如何使用PyTorch的API来自动训练模型~

import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import random
def creat_data():x, y, coef = make_regression(n_samples=100, n_features=1, noise=10, coef=True, bias=14.5, random_state=0)# 所有的特征值X都是0,目标变量y的平均值也会是14.5(加上或减去由于noise参数引入的噪声)# coef:权重系数,表示线性回归模型中每个特征的权重,y_pred = x * coef + biasx = torch.tensor(x)y = torch.tensor(y)return x, y ,coef  # x , y 不是按顺序的, 而是随机顺序的def data_loader(x, y, batch_size):data_len = len(y)data_index = list(range(data_len))random.shuffle(data_index)batch_number = data_len // batch_sizefor idx in range(batch_number):start = idx * batch_sizeend = start + batch_sizebatch_train_x = x[start: end]batch_train_y = y[start: end]yield batch_train_x, batch_train_y  # 相当于reutrn, 返回一个值,但是不会结束函数

🧨这一部分creat_data是来生成线性回归的数据,coef=True(截距)表示所有的特征值X都是0时,目标变量y的平均值也会是14.5(加上或减去由于noise参数引入的噪声)

# 假设函数
w = torch.tensor(0.1, requires_grad=True, dtype=torch.float64)
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float64)def linear_regression(x):return w * x + b# 损失函数
def square_loss(y_pre, y_true):return (y_pre - y_true) ** 2# 优化方法(梯度下降)
def sgd(lr=0.01):w.data = w.data - lr * w.grad.data / 16  # 批次样本的平均梯度值,梯度累积了16次b.data = b.data - lr * b.grad.data / 16

def train():# 加载数据集x, y, coef = creat_data()# 定义训练参数epochs = 100learning_rate = 0.01# 存储训练信息epochs_loss = []total_loss = 0.0train_samples = 0for _ in range(epochs):for train_x, train_y in data_loader(x, y, batch_size=16):y_pred = linear_regression(train_x)# 计算平方损失loss = square_loss(y_pred, train_y.reshape(-1, 1)).sum()  # 16个tensor(16行1列)# print(loss)total_loss += loss.item()train_samples += len(train_y)# 梯度清零if w.grad is not None:w.grad.data.zero_()if b.grad is not None:b.grad.data.zero_()# 自动微分loss.backward()  sgd(learning_rate)print('loss:%.10f' % (total_loss / train_samples))# 记录每一个epochs的平均损失epochs_loss.append(total_loss / train_samples)# 先绘制数据集散点图plt.scatter(x, y)# 绘制拟合的直线x = torch.linspace(x.min(), x.max(), 1000)y1 = torch.tensor([v * w + b for v in x])y2 = torch.tensor([v * coef + b for v in x])plt.plot(x, y1, label='训练')plt.plot(x, y2, label='真实')plt.grid()plt.legend()plt.show()# 打印损失变化曲线plt.plot(range(epochs), epochs_loss)plt.grid()plt.title('损失变化曲线')plt.show()if __name__ == '__main__':train()

🧨 我们将整个数据集分成多个批次(batch),每个批次包含16个数据。由于每个批次的数据都是随机抽取的。这样可以增加模型的泛化能力,避免过拟合。分批次训练可以提高学习的稳定性。当使用梯度下降法优化模型参数时,较小的批次可以使梯度下降方向更加稳定,从而更容易收敛到最优解。

🧨我们将这批数据每次分成16份训练,并且这样重复训练epochs次,可以更深入地学习数据中的特征和模式,有助于防止模型快速陷入局部最优解,从而提高模型的泛化能力,而且适当的epoch数量可以在欠拟合和过拟合之间找到平衡点,确保模型具有良好的泛化能力。

关于backward方法: 调用loss.backward()时,PyTorch会计算损失函数相对于所有需要梯度的参数的梯度。在我们的例子中,backward() 方法被调用在一个张量(即损失函数的输出)上。这是因为在 PyTorch 中,backward() 方法用于计算某个张量(通常是损失函数的输出)相对于所有需要梯度的参数的梯度。当 backward() 方法被调用时,PyTorch 会自动计算该张量相对于所有需要梯度的参数的梯度,并将这些梯度累加到对应参数的 .grad 属性上。

我们再来看一个例子:

def test03():# y = x**2x = torch.tensor(10, requires_grad=True, dtype=torch.float64)for _ in range(500):# 正向计算f = x ** 2print(x.grad)# 梯度清零if x.grad is not None:x.grad.data.zero_()# 反向传播计算梯度f.backward()# 更新参数x.data = x.data - 0.01 * x.gradprint('%.10f' % x.data)

虽然 f 本身不是损失函数,但在 PyTorch 中,任何需要进行梯度计算的张量都可以使用 backward() 方法来帮助进行梯度更新。这是自动微分机制的一部分,使得无论 f 是简单函数还是复杂的损失函数,都能利用相同的方法来进行梯度的反向传播。

我们看一下训练后的效果:

7246040a180d4e5ba9fb2d40554431c6.png

可以看到经过重复训练几乎和原本的真实直线吻合, 我们在每次epochs后都会记录平均损失,看一下平均损失的下降趋势:

67d55b7b802444c5981597fc36e7bfbd.png

回顾:随机梯度下降算法(SGD) 

from sklearn.linear_model import SGDRegressor
  • 随机梯度下降算法(SGD)
  • 每次迭代时, 随机选择并使用一个样本梯度值

由于FG每迭代更新一次权重都需要计算所有样本误差,而实际问题中经常有上亿的训练样本,故效率偏低,且容易陷入局部最优解,因此提出了随机梯度下降算法。其每轮计算的目标函数不再是全体样本误差,而仅是单个样本误差,即 每次只代入计算一个样本目标函数的梯度来更新权重,再取下一个样本重复此过程,直到损失函数值停止下降或损失函数值小于某个可以容忍的阈值。

但是由于,SG每次只使用一个样本迭代,若遇上噪声则容易陷入局部最优解。 


🥂接下来我们看一下PyTorch的相关API的自动训练: 

模型定义方法

  • 使用 PyTorch 的 nn.MSELoss() 代替自定义的平方损失函数
  • 使用 PyTorch 的 data.DataLoader 代替自定义的数据加载器
  • 使用 PyTorch 的 optim.SGD 代替自定义的优化器
  • 使用 PyTorch 的 nn.Linear 代替自定义的假设函数
  1. PyTorch的nn.MSELoss():这是PyTorch中用于计算预测值与真实值之间均方误差的损失函数,主要用于回归问题。它提供了参数来控制输出形式,可以是同维度的tensor或者是一个标量。
  2. PyTorch的data.DataLoader:这是PyTorch中负责数据装载的类,它支持自动批处理、采样、打乱数据和多进程数据加载等功能。DataLoader可以高效地在一个大数据集上进行迭代。
  3. PyTorch的optim.SGD:这是PyTorch中实现随机梯度下降(SGD)优化算法的类。SGD是一种常用的优化算法,尤其在深度学习中被广泛应用。它的主要参数包括学习率、动量等,用于调整神经网络中的参数以最小化损失函数。
  4. PyTorch的nn.Linear:这是PyTorch中用于创建线性层的类,也被称为全连接层。它将输入与权重矩阵相乘并加上偏置,然后通过激活函数进行非线性变换。nn.Linear定义了神经网络的一个线性层,可以指定输入和输出的特征数。
  5. 通过这些组件,我们可以构建和训练复杂的网络模型,而无需手动编写大量的底层代码。

 接下来使用 PyTorch 来构建线性回归:

import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
import matplotlib.pyplot as pltdef create_data():x, y, coef = make_regression(n_samples=100,n_features=1,noise=10,coef=True,bias=14.5,random_state=0)x = torch.tensor(x)y = torch.tensor(y)return x, y, coefdef train():x, y, coef = create_data()dataset = TensorDataset(x, y)# 数据加载器dataloader = DataLoader(dataset, batch_size=16, shuffle=True)model = nn.Linear(in_features=1, out_features=1)# 构建损失函数criterion = nn.MSELoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-2)# 初始化训练参数epochs = 100for _ in range(epochs):for train_x, train_y in dataloader:y_pred = model(train_x.type(torch.float32))# 计算损失值loss = criterion(y_pred, train_y.reshape(-1, 1).type(torch.float32))# 梯度清零optimizer.zero_grad()# 自动微分(反向传播)loss.backward()# 更新参数optimizer.step()# 绘制拟合直线plt.scatter(x, y)x = torch.linspace(x.min(), x.max(), 1000)y1 = torch.tensor([v * model.weight + model.bias for v in x])y2 = torch.tensor([v * coef + 14.5 for v in x])plt.plot(x, y1, label='训练')plt.plot(x, y2, label='真实')plt.legend()plt.show()if __name__ == '__main__':train()

5d03f68fbac94a2db8b8b19cd008ae50.gif

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/4327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WebSocket 深入浅出

WebSocket 深入浅出 1. WebSocket 是什么2. WebSocket 建立连接通信的过程3. WebSocket 和http的联系与区别4. WebSocket 的使用场景及限制 1. WebSocket 是什么 定义:WebSocket 是一种网络通信协议,它允许在单个TCP连接上进行全双工通信。是HTML5规范提…

电商技术揭秘三十七:电商智能风控业务架构设计

相关系列文章 电商技术揭秘相关系列文章合集(1) 电商技术揭秘相关系列文章合集(2) 电商技术揭秘二十八:安全与合规性保障 电商技术揭秘二十九:电商法律合规浅析 电商技术揭秘三十:知识产权保…

无人机+激光雷达:高精度测绘级实时点云激光雷达技术详解

在现代测绘技术中,无人机与激光雷达的结合已经成为一种重要的技术手段。激光雷达(LiDAR)是一种主动式航空传感器,通过发射激光束并探测其与目标物体的反射,可以获取目标物体的位置、速度等特征信息。而无人机则作为一种…

web.config 学习

1 appSettings 节 打开一个项目的web.config文件看一下&#xff1b;appSettings节如下&#xff0c; <appSettings><add key"Telerik.Skin" value"Windows7" /><add key"ValidationSettings:UnobtrusiveValidationMode" value&qu…

ULTIMATE VOCAL REMOVER V5 for Mac:专业人声消除软件

ULTIMATE VOCAL REMOVER V5 for Mac是一款专为Mac用户设计的人声消除软件&#xff0c;它凭借强大的功能和卓越的性能&#xff0c;在音乐制作和后期处理领域崭露头角。 ULTIMATE VOCAL REMOVER V5 for Mac v5.6激活版下载 这款软件基于深度神经网络&#xff0c;通过先进的训练模…

不可重复读,幻读和脏读

不可重复读一般在读未提交&#xff0c;读已提交这两种隔离级别出现&#xff0c;第一次读和第二次读的数据不一致。 幻读一般在读未提交&#xff0c;读已提交&#xff0c;可重复读出现&#xff0c;原因是第一个事务执行时&#xff0c;第二个事务完成了提交&#xff0c;在第一个…

解决HttpServletRequest中的InputStream/getReader只能被读取一次的问题

一、事由 由于我们业务接口需要做签名校验&#xff0c;但因为是老系统了签名规则被放在了Body里而不是Header里面&#xff0c;但是我们不能在每个Controller层都手动去做签名校验&#xff0c;这样不是优雅的做法&#xff0c;然后我就写了一个AOP&#xff0c;在AOP中实现签名校…

Stable Diffusion教程:额外功能/后期处理/高清化

"额外功能"对应的英文单词是Extras&#xff0c;算是直译。但是部分版本中的翻译是“后期处理”或者“高清化”&#xff0c;这都是意译&#xff0c;因为它的主要功能是放大图片、去噪、修脸等对图片的后期处理。注意这里边对图片的处理不是 Stable Diffusion 本身的能…

RabbitMq基础概念知识复习

消息拥有消息头和消息体&#xff0c;消息具有rounting key&#xff0c;主题交换机和扇形交换机都是分布与订阅的实现方式&#xff0c;主题交换机用于匹配接收的消息的rount key 动态匹配模式匹配到多个符合的队列&#xff0c;扇形fanout交换机则不会使用消息的路由key&#xff…

PyTorch深度学习实战(41)——循环神经网络与长短期记忆网络

PyTorch深度学习实战&#xff08;41&#xff09;——循环神经网络与长短期记忆网络 0. 前言1. 循环神经网络1.1 传统文本处理方法的局限性1.2 RNN 架构2.3 RNN 内存机制 2. RNN 的局限性3. 长短期记忆网络3.1 LSTM 架构3.2 构建 LSTM 小结系列链接 0. 前言 循环神经网络 (Recu…

数据结构––串

5.1 串的定义 由零个或者任意多个字符组成的有限序列&#xff0c;是一种特殊的顺序表&#xff0c;每一个元素都是单独一个字符 空格也可以是一个字符 串的长度&#xff1a;串中的有效元素的个数&#xff08;不包括\0&#xff09; 空串&#xff1a;不包括任何元素的串&#…

web server apache tomcat11-26-maven jars

前言 整理这个官方翻译的系列&#xff0c;原因是网上大部分的 tomcat 版本比较旧&#xff0c;此版本为 v11 最新的版本。 开源项目 从零手写实现 tomcat minicat 别称【嗅虎】心有猛虎&#xff0c;轻嗅蔷薇。 系列文章 web server apache tomcat11-01-官方文档入门介绍 web…

传统过程自动化工厂的智能扩展

一 通过NOA概念&#xff0c;公开、安全地迈向未来 随着数字化转型在过程自动化工业中的不断深入&#xff0c;许多公司都面临着同一挑战——如何平衡创新和传统。放眼望去&#xff0c;过程自动化工业和信息技术似乎在以不同的速度发展。虽然过程自动化工厂通过使用传统的自动化…

基于Springboot的幼儿园管理系统

基于SpringbootVue的幼儿园管理系统的设计与实现 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录 用户管理 教师管理 幼儿园信息管理 班级信息管理 工作日志管理 会议记录管理…

Go语言中,常用的同步机制

在 Go 语言中&#xff0c;保证多线程&#xff08;或者更准确地说&#xff0c;多协程&#xff09;有序执行&#xff0c;主要依赖于协程间的同步机制。Go 提供了几种工具来帮助开发者控制协程&#xff08;goroutine&#xff09;的执行顺序&#xff0c;确保数据的一致性和操作的原…

大模型实战提示工程4—结构化信息与代码相关任务示例

1. 结构化信息处理类 1.1. 命名实体识别(Named Entity Recognition, NER) 任务描述:从中文文本中识别具有特定意义的实体,如人名、地点、组织、时间等。 示例:原文:"小米集团 CEO雷军在2024年 发布了小米汽车。" 实体识别结果: 人名:雷军职务:小米集团 …

Vue2基础知识:组件的样式冲突scoped,为什么加了scoped样式就会独立出来呢?

默认情况&#xff1a;写在组件中的样式会全局生效&#xff0c;这样就容易造成多个组件之间的样式冲突问题。 1.全局样式&#xff1a;默认组件中的样式会作用到全局.&#xff08;也就是说不管你在哪个页面或者组件中写入样式&#xff0c;只要页面生效&#xff0c;该页面的style…

七大排序算法(Java实现)——冒泡、快排、插入、希尔、选择、堆排、归并

升序排序为例 交换元素的通用代码&#xff1a; /*** 交换元素* param arr* param idx1* param idx2*/private void swap(int[] arr, int idx1, int idx2) {int tmp arr[idx1];arr[idx1] arr[idx2];arr[idx2] tmp;} 一、交换排序——冒泡排序 冒泡排序&#xff1a; 相邻两…

sql连续登录

1、sql建表语句 DROP TABLE IF EXISTS app_login_record; CREATE TABLE app_login_record (user_id int(0) NULL DEFAULT NULL,enter_time datetime(0) NULL DEFAULT NULL,leave_time datetime(0) NULL DEFAULT NULL );INSERT INTO app_login_record VALUES (789012, 2023-05-…