【Python代码】以线性模型为例,详解深度学习算法流程,包括数据生成、定义模型、损失函数、优化算法和训练

**使用带有噪声的线性模型构造数据集,并根据有限的数据恢复该线性模型的参数。**其中包括数据集构造、模型参数初始化、损失函数定义、定义优化算法和训练等过程。是大多数算法实现过程的一个缩影,理解此过程有助于在开发或改进算法时更深刻了解其算法的构造和框架。

  • 过程详解
    • 生成数据
    • 小批量划分数据集
    • 模型参数初始化
    • 模型定义
    • 损失函数定义
    • 定义优化算法
    • 模型训练
  • 运行示例
    • 整体代码
  • 示例中部分函数详解
    • torch.normal()
    • .backward()

过程详解

生成数据

此时以简单的线性模型为例,生成1000个2维的数据,作为样本集。每个样本都符合均值为0,标准差为1的正态分布。具体代码如下所示。

import torchdef synthetic_data(w,b,num_examples):"""生成y=Wx+b+噪声"""##X是一个1000行2列的数据,符合0,1的正态分布X=torch.normal(0,1,(num_examples,len(w)))#特征y=torch.matmul(X,w)+by+=torch.normal(0,0.01,y.shape)#添加噪声return X,y.reshape((-1,1))
true_w=torch.tensor([2,-3.4])
true_b=4.2
features,labels=synthetic_data(true_w,true_b,1000)
print()
print('features',features[0],'\nlabels:',labels[0])
print('features.shape',features.shape,'\nlabels.shape:',labels.shape)

输出:

features tensor([0.1724, 0.8911]) 
labels: tensor([1.5308])
features.shape torch.Size([1000, 2]) 
labels.shape: torch.Size([1000, 1])

可以看出,已经生成了1000个2维的样本集X,大小为1000行2列。添加完噪声的标签labels,为1000行1列,即一个样本对应一个标签。
对生成数据的第一维和标签的结果可视化:

import matplotlib.pyplot as plt
plt.scatter(features[:,0].detach().numpy(),labels.detach().numpy(),1)
plt.savefig('x1000.jpg')
plt.show()

在这里插入图片描述

小批量划分数据集

把生成的数据打乱,并根据设置的批量大小,根据索引提取样本和标签。

def data_iter(batch_size,features,labels):num_examples=len(features)indices=list(range(num_examples))##这些样本是随机读取的,没有特定顺序random.shuffle(indices)for i in range(0,num_examples,batch_size):batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)])#确定索引##根据索引值提取相应的特征和标签yield features[batch_indices],labels[batch_indices]

选择其中的一个批量样本和标签可视化进行展示。

batch_size=10#批量设置为10
for X,y in data_iter(batch_size,features,labels):print(X,'\n',y)break##读取一次就退出,即选择其中的一个批量

输出:

tensor([[-0.6141, -0.9904],[ 2.2592, -1.2401],[ 0.3217, -2.0419],[ 2.6761,  1.6293],[-0.3886,  1.4958],[-1.4074,  0.2157],[-1.9986, -0.1091],[ 0.3808,  0.3756],[-0.6877,  0.3499],[ 1.5450, -1.0313]]) tensor([[ 6.3528],[12.9585],[11.7972],[ 4.0089],[-1.6630],[ 0.6698],[ 0.5591],[ 3.6852],[ 1.6348],[10.7883]])

样本是10个1行2列的数据,标签是10个1行1列的数据。

模型参数初始化

这里设置权重w为符合均值为0,标准差为0.01的正态分布的随机生成数据,偏置b设置为0。也可以根据自己情况设置初始参数。

w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
b=torch.zeros(1,requires_grad=True)

输出:

w: tensor([[-0.0164],[-0.0022]], requires_grad=True) 
b: tensor([0.], requires_grad=True)

初始参数设置之后,接下来就是更新这些参数,直到这些参数满足我们的数据拟合。
在更新时,需要计算损失函数关于参数的梯度,再根据梯度向减小损失的方向更新参数。

模型定义

定义一个模型,将输入和输出关联起来。前面生成的数据是线性的,所以定义的模型也是一个线性的 y=Wx+b

def linreg(X,w,b):"""线性回归模型"""return torch.matmul(X,w)+b

损失函数定义

这里简单的定义一个损失函数,计算预测结果和真实结果的平均平方差。

def squared_loss(y_hat,y):"""均方损失"""return (y_hat-y.reshape(y_hat.shape))**2/2

定义优化算法

使用小批量随机梯度下降算法作为优化算法。这里要确定超参数批量大小和学习率。

def sgd(params,lr,batch_size):"""小批量随机梯度下降优化法"""with torch.no_grad():for param in params:param-=lr*param.grad/batch_sizeparam.grad.zero_()

模型训练

有了数据、损失函数、初始参数和优化算法,我们就可以开始训练,更新参数。

lr=0.01
num_epochs=10
net=linreg
loss=squared_lossfor epoch in range(num_epochs):for X,y in data_iter(batch_size,features,labels):y_pred=net(X,w,b)l=loss(y_pred,y)l.sum().backward()sgd([w,b],lr,batch_size)with torch.no_grad():train_l=loss(net(features,w,b),labels)#使用更新后的参数计算损失

运行示例

整体代码


import torch
import random##生成数据
def synthetic_data(w,b,num_examples):"""生成y=Wx+b+噪声"""##X是一个1000行2列的数据,符合0,1的正态分布X=torch.normal(0,1,(num_examples,len(w)))y=torch.matmul(X,w)+by+=torch.normal(0,0.01,y.shape)return X,y.reshape((-1,1))
#真实的权重和偏置
true_w=torch.tensor([2,-3.4])
true_b=4.2
##调用生成数据函数,生成数据
features,labels=synthetic_data(true_w,true_b,1000)
# print('features',features[0],'\nlabels:',labels[0])
# print('features.shape',features.shape,'\nlabels.shape:',labels.shape)# import matplotlib.pyplot as plt
# #set_figsize()
# plt.scatter(features[:,0].detach().numpy(),labels.detach().numpy(),1)
# plt.savefig('x1000.jpg')
# plt.show()##读取数据,根据设置的批量大小
def data_iter(batch_size,features,labels):num_examples=len(features)indices=list(range(num_examples))##这些样本是随机读取的,没有特定顺序random.shuffle(indices)for i in range(0,num_examples,batch_size):batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)])yield features[batch_indices],labels[batch_indices]
batch_size=10
# for X,y in data_iter(batch_size,features,labels):
#     print(X,'\n',y)
#     break##初始化参数
w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
b=torch.zeros(1,requires_grad=True)##定义模型、损失函数和优化算法
def linreg(X,w,b):"""线性回归模型"""return torch.matmul(X,w)+b
def squared_loss(y_hat,y):"""均方损失"""return (y_hat-y.reshape(y_hat.shape))**2/2def sgd(params,lr,batch_size):"""小批量随机梯度下降优化法"""with torch.no_grad():for param in params:param-=lr*param.grad/batch_sizeparam.grad.zero_()
#设置参数
lr=0.01
num_epochs=10
net=linreg
loss=squared_loss
##开始训练
for epoch in range(num_epochs):for X,y in data_iter(batch_size,features,labels):y_pred=net(X,w,b)l=loss(y_pred,y)l.sum().backward()sgd([w,b],lr,batch_size)with torch.no_grad():train_l=loss(net(features,w,b),labels)#使用更新后的参数计算损失print('w:',w,'\nb',b)print(f'epoch{epoch+1},loss{float(train_l.mean()):f}')

输出结果:

w: tensor([[ 1.1745],[-2.1489]], requires_grad=True) 
b tensor([2.6504], requires_grad=True)
epoch1,loss2.268522
w: tensor([[ 1.6670],[-2.9392]], requires_grad=True) 
b tensor([3.6257], requires_grad=True)
epoch2,loss0.318115
w: tensor([[ 1.8670],[-3.2300]], requires_grad=True) 
b tensor([3.9866], requires_grad=True)
epoch3,loss0.044844
w: tensor([[ 1.9472],[-3.3373]], requires_grad=True) 
b tensor([4.1204], requires_grad=True)
epoch4,loss0.006389
w: tensor([[ 1.9793],[-3.3768]], requires_grad=True) 
b tensor([4.1701], requires_grad=True)
epoch5,loss0.000951
w: tensor([[ 1.9920],[-3.3914]], requires_grad=True) 
b tensor([4.1888], requires_grad=True)
epoch6,loss0.000178
w: tensor([[ 1.9970],[-3.3968]], requires_grad=True) 
b tensor([4.1958], requires_grad=True)
epoch7,loss0.000069
w: tensor([[ 1.9991],[-3.3988]], requires_grad=True) 
b tensor([4.1984], requires_grad=True)
epoch8,loss0.000053
w: tensor([[ 1.9997],[-3.3995]], requires_grad=True) 
b tensor([4.1993], requires_grad=True)
epoch9,loss0.000051
w: tensor([[ 1.9999],[-3.3998]], requires_grad=True) 
b tensor([4.1997], requires_grad=True)
epoch10,loss0.000051

示例中部分函数详解

此部分,对代码中的部分函数进行解释和说明,以帮助大家理解和使用。

torch.normal()

torch.normal 是 PyTorch 中的一个函数,用于从正态分布(也称为高斯分布)中生成随机数。返回一个与输入张量形状相同的张量,其中的元素是从均值为 mean,标准差为 std 的正态分布中随机采样的

import torchx = torch.normal(0, 1, (3, 3))
print(x)

输出:

tensor([[-1.5393,  0.2281,  1.2181],[ 0.7260, -1.4805,  0.5720],[ 0.0170, -0.9961, -0.2761]])

.backward()

在PyTorch中,a.backward() 是一个用于自动微分的方法。它通常用于计算一个张量(tensor)相对于其操作数(即输入和参数)的梯度。当你使用 PyTorch 的 autograd 模块时,可以通过调用 backward() 方法来自动计算梯度。

import torch# 创建一个张量并设置requires_grad=True来跟踪其计算历史
a = torch.tensor([5.0], requires_grad=True)
print('a:',a)
# 定义一个简单的操作
b = a * 2# 调用backward()来自动计算梯度
b.backward()# 输出梯度
print(a.grad)  

输出:

a: tensor([5.], requires_grad=True)
tensor([2.])

其中,requires_grad属性是为了使PyTorch跟踪张量的计算历史并自动计算梯度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/638303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

寒假每日一题-公路

小苞准备开着车沿着公路自驾。公路上一共有 n个站点,编号为从 1 到 n。其中站点 i与站点 i1 的距离为 vi公里。 公路上每个站点都可以加油,编号为 i的站点一升油的价格为 ai元,且每个站点只出售整数升的油。 小苞想从站点 1开车到站点 n&am…

golang学习笔记——http.Handle和http.HandleFunc的区别与type func巧妙运用

文章目录 http.Handle和http.HandleFunc的区别http.Handle分析type func巧妙运用 http.HandleFunc分析总结参考资料 http.Handle和http.HandleFunc的区别 http.Handle和http.HandleFunc的区别体现了Go语言接口的巧妙运用 下面代码启动了一个 http 服务器,监听 808…

基于python的数字识别-含数据集和代码

数据集介绍,下载本资源后,界面如下: 有一个文件夹一个是存放数据集的文件。 数据集介绍: 一共含有:16个类别,包含:division, eight, five, four, left_bracket, minus, multiplication, nine, one, plus, right_brac…

逻辑回归中的损失函数

一、引言 逻辑回归中的损失函数通常采用的是交叉熵损失函数(cross-entropy loss function)。在逻辑回归中,我们通常使用sigmoid函数将线性模型的输出转换为概率值,然后将这些概率值与实际标签进行比较,从而计算损失。 …

《Windows核心编程》若干知识点应用实战分享

目录 1、进程的虚拟内存分区与小于0x10000的小地址内存区 1.1、进程的虚拟内存分区 1.2、小于0x10000的小地址内存区 2、保存线程上下文的CONTEXT结构体 3、从汇编代码角度去理解多线程运行过程的典型实例 4、调用TerminateThread强制结束线程会导致线程中的资源没有释放…

多人在线聊天交友工具,匿名聊天室网站源码,附带搭建教程

源码介绍 匿名聊天室(nodejs vue) 多人在线聊天交友工具,无需注册即可畅所欲言!你也可以放心讲述自己的故事,说出自己的秘密,因为谁也不知道对方是谁。 运行说明 安装依赖项:npm install 启动…

Web server failed to start.Port xxxx was already in use.

目录 一、报错截图:二、解决方式 一、报错截图: 某端口被占用,导致出现如下报错: 二、解决方式 windowsR 输入cmd—>回车 如下图所示 查看被占用的端口的进程,如下图: netstat -ano |findstr 端口号结束这个进程…

【大模型研究】(1):从零开始部署书生·浦语2-20B大模型,使用fastchat和webui部署测试,autodl申请2张显卡,占用显存40G可以运行

1,演示视频 https://www.bilibili.com/video/BV1pT4y1h7Af/ 【大模型研究】(1):从零开始部署书生浦语2-20B大模型,使用fastchat和webui部署测试,autodl申请2张显卡,占用显存40G可以运行 2&…

WEB接口测试之Jmeter接口测试自动化 (三)(数据驱动测试)

接口测试与数据驱动 1简介 数据驱动测试,即是分离测试逻辑与测试数据,通过如excel表格的形式来保存测试数据,用测试脚本读取并执行测试的过程。 2 数据驱动与jmeter接口测试 我们已经简单介绍了接口测试参数录入及测试执行的过程&#xff0…

C++——数组、多维数组、简单排序、模板类vector

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,正逐渐往全干发展 📃个人状态: 研发工程师,现效力于中国工业软件事业 🚀人生格言: 积跬步…

数据结构实验7:查找的应用

目录 一、实验目的 二、实验原理 1. 顺序查找 2. 折半查找 3. 二叉树查找 三、实验内容 实验一 任务 代码 截图 实验2 任务 代码 截图 一、实验目的 1.掌握查找的基本概念; 2.掌握并实现以下查找算法:顺序查找、折半查找、二叉树查找。 …

Python正则表达式Regular Expression初探

目录 Regular 匹配规则 单字符匹配 数量匹配 边界匹配 分组匹配 贪婪与懒惰 原版说明 特殊字符 转义序列 模块方法 函数说明 匹配模式 常用匹配规则 1. 匹配出所有整数 2. 匹配11位且13开头的整数 Regular Python的re模块提供了完整的正则表达式功能。正则表达式…

Win10升级Win11后卡顿了?

目录 关闭动画效果 任务栏居中改为居左 调整外观和性能 其他 当你看到最后,还知道哪些升级WIN11后必做的优化呢?欢迎在评论区分享出来!❤️ win11上市目前也有一段时间了,想必很多大家都已经进行更新了。新的系统确实更加简洁…

安规电容的知识

1、常见安规电容有哪些? 一般我们所说的安规电容也就有两种,一种就是X安规电容(X1/X2/X3安规电容),还有一种是Y电容(最常见的是Y1和Y2安规电容)。 2、x电容的位置 火线零线间的是X电容。x电容用…

Git将某个文件合并到指定分支

企业开发中&#xff0c;经常会单独拉分支去做自己的需求开发&#xff0c;但是某些时候一些公共的配置我们需要从主线pull&#xff0c;这时候整个分支merge显然不合适 1.切换至待合并文件的分支 git checkout <branch>2.将目标分支的单个文件合并到当前分支 git checkou…

力扣刷MySQL-第三弹(详细讲解)

&#x1f389;欢迎您来到我的MySQL基础复习专栏 ☆* o(≧▽≦)o *☆哈喽~我是小小恶斯法克&#x1f379; ✨博客主页&#xff1a;小小恶斯法克的博客 &#x1f388;该系列文章专栏&#xff1a;力扣刷题讲解-MySQL &#x1f379;文章作者技术和水平很有限&#xff0c;如果文中出…

JVM中的垃圾收集算法

标记-清除算法 首先标记出所有需要回收的对象&#xff0c;在标记完成后&#xff0c;统一回收掉所有被标记的对象&#xff0c;也可以反过来&#xff0c;标记存活的对象&#xff0c;统一回收所有未被标记的对象。标记过程就是对象是否属于垃圾的判定过程 缺点 第一个是执行效率…

C# 获取QQ会话聊天信息

目录 利用UIAutomation获取QQ会话聊天信息 效果 代码 目前遇到一个问题 其他解决办法 利用UIAutomation获取QQ会话聊天信息 效果 代码 AutomationElement window AutomationElement.FromHandle(get.WindowHwnd); AutomationElement QQMsgList window.FindFirst(Tr…

4.postman批量运行及json、cvs文件运行

一、批量运行collection 1.各个接口设置信息已保存&#xff0c;在collection中点击run collection 2.编辑并运行集合 集合运行时&#xff0c;单独上传图片时报错。需修改postman设置 二、csv文件运行 可新建记事本&#xff0c;输入测试数据&#xff0c;后另存为新的文本文件&…

Pytest 结合 Allure 生成测试报告

测试报告在项目中是至关重要的角色&#xff0c;一个好的测试报告&#xff1a; 可以体现测试人员的工作量&#xff1b; 开发人员可以从测试报告中了解缺陷的情况&#xff1b; 测试经理可以从测试报告中看到测试人员的执行情况及测试用例的覆盖率&#xff1b; 项目负责人可以通过…