深度学习 线性神经网络(线性回归 从零开始实现)

介绍:

在线性神经网络中,线性回归是一种常见的任务,用于预测一个连续的数值输出。其目标是根据输入特征来拟合一个线性函数,使得预测值与真实值之间的误差最小化。

线性回归的数学表达式为:
y = w1x1 + w2x2 + ... + wnxn + b

其中,y表示预测的输出值,x1, x2, ..., xn表示输入特征,w1, w2, ..., wn表示特征的权重,b表示偏置项。

训练线性回归模型的目标是找到最优的权重和偏置项,使得模型预测的输出与真实值之间的平方差(即损失函数)最小化。这一最优化问题可以通过梯度下降等优化算法来解决。

线性回归在深度学习中也被广泛应用,特别是在浅层神经网络中。在深度学习中,通过将多个线性回归模型组合在一起,可以构建更复杂的神经网络结构,以解决更复杂的问题。

 手动生成数据集:

%matplotlib inline
import torch
from d2l import torch as d2l
import random#"""生成y=Xw+b+噪声"""
def synthetic_data(w, b, num_examples):  #生成num_examples个样本X = d2l.normal(0, 1, (num_examples, len(w)))#随机x,长度为特征个数,权重个数y = d2l.matmul(X, w) + b#y的函数y += d2l.normal(0, 0.01, y.shape)#加上0~0.001的随机噪音return X, d2l.reshape(y, (-1, 1))#返回true_w = d2l.tensor([2, -3.4])#初始化真实w
true_b = 4.2#初始化真实bfeatures, labels = synthetic_data(true_w, true_b, 1000)#随机一些数据
print(features)
print(labels)

显示数据集:

print('features:', features[0],'\nlabel:', labels[0])'''
features: tensor([ 2.1714, -0.6891]) 
label: tensor([10.8673])
'''d2l.set_figsize()
d2l.plt.scatter(d2l.numpy(features[:, 1]), d2l.numpy(labels), 1);

读取小批量数据集:

#每次抽取一批量样本
def data_iter(batch_size, features, labels):#步长、特征、标签num_examples = len(features)#特征个数indices = list(range(num_examples))random.shuffle(indices)# 这些样本是随机读取的,没有特定的顺序,打乱顺序for i in range(0, num_examples, batch_size):#随机访问,步长为batch_sizebatch_indices = d2l.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]

定义模型:

#定义模型
def linreg(X, w, b):  """线性回归模型"""return d2l.matmul(X, w) + b

定义损失函数:

#定义损失和函数
def squared_loss(y_hat, y):  #@save"""均方损失"""return (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2

定义优化算法(小批量随机梯度下降):

#定义优化算法  """小批量随机梯度下降"""
def sgd(params, lr, batch_size):  #参数、lr学习率、with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

模型训练:

#训练
lr = 0.03#学习率
num_epochs = 3#数据扫三遍
net = linreg#模型
loss = squared_loss#损失函数
#初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)#权重
b = torch.zeros(1, requires_grad=True)#b全赋为0for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):#拿出一批量x,yl = loss(net(X, w, b), y)  # X和y的小批量损失,实际的和预测的# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,# 并以此计算关于[w,b]的梯度l.sum().backward()sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
'''
epoch 1, loss 0.037302
epoch 2, loss 0.000140
epoch 3, loss 0.000048
'''print(f'w的估计误差: {true_w - d2l.reshape(w, true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
'''
w的估计误差: tensor([0.0006, 0.0001], grad_fn=<SubBackward0>)
b的估计误差: tensor([-0.0003], grad_fn=<RsubBackward1>)
'''print(w)
'''
tensor([[ 1.9994],[-3.4001]], requires_grad=True)
'''print(b)
'''
tensor([4.2003], requires_grad=True)
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/766305.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【隐私计算实训营——004上手隐语SecretFlow和SecretNote安装部署】

1. SecretFlow安装 1.1 环境要求 Python>3.8操作系统 Ubuntu18 资源&#xff1a;>8核16GB安装包 secretflow-lite 安装方式 docker&#xff08;推荐&#xff09; 2. SecretFlow部署模式 SecretFlow使用Ray作为分布式计算调度框架。 Ray集群由一个主节点和零或若干个…

Fabric Measurement

Fabric Measurement 布料测量

分布式组件 Nacos

1.在之前的文章写过的就不用重复写。 写一些没有写过的新东西 2.细节 2.1命名空间 &#xff1a; 配置隔离 默认&#xff1a; public &#xff08;默认命名空间&#xff09;:默认新增所有的配置都在public空间下 2.1.1 开发 、测试 、生产&#xff1a;有不同的配置文件 比如…

docker 数据卷 (二)

1&#xff0c;为什么使用数据卷 卷是在一个或多个容器内被选定的目录&#xff0c;为docker提供持久化数据或共享数据&#xff0c;是docker存储容器生成和使用的数据的首选机制。对卷的修改会直接生效&#xff0c;当提交或创建镜像时&#xff0c;卷不被包括在镜像中。 总结为两…

Orbit 使用指南 10|在机器人上安装传感器 | Isaac Sim | Omniverse

如是我闻&#xff1a; 资产类&#xff08;asset classes&#xff09;允许我们创建和模拟机器人&#xff0c;而传感器 (sensors) 则帮助我们获取关于环境的信息&#xff0c;获取不同的本体感知和外界感知信息。例如&#xff0c;摄像头传感器可用于获取环境的视觉信息&#xff0c…

ADB环境配置和基础使用

目录 一、ADB简介工作原理 二、安装ADB驱动程序配置环境变量验证ADB安装 三、启用USB调试模式四、连接设备到计算机五、使用ADB命令安装/卸载包Android 设备与电脑传输文件exit 退出目录日志操作指令系统操作指令adb ps命令 一、ADB简介 ADB全称是Android Debug Bridge&#x…

CentOS系统部署YesPlayMusic播放器并实现公网访问本地音乐资源

文章目录 1. 安装Docker2. 本地安装部署YesPlayMusic3. 安装cpolar内网穿透4. 固定YesPlayMusic公网地址 本篇文章讲解如何使用Docker搭建YesPlayMusic网易云音乐播放器&#xff0c;并且结合cpolar内网穿透实现公网访问音乐播放器。 YesPlayMusic是一款优秀的个人音乐播放器&am…

校园大数据平台的顶层设计与微观应用PDF下载

校园大数据平台的顶层设计与微观应用文档&#xff0c;是一份全面深入的解决方案&#xff0c;旨在构建一个集数据收集、存储、处理、分析及可视化于一体的综合平台。该设计以提升教育教学质量、优化资源配置、增强学生服务体验和提高管理效率为核心目标&#xff0c;通过大数据分…

c++的学习之路:3、入门(2)

一、引用 1、引用的概念 引用不是新定义一个变量&#xff0c;而是给已存在变量取了一个别名&#xff0c;编译器不会为引用变量开辟内存空 间&#xff0c;它和它引用的变量共用同一块内存空间。 怎么说呢&#xff0c;简单点理解就是你的小名&#xff0c;家里人叫你小名&#…

基于springboot和vue的旅游资源网站的设计与实现

环境以及简介 基于vue, springboot旅游资源网站的设计与实现&#xff0c;Java项目&#xff0c;SpringBoot项目&#xff0c;含开发文档&#xff0c;源码&#xff0c;数据库以及ppt 环境配置&#xff1a; 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xf…

谷歌seo营销服务有哪些服务?

以我们举例&#xff0c;如果你在做B2B外贸建站&#xff0c;这里有全套保姆式托管服务&#xff0c;让你既省心又省力&#xff0c;七天就能搞定网站建设&#xff0c;快速上线&#xff0c;再来就是谷歌白帽SEO&#xff0c;我们这边强调的是纯白帽操作&#xff0c;专注于高质量的原…

今天聊聊新零售

一、什么是新零售&#xff1f; 2016年&#xff0c;在杭州举行的“云栖大会”上&#xff0c;马云发表了讲话&#xff0c;首次提出了“新零售”这一概念。 1.1 新零售概念 新零售&#xff0c;英文是New Retailing&#xff0c;新零售是对人货场的重构。人是消费者、销售人员、…

CISP 4.2备考之《物理与网络通信安全》知识点总结

文章目录 第 1 节 物理与环境安全第 2 节 网络安全基础第 3 节 网络安全技术与设备第 1 部分 防火墙第 2 部分 入侵检测系统第 3 部分 其他安全产品 第 4 节 网络安全设计规划 第 1 节 物理与环境安全 1.场地选择 1.1 场地选择:自然条件、社会条件、其他条件。1.2 抗震和承重&…

【操作系统】进程基础知识

目录 1、进程的介绍 2、进程的五个基本特性 3、进程的组成 4、进程的并行和并发执行 5、进程的状态 6、进程的通信 7、线程 1、进程的介绍 进程&#xff08;Process&#xff09;是程序在某个数据集合上的一次运行活动&#xff0c;也是操作系统进行资源分配和保护的基本单…

java设计模式(1)---总则

设计模式总则 一、概述 1、什么是设计模式 设计模式是一套被反复使用、多数人知晓的、经过分类编目的、代码设计经验的总结。 解释下&#xff1a; 分类编目&#xff1a;就是说可以找到一些特征去划分这些设计模式&#xff0c;从而进行分类。 代码设计经验&#xff1a;这句很重…

使用Intellij idea编写Spark应用程序(Scala+SBT)

使用Intellij idea编写Spark应用程序(ScalaSBT) 对Scala代码进行打包编译时&#xff0c;可以采用Maven&#xff0c;也可以采用SBT&#xff0c;相对而言&#xff0c;业界更多使用SBT。 运行环境 Ubuntu 16.04 Spark 2.1.0 Intellij Idea (Version 2017.1) 安装Scala插件 安…

【微服务】StackOverflow的架构学习

目录 架构基础设施网络服务器SQL 服务器Redis推荐超级课程: Docker快速入门到精通Kubernetes入门到大师通关课AWS云服务快速入门实战StackOverflow 是资源需求量最大的网站之一。我们作为架构师,在进行各种微服务架构的实践的同时,也需要学习借鉴各个成熟实践的精华。 因此本…

【HarmonyOS】ArkUI - 状态管理

在声明式 UI 中&#xff0c;是以状态驱动视图更新&#xff0c;如图1所示&#xff1a; 图1 其中核心的概念就是状态&#xff08;State&#xff09;和视图&#xff08;View&#xff09;&#xff1a; 状态&#xff08;State&#xff09;&#xff1a;指驱动视图更新的数据&#xf…

第十一届蓝桥杯大赛第二场省赛试题 CC++ 研究生组-子串分值和

solution1&#xff08;通过40%&#xff09; 依次求子串并统计出现过的字母个数 #include<iostream> #include<string> #include<set> using namespace std; int main(){string s, subs;cin >> s;int len s.size(), ans 0;for(int j 1; j < len…

【LabVIEW FPGA入门】FPGA寄存器(Register)

当您需要从多个时钟域或设计的不同部分访问数据&#xff0c;并且需要编写可重复使用的代码时&#xff0c;可使用寄存器项来存储数据。与 FIFO 相比&#xff0c;寄存器项消耗的 FPGA 逻辑资源更少&#xff0c;而且不消耗块存储器&#xff0c;而块存储器是最有限的 FPGA 资源类型…