16.线性回归代码实现

线性回归的实操与理解

介绍

线性回归是一种广泛应用的统计方法,用于建模一个或多个自变量(特征)与因变量(目标)之间的线性关系。在机器学习和数据科学中,线性回归是许多入门者的第一个模型,它提供了对监督学习问题的基础理解。本文将介绍线性回归的基本概念,并通过Python和PyTorch库来实操线性回归模型,深入理解其训练和预测过程。

线性回归的基本概念

线性回归假设目标变量(y)是输入变量(X)的线性组合,并可以通过最小二乘法来估计模型的参数(权重w和偏置b)。数学上,线性回归模型可以表示为:

y=w1​x1​+w2​x2​+…+wn​xn​+b

或者更一般地,使用矩阵形式表示:

y=XW+b

其中,X 是特征矩阵,W 是权重向量,b 是偏置项。

实操:使用PyTorch实现线性回归

1. 导入必要的库

首先,我们需要导入PyTorch和其他必要的库。

import torch  
import torch.nn as nn  
import torch.optim as optim  
import numpy as np  
import matplotlib.pyplot as plt


2. 生成模拟数据

为了演示线性回归,我们将生成一些模拟数据。

# 设置随机种子  
torch.manual_seed(0)  
np.random.seed(0)  # 生成数据  
n_samples = 100  
x = torch.randn(n_samples, 1) * 10  # 输入数据  
w_true = 2  
b_true = 1  
y = x * w_true + b_true + torch.randn(n_samples, 1) * 0.5  # 真实标签


3. 定义线性回归模型

使用PyTorch的nn.Module来定义线性回归模型。

class LinearRegressionModel(nn.Module):  def __init__(self, input_dim=1, output_dim=1):  super(LinearRegressionModel, self).__init__()  self.linear = nn.Linear(input_dim, output_dim)  def forward(self, x):  out = self.linear(x)  return out


4. 初始化模型和优化器

实例化模型,并定义损失函数和优化器。

# 初始化模型  
model = LinearRegressionModel()  # 定义损失函数和优化器  
criterion = nn.MSELoss()  
optimizer = optim.SGD(model.parameters(), lr=0.01)


5. 训练模型

通过迭代训练数据来训练模型。

# 训练模型  
num_epochs = 1000  
for epoch in range(num_epochs):  # 前向传播  outputs = model(x)  loss = criterion(outputs, y)  # 反向传播和优化  optimizer.zero_grad()  # 清空梯度  loss.backward()  # 反向传播计算梯度  optimizer.step()  # 更新参数  if (epoch+1) % 100 == 0:  print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))


6. 评估模型

在训练完成后,我们可以评估模型的性能。但在这个简单的例子中,我们主要关注于模型是否能学习到正确的权重和偏置。

7. 可视化结果

我们可以将预测结果和真实数据可视化出来。

# 提取训练后的参数  
w, b = model.linear.weight.item(), model.linear.bias.item()  
print('w = {}, b = {}'.format(w, b))  # 可视化结果  
predicted = model(x).detach().numpy()  
plt.scatter(x.numpy(), y.numpy(), color='blue', label='True data')  
plt.plot(x.numpy(), predicted, color='red', linewidth=2, label='Predicted data')  
plt.legend()  
plt.show()


总结

通过本文的实操,我们深入理解了线性回归的基本原理和其在PyTorch中的实现方式。我们生成了模拟数据,定义了线性回归模型,并使用随机梯度下降优化器来训练模型。通过可视化结果,我们可以看到模型能够很好地拟合生成的数据,并且学习到的权重和偏置与真实

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/15036.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

A股重磅!史上最严减持新规,发布!

此次减持新规被市场视为A股史上最严、最全面的规则,“花式”减持通道被全面“封堵”。 5月24日晚间,证监会正式发布《上市公司股东减持股份管理暂行办法》(以下简称《减持管理办法》)及相关配套规则。 据了解,《减持…

工作学习的电脑定时关机,定时重启,定时提醒

可以直接下载工具: 定时自动关机 大家好,! 在我们学习与工作时,经常会遇到想要在完成一个任务后,再关闭电脑或对电脑重启,但这个时间点,操作电脑的人可能不能在电脑旁边,这样就需要…

大语言模型的工程技巧(四)——梯度检查点

相关说明 这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。 本文将讨论如何利用梯度检查点算法来减少模型在训练时候(更准确地说是运行反向传播算法时)的内存开支。…

机器学习-决策树算法

前言 本篇介绍决策树与随机森林的内容,先完成了决策树的部分。 决策树 决策树(Decision Tree)是一种有监督学习的方法,可以同时解决分类和回归问题,它能够从一系列有特征和标签的数据中总结出决策规则,并用树状图的结构来呈现这…

SecureCRT for Mac注册激活版:专业终端SSH工具

SecureCRT是一款支持SSH(SSH1和SSH2)的终端仿真程序,简单地说是Windows下登录UNIX或Linux服务器主机的软件。 SecureCRT支持SSH,同时支持Telnet和rlogin协议。SecureCRT是一款用于连接运行包括Windows、UNIX和VMS的理想工具。通过…

大摩:AI到“临界点”了,资管公司到了广泛部署的时刻

大摩表示,尽管AI技术在资产管理行业中的应用仍处于早期阶段,但其潜力巨大,能够为行业带来根本性的变革。预计生成式AI能够在资产管理公司的运营模型中带来20%至40%的生产力提升。 正文介绍 在全球经济面临诸多不确定因素的当下,…

【全开源】答题考试系统源码(FastAdmin+ThinkPHP+Uniapp)

答题考试系统源码:构建高效、安全的在线考试平台 引言 在当今数字化时代,在线考试系统已成为教育机构和企业选拔人才的重要工具。一个稳定、高效、安全的答题考试系统源码是构建这样平台的核心。本文将深入探讨答题考试系统源码的关键要素,…

大佬大讲堂(1)电机及其驱动内核-自适应观察器

点击上方 “机械电气电机杂谈 ” → 点击右上角“...” → 点选“设为星标 ★”,为加上机械电气电机杂谈星标,以后找夏老师就方便啦!你的星标就是我更新动力,星标越多,更新越快,干货越多! 关注…

Java面试八股之可重入锁ReentrantLock是怎么实现可重入的

可重入锁ReentrantLock是怎么实现可重入的 ReentrantLock实现可重入性的机制主要依赖于以下几个核心组件和步骤: 状态计数器:ReentrantLock内部维护一个名为state的整型变量作为状态计数器,这个计数器不仅用来记录锁是否被持有,…

Java进阶学习笔记9——子类中访问其他成员遵循就近原则

正确访问成员的方法。 在子类方法中访问其他成员(成员变量、成员方法),是依照就近原则的。 F类: package cn.ensource.d13_extends_visit;public class F {String name "父类名字";public void print() {System.out.p…

langchian进阶二:LCEL表达式,轻松进行chain的组装

LangChain表达式语言-LCEL,是一种声明式的方式,可以轻松地将链条组合在一起。 你会在这些情况下使用到LCEL表达式: 流式支持 当你用LCEL构建你的链时,你可以得到最佳的首次到令牌的时间(输出的第一块内容出来之前的时间)。对于一些链&#…

Springboot+Vue项目-基于Java+MySQL的酒店管理系统(附源码+演示视频+LW)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &…

手撕算法|斯坦福大学教授用60页PPT搞定了八大神经网络

人工智能领域深度学习的八大神经网络常见的是以下几种 1.卷积神经网络(CNN): 卷积神经网络是用于图像和空间数据处理的神经网络,通过卷积层和池化层来捕捉图像的局部特征,广泛应用于图像分类、物体检测等领域。 2.循…

blender 布尔运算,切割模型。

1.创建一个立方体和球体。 2.选中立方体,在属性面板添加布尔修改器。点击物体属性右边的按钮选中球体。参数如下。 3.此时隐藏球体,就可以看到被切掉的效果了。

【算法】前缀和算法——和可被K整除的子数组

题解:和可被K整除的子数组(前缀和算法) 目录 1.题目2.前置知识2.1同余定理2.2CPP中‘%’的计算方式与数学‘%’的差异 及其 修正2.3题目思路 3.代码示例4.总结 1.题目 题目链接:LINK 2.前置知识 2.1同余定理 注:这里的‘/’代表的是数学…

Creating Server TCP listening socket *:6379: listen: Unknown error

错误: 解决方法: 在redis安装路径中打开cmd命令行窗口,输入 E:\Redis-x64-3.2.100>redis-server ./redis.windows.conf结果:

动态链接学习总结

背景 之前了解了静态链接的原理,就想着把动态链接的原理也学习一下,提高编程能力。 关键知识点 动态链接的工作原理: 编译时的处理: 当程序被编译时,编译器知道程序需要某些库函数,但并不把这些函数的代…

【C++】C++11(一)

C11是一次里程碑式的更新,我们一起来看一看~ 目录 列表初始化:{ }初始化:std::initializer_list: 声明:auto:decltype: STL的一些变化: 列表初始化: { }初始化&#xf…

学习记录16-反电动势

一、反电动势公式 在负载下反电势和端电压的关系式为:𝑈𝐼𝑅𝐿*(𝑑𝑖 / 𝑑𝑡)𝐸 E为线圈电动势、 𝜓 为磁链、f为频率、N…

博客说明 5/12~5/24【个人】

博客说明 5/12~5/24【个人】 前言版权博客说明 5/12~5/24【个人】对比最后 前言 2024-5-24 13:39:23 对我在2024年5月12日到5月24日发布的博客做一下简要的说明 以下内容源自《【个人】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作…