【Python代码】以线性模型为例,详解深度学习算法流程,包括数据生成、定义模型、损失函数、优化算法和训练

**使用带有噪声的线性模型构造数据集,并根据有限的数据恢复该线性模型的参数。**其中包括数据集构造、模型参数初始化、损失函数定义、定义优化算法和训练等过程。是大多数算法实现过程的一个缩影,理解此过程有助于在开发或改进算法时更深刻了解其算法的构造和框架。

  • 过程详解
    • 生成数据
    • 小批量划分数据集
    • 模型参数初始化
    • 模型定义
    • 损失函数定义
    • 定义优化算法
    • 模型训练
  • 运行示例
    • 整体代码
  • 示例中部分函数详解
    • torch.normal()
    • .backward()

过程详解

生成数据

此时以简单的线性模型为例,生成1000个2维的数据,作为样本集。每个样本都符合均值为0,标准差为1的正态分布。具体代码如下所示。

import torchdef synthetic_data(w,b,num_examples):"""生成y=Wx+b+噪声"""##X是一个1000行2列的数据,符合0,1的正态分布X=torch.normal(0,1,(num_examples,len(w)))#特征y=torch.matmul(X,w)+by+=torch.normal(0,0.01,y.shape)#添加噪声return X,y.reshape((-1,1))
true_w=torch.tensor([2,-3.4])
true_b=4.2
features,labels=synthetic_data(true_w,true_b,1000)
print()
print('features',features[0],'\nlabels:',labels[0])
print('features.shape',features.shape,'\nlabels.shape:',labels.shape)

输出:

features tensor([0.1724, 0.8911]) 
labels: tensor([1.5308])
features.shape torch.Size([1000, 2]) 
labels.shape: torch.Size([1000, 1])

可以看出,已经生成了1000个2维的样本集X,大小为1000行2列。添加完噪声的标签labels,为1000行1列,即一个样本对应一个标签。
对生成数据的第一维和标签的结果可视化:

import matplotlib.pyplot as plt
plt.scatter(features[:,0].detach().numpy(),labels.detach().numpy(),1)
plt.savefig('x1000.jpg')
plt.show()

在这里插入图片描述

小批量划分数据集

把生成的数据打乱,并根据设置的批量大小,根据索引提取样本和标签。

def data_iter(batch_size,features,labels):num_examples=len(features)indices=list(range(num_examples))##这些样本是随机读取的,没有特定顺序random.shuffle(indices)for i in range(0,num_examples,batch_size):batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)])#确定索引##根据索引值提取相应的特征和标签yield features[batch_indices],labels[batch_indices]

选择其中的一个批量样本和标签可视化进行展示。

batch_size=10#批量设置为10
for X,y in data_iter(batch_size,features,labels):print(X,'\n',y)break##读取一次就退出,即选择其中的一个批量

输出:

tensor([[-0.6141, -0.9904],[ 2.2592, -1.2401],[ 0.3217, -2.0419],[ 2.6761,  1.6293],[-0.3886,  1.4958],[-1.4074,  0.2157],[-1.9986, -0.1091],[ 0.3808,  0.3756],[-0.6877,  0.3499],[ 1.5450, -1.0313]]) tensor([[ 6.3528],[12.9585],[11.7972],[ 4.0089],[-1.6630],[ 0.6698],[ 0.5591],[ 3.6852],[ 1.6348],[10.7883]])

样本是10个1行2列的数据,标签是10个1行1列的数据。

模型参数初始化

这里设置权重w为符合均值为0,标准差为0.01的正态分布的随机生成数据,偏置b设置为0。也可以根据自己情况设置初始参数。

w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
b=torch.zeros(1,requires_grad=True)

输出:

w: tensor([[-0.0164],[-0.0022]], requires_grad=True) 
b: tensor([0.], requires_grad=True)

初始参数设置之后,接下来就是更新这些参数,直到这些参数满足我们的数据拟合。
在更新时,需要计算损失函数关于参数的梯度,再根据梯度向减小损失的方向更新参数。

模型定义

定义一个模型,将输入和输出关联起来。前面生成的数据是线性的,所以定义的模型也是一个线性的 y=Wx+b

def linreg(X,w,b):"""线性回归模型"""return torch.matmul(X,w)+b

损失函数定义

这里简单的定义一个损失函数,计算预测结果和真实结果的平均平方差。

def squared_loss(y_hat,y):"""均方损失"""return (y_hat-y.reshape(y_hat.shape))**2/2

定义优化算法

使用小批量随机梯度下降算法作为优化算法。这里要确定超参数批量大小和学习率。

def sgd(params,lr,batch_size):"""小批量随机梯度下降优化法"""with torch.no_grad():for param in params:param-=lr*param.grad/batch_sizeparam.grad.zero_()

模型训练

有了数据、损失函数、初始参数和优化算法,我们就可以开始训练,更新参数。

lr=0.01
num_epochs=10
net=linreg
loss=squared_lossfor epoch in range(num_epochs):for X,y in data_iter(batch_size,features,labels):y_pred=net(X,w,b)l=loss(y_pred,y)l.sum().backward()sgd([w,b],lr,batch_size)with torch.no_grad():train_l=loss(net(features,w,b),labels)#使用更新后的参数计算损失

运行示例

整体代码


import torch
import random##生成数据
def synthetic_data(w,b,num_examples):"""生成y=Wx+b+噪声"""##X是一个1000行2列的数据,符合0,1的正态分布X=torch.normal(0,1,(num_examples,len(w)))y=torch.matmul(X,w)+by+=torch.normal(0,0.01,y.shape)return X,y.reshape((-1,1))
#真实的权重和偏置
true_w=torch.tensor([2,-3.4])
true_b=4.2
##调用生成数据函数,生成数据
features,labels=synthetic_data(true_w,true_b,1000)
# print('features',features[0],'\nlabels:',labels[0])
# print('features.shape',features.shape,'\nlabels.shape:',labels.shape)# import matplotlib.pyplot as plt
# #set_figsize()
# plt.scatter(features[:,0].detach().numpy(),labels.detach().numpy(),1)
# plt.savefig('x1000.jpg')
# plt.show()##读取数据,根据设置的批量大小
def data_iter(batch_size,features,labels):num_examples=len(features)indices=list(range(num_examples))##这些样本是随机读取的,没有特定顺序random.shuffle(indices)for i in range(0,num_examples,batch_size):batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)])yield features[batch_indices],labels[batch_indices]
batch_size=10
# for X,y in data_iter(batch_size,features,labels):
#     print(X,'\n',y)
#     break##初始化参数
w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
b=torch.zeros(1,requires_grad=True)##定义模型、损失函数和优化算法
def linreg(X,w,b):"""线性回归模型"""return torch.matmul(X,w)+b
def squared_loss(y_hat,y):"""均方损失"""return (y_hat-y.reshape(y_hat.shape))**2/2def sgd(params,lr,batch_size):"""小批量随机梯度下降优化法"""with torch.no_grad():for param in params:param-=lr*param.grad/batch_sizeparam.grad.zero_()
#设置参数
lr=0.01
num_epochs=10
net=linreg
loss=squared_loss
##开始训练
for epoch in range(num_epochs):for X,y in data_iter(batch_size,features,labels):y_pred=net(X,w,b)l=loss(y_pred,y)l.sum().backward()sgd([w,b],lr,batch_size)with torch.no_grad():train_l=loss(net(features,w,b),labels)#使用更新后的参数计算损失print('w:',w,'\nb',b)print(f'epoch{epoch+1},loss{float(train_l.mean()):f}')

输出结果:

w: tensor([[ 1.1745],[-2.1489]], requires_grad=True) 
b tensor([2.6504], requires_grad=True)
epoch1,loss2.268522
w: tensor([[ 1.6670],[-2.9392]], requires_grad=True) 
b tensor([3.6257], requires_grad=True)
epoch2,loss0.318115
w: tensor([[ 1.8670],[-3.2300]], requires_grad=True) 
b tensor([3.9866], requires_grad=True)
epoch3,loss0.044844
w: tensor([[ 1.9472],[-3.3373]], requires_grad=True) 
b tensor([4.1204], requires_grad=True)
epoch4,loss0.006389
w: tensor([[ 1.9793],[-3.3768]], requires_grad=True) 
b tensor([4.1701], requires_grad=True)
epoch5,loss0.000951
w: tensor([[ 1.9920],[-3.3914]], requires_grad=True) 
b tensor([4.1888], requires_grad=True)
epoch6,loss0.000178
w: tensor([[ 1.9970],[-3.3968]], requires_grad=True) 
b tensor([4.1958], requires_grad=True)
epoch7,loss0.000069
w: tensor([[ 1.9991],[-3.3988]], requires_grad=True) 
b tensor([4.1984], requires_grad=True)
epoch8,loss0.000053
w: tensor([[ 1.9997],[-3.3995]], requires_grad=True) 
b tensor([4.1993], requires_grad=True)
epoch9,loss0.000051
w: tensor([[ 1.9999],[-3.3998]], requires_grad=True) 
b tensor([4.1997], requires_grad=True)
epoch10,loss0.000051

示例中部分函数详解

此部分,对代码中的部分函数进行解释和说明,以帮助大家理解和使用。

torch.normal()

torch.normal 是 PyTorch 中的一个函数,用于从正态分布(也称为高斯分布)中生成随机数。返回一个与输入张量形状相同的张量,其中的元素是从均值为 mean,标准差为 std 的正态分布中随机采样的

import torchx = torch.normal(0, 1, (3, 3))
print(x)

输出:

tensor([[-1.5393,  0.2281,  1.2181],[ 0.7260, -1.4805,  0.5720],[ 0.0170, -0.9961, -0.2761]])

.backward()

在PyTorch中,a.backward() 是一个用于自动微分的方法。它通常用于计算一个张量(tensor)相对于其操作数(即输入和参数)的梯度。当你使用 PyTorch 的 autograd 模块时,可以通过调用 backward() 方法来自动计算梯度。

import torch# 创建一个张量并设置requires_grad=True来跟踪其计算历史
a = torch.tensor([5.0], requires_grad=True)
print('a:',a)
# 定义一个简单的操作
b = a * 2# 调用backward()来自动计算梯度
b.backward()# 输出梯度
print(a.grad)  

输出:

a: tensor([5.], requires_grad=True)
tensor([2.])

其中,requires_grad属性是为了使PyTorch跟踪张量的计算历史并自动计算梯度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/638303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

寒假每日一题-公路

小苞准备开着车沿着公路自驾。公路上一共有 n个站点,编号为从 1 到 n。其中站点 i与站点 i1 的距离为 vi公里。 公路上每个站点都可以加油,编号为 i的站点一升油的价格为 ai元,且每个站点只出售整数升的油。 小苞想从站点 1开车到站点 n&am…

golang学习笔记——http.Handle和http.HandleFunc的区别与type func巧妙运用

文章目录 http.Handle和http.HandleFunc的区别http.Handle分析type func巧妙运用 http.HandleFunc分析总结参考资料 http.Handle和http.HandleFunc的区别 http.Handle和http.HandleFunc的区别体现了Go语言接口的巧妙运用 下面代码启动了一个 http 服务器,监听 808…

基于python的数字识别-含数据集和代码

数据集介绍,下载本资源后,界面如下: 有一个文件夹一个是存放数据集的文件。 数据集介绍: 一共含有:16个类别,包含:division, eight, five, four, left_bracket, minus, multiplication, nine, one, plus, right_brac…

Golang杀死子进程的三种方式

目录 前言 正文 一、cmd.Process.Kill() 二、syscall.Kill 三、cmd.Process.Signal 结论 前言 熟悉Golang语言的小伙伴一定都知道,杀死子进程有三种方式,今天就来简单介绍一下。 正文 Golang中有三种方式可以杀死子进程,分别是cmd.P…

逻辑回归中的损失函数

一、引言 逻辑回归中的损失函数通常采用的是交叉熵损失函数(cross-entropy loss function)。在逻辑回归中,我们通常使用sigmoid函数将线性模型的输出转换为概率值,然后将这些概率值与实际标签进行比较,从而计算损失。 …

《Windows核心编程》若干知识点应用实战分享

目录 1、进程的虚拟内存分区与小于0x10000的小地址内存区 1.1、进程的虚拟内存分区 1.2、小于0x10000的小地址内存区 2、保存线程上下文的CONTEXT结构体 3、从汇编代码角度去理解多线程运行过程的典型实例 4、调用TerminateThread强制结束线程会导致线程中的资源没有释放…

C#练习 — 第一期(帮助卢锡安给他的女朋友准备晚餐)

前言 纸上得来终觉浅,绝知此事要躬行。我们之前学习了C#许多基础知识,但很少有练习,今天开始,我们将通过练习题的形式,巩固此前学到的知识点,加油! 目录 提示 要求 分步实现 构建框架预定义…

多人在线聊天交友工具,匿名聊天室网站源码,附带搭建教程

源码介绍 匿名聊天室(nodejs vue) 多人在线聊天交友工具,无需注册即可畅所欲言!你也可以放心讲述自己的故事,说出自己的秘密,因为谁也不知道对方是谁。 运行说明 安装依赖项:npm install 启动…

Web server failed to start.Port xxxx was already in use.

目录 一、报错截图:二、解决方式 一、报错截图: 某端口被占用,导致出现如下报错: 二、解决方式 windowsR 输入cmd—>回车 如下图所示 查看被占用的端口的进程,如下图: netstat -ano |findstr 端口号结束这个进程…

python定义可调用的类型

除了用户定义的函数,调用运算符(即 ())还可以应用到其他对象上。如果想判断对象能否调用,可以使用内置的 callable() 函数。Python 数据模型文档列出了 7 种可调用对象。 使用 def 语句或 lambda 表达式创建内置函数:…

【大模型研究】(1):从零开始部署书生·浦语2-20B大模型,使用fastchat和webui部署测试,autodl申请2张显卡,占用显存40G可以运行

1,演示视频 https://www.bilibili.com/video/BV1pT4y1h7Af/ 【大模型研究】(1):从零开始部署书生浦语2-20B大模型,使用fastchat和webui部署测试,autodl申请2张显卡,占用显存40G可以运行 2&…

WEB接口测试之Jmeter接口测试自动化 (三)(数据驱动测试)

接口测试与数据驱动 1简介 数据驱动测试,即是分离测试逻辑与测试数据,通过如excel表格的形式来保存测试数据,用测试脚本读取并执行测试的过程。 2 数据驱动与jmeter接口测试 我们已经简单介绍了接口测试参数录入及测试执行的过程&#xff0…

2024.1.15力扣每日一题——删除排序链表中的重复元素 II

2024.1.15 题目来源我的题解方法一 三指针虚拟头结点 题目来源 力扣每日一题;题序:82 我的题解 方法一 三指针虚拟头结点 先构建一个带虚拟头结点的链表,然后使用三个指针p,left,right,分别指向最右非重复节点,可能…

C++——数组、多维数组、简单排序、模板类vector

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

数据结构实验7:查找的应用

目录 一、实验目的 二、实验原理 1. 顺序查找 2. 折半查找 3. 二叉树查找 三、实验内容 实验一 任务 代码 截图 实验2 任务 代码 截图 一、实验目的 1.掌握查找的基本概念; 2.掌握并实现以下查找算法:顺序查找、折半查找、二叉树查找。 …

给github设置代理

1 引言 本文详细介绍了在 Linux 环境下配置和使用网络代理的步骤。包括使用环境变量设置代理的方法、在 Git 中配置代理的常用方法以及一些调试工具。这些内容对于需要在网络受限环境下使用 Git 的用户非常实用。 2 配置代理 export http_proxyhttp://host:port/ export h…

Python正则表达式Regular Expression初探

目录 Regular 匹配规则 单字符匹配 数量匹配 边界匹配 分组匹配 贪婪与懒惰 原版说明 特殊字符 转义序列 模块方法 函数说明 匹配模式 常用匹配规则 1. 匹配出所有整数 2. 匹配11位且13开头的整数 Regular Python的re模块提供了完整的正则表达式功能。正则表达式…

js算法不连续子序列

涉及力扣题目: 1143.最长公共子序列 1035.不相交的线 53. 最大子序和 上一次我们说过如何求连续子序列,解决方法是模拟一个”棋盘“两两相同对撞,又因为是连续所以一定是对角线为上一组相同。这次有点变化,要求是非连续子序列。 …

前端上传图片至OSS

环境:VUE3NODEJS16 一、第一步肯定是引入依赖 在package.json文件中的dependencies加上"ali-oss": "^6.17.1"如下代码所示: //加入后的整体展示"dependencies": {"ali-oss": "^6.17.1"},然后在控制台…

【Delphi 基础知识 22】TStringList 的详细用法

文章目录 TStringList 与TStrings的区别TStringList 常用方法与属性 TStringList 类在Delphi中会经常使用到,我们这里一起来看看 TStringList 的详细用法. TStringList 与TStrings的区别 TStringList 和 TStrings 都是 Delphi 编程语言中用于处理字符串列表的类。它…