2.线性神经网络

目录

  • 1.线性回归
    • 一个简化模型
    • 线性模型:可以看做是单层神经网络
    • 衡量预估质量
    • 训练数据
    • 参数学习
    • 显示解
    • 总结
  • 2.基础优化方法
    • 小批量随机梯度下降
    • 总结
  • 3.Softmax回归:其实是一个分类问题
    • 回归VS分类
    • 从回归到多类分类---均方损失
    • Softmax和交叉熵损失
  • 4.损失函数
    • L2 Loss(L2范式)
    • L1 Loss
    • Huber Robust Loss
  • 5.从零开始实现Softmax回归
  • 6.Softmax回归的简介实现

1.线性回归

一个简化模型

  • 假设1:影响放假的关键因素是卧室个数,卫生间个数和居住面积,记为x1,x2,x3
  • 假设2:成交价实关键因素的加权和 y = w1 * x1 + w2 * x2 + w3 * x3 +b

权重和偏差的实际值在后面决定

线性模型:可以看做是单层神经网络

  • 给定n维输入 x = [x1,x2,…xn]
  • 线性模型有一个n维权重和一个标量偏差
    • w = [w1,w2,…,wn] b
  • 输入是输出的加权和
    • y = w1 * x1 + w2 * x2 +…+wn * xn + b
  • 向量版:y = <w , x> + b

衡量预估质量

  • 比较真实值和预估值,例如房屋售价和估价
  • 假设y是真实值,z是估计值,我们可以比较
    • L(y,z) = 1/2(y-z)^2 这个叫做平方损失

训练数据

  • 收集一些数据点来决定参数值(权重和偏差),例如过去6个月卖的房子
  • 这被称为训练数据
  • 通常越多越好
  • 假设我们有n个样本,记
    • X =[x1,x2,x3…xn] y = [y1,y2,y3…yn]

参数学习

在这里插入图片描述

显示解

总结

  • 线性回归是对N维输入的加权,外加偏差
  • 使用平方损失来衡量预测值和真实值的差异
  • 线性回归有显示解,一般都网络都是非线性的没有显示解
  • 线性回归可以看做是单层神经网络

2.基础优化方法

梯度:使得损失函数增加最快的方向

负梯度:是损失函数减小最快的方向
学习率的选择.PNG

小批量随机梯度下降

​ 在实际中我们很少使用梯度下降,深度学习中最常用的小批量随机梯度下降

  • 在整个训练集上算梯度太贵
    • 一个深度神经网络模型可能需要数分钟至数小时
      在这里插入图片描述

选择批量大小

  • 不能太小:每次计算量太小,不适合并行来最大利用计算资源
  • 不能太大:内存消耗增加浪费计算

总结

  • 梯度下降通过不断沿着反梯度方向更新参数求解
  • 小批量随机梯度下降是深度学习默认的求解算法
  • 两个重要的超参数是批量大小和学习率

3.Softmax回归:其实是一个分类问题

回归VS分类

  • 回归估计一个连续值:如房价的预测
  • 分类预测一个离散类别:如猫狗等图片分类

MNIST:手写数字识别(10类)

ImageNet:自然物体分类(1000类)

Kaggle上的分类问题:Kaggle一个数据建模和数据分析竞赛平台

  • 将人类蛋白质显微镜图片分成28类
  • 将恶意软件分成9个类
  • 将恶意的Wikopedia评论分成7类

回归

  • 单连续数值输出
  • 自然区间R
  • 跟真实值的区别做为损失

分类

  • 通常多个输出
  • 输出i是预测为第i类的置信度

从回归到多类分类—均方损失

​ 针对分类来讲,不关心他们之间的实际的值,是关心对正确类别的置信度特别大。

​ 假设oy为真实值 oi为预测值,他们不关心oi的大小

​ 他们是关系 oy - oi 大于等于 某个阀值,或者说我们正确的oi要远远大于其他类别的oi
在这里插入图片描述

Softmax和交叉熵损失

在这里插入图片描述

4.损失函数

​ 损失函数用来衡量真实值和预测值的区别。介绍三个基本损失函数

L2 Loss(L2范式)

橙色线为:梯度

绿色线为:它的似然函数,高斯分布

蓝色线为:y = 0时变换y’的函数

​ 蓝线代表真实值和预测值的差值,当它大时,梯度值也是较大的

L1 Loss

​ 这种损失函数的特性:不管预测值和真实值相差多远,我的梯度永远是常数,所以权重更新稳定,但是此函数零点出不可导,所以它具有不平滑性,当你的预测值和真实值靠近时,会有剧烈变化

Huber Robust Loss

Huber Robust Loss

​ 当你的y’(导数)大于1时,是L1,当y’小于1时,是L2

​ 这个相当于L1和L2的结合,避免L1的不平滑性和L2在开始的不稳定性

5.从零开始实现Softmax回归

import torch
from IPython import display
from d2l import torch as d2lbatch_size = 256 
#load_data_fashion_mnist加载 Fashion-MNIST 数据集  10个类别的服饰图片数据集,每个类别包含7000张28x28像素的灰度图像。
#获取了两个迭代器 train_iter 和 test_iter 每次256张图像和对应的标签
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)#每张图片是28 * 28 * 1但是Softmax输入为向量,要把图片拉成向量
num_inputs = 784#28*28=784
num_outputs = 10#10个类
#定义权重和偏置
#w 以均值 0、标准差 0.01 的正态分布随机初始化,大小为 (num_inputs, num_outputs) 的张量
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
#b 是一个大小为 (num_outputs,) 的张量,其中 num_outputs 是输出的数量。通过 torch.zeros() 函数以全零初始化
b = torch.zeros(num_outputs, requires_grad=True)#定义一个Softmax
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制#定义网络模型
def net(X):#使用 torch.matmul() 函数计算输入数据 X 与权重 W 的矩阵乘法。#将偏置 b 加到上述结果中。#将结果应用 softmax 函数。#256 * 784return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)#定义交叉熵损失函数
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])#将预测类别y_hat与真实y元素进行比较
def accuracy(y_hat, y): """计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:#axis=1 表示沿着第一个维度(通常是行)进行操作#y_hat.argmax返回一个一维张量,其中每个元素是对应样本的最大值的索引y_hat = y_hat.argmax(axis=1)#张量 y_hat 里的数据类型是否与张量 y 里的数据类型相同,并将结果存储在 cmp 变量中cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())#cmp.type(y.dtype) 将张量 cmp 中的数据类型转换为与张量 y 相同的数据类型。这是为了确保类型匹配。#.sum() 对张量中的所有元素进行求和操作。由于 True 在数值上等价于 1,而 False 在数值上等价于 0,因此求和操作将计算出 cmp 中值为 True 的元素的数量。#float() 将计算结果转换为浮点数类型。#这个结果通常用于计算模型预测的准确率,即模型正确预测的样本数量与总样本数量的比例。#可以评估在任意模型net的准确率
def evaluate_accuracy(net, data_iter): """计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式 在评估模式下,模型不会进行参数更新metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():#一个上下文管理器,它指示 PyTorch 在接下来的代码块中不要进行梯度计算for X, y in data_iter:#遍历 data_iter 中的数据。在每次迭代中,X 是输入特征的批量,y 是对应的标签。metric.add(accuracy(net(X), y), y.numel())#神经网络模型 net 对输入数据 X 进行前向传播,得到预测结果。然后使用 accuracy 函数计算预测结果的准确率。接着调用 metric 对象的 add 方法,将准确率和当前批量数据的样本数量 y.numel() 作为参数传入,用于累积这些值。这样可以在整个评估过程中跟踪模型的平均准确率return metric[0] / metric[1]#Softmax回归训练,训练模型一个迭代周期
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)#交叉熵损失函数if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()#梯度置为0l.mean().backward()#计算损失函数 l 对模型参数的梯度updater.step()#计算梯度后,调用优化器的 step() 方法来更新模型的参数else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]#训练函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型(定义见第3章)"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):#使用训练数据集,训练一个周期train_metrics = train_epoch_ch3(net, train_iter, loss, updater)#使用测试数据集,评价精度test_acc = evaluate_accuracy(net, test_iter)#将当前周期的训练指标和测试准确率添加到动画器中,以便可视化animator.add(epoch + 1, train_metrics + (test_acc,))train_loss, train_acc = train_metrics#将训练指标解包,得到训练损失和训练准确率。assert train_loss < 0.5, train_lossassert train_acc <= 1 and train_acc > 0.7, train_accassert test_acc <= 1 and test_acc > 0.7, test_acc#小批量随机梯度下降来优化模型的损失函数
lr = 0.1
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)#训练模型10个迭代周期
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

6.Softmax回归的简介实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)#初始化模型参数
# PyTorch不会隐式(自动)地调整输入的形状。因此,
#创建了一个简单的神经网络模型 net
#nn.Flatten(): 这是一个用于将输入数据展平的层,将输入的多维数据(比如图像)展平成一维向量
# nn.Linear这是一个全连接层,将输入的 784 维向量映射到一个 10 维的输出向量。
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):#接受一个参数 mif type(m) == nn.Linear:#检查当前层 m 是否为线性层 (nn.Linear)nn.init.normal_(m.weight, std=0.01)#如果是线性层,则使用正态分布(normal distribution)随机初始化该层的权重 m.weight,标准差为 0.01net.apply(init_weights);#将定义的 init_weights 函数应用到神经网络模型 net 的所有层上,以便对每一层的权重进行初始化操作#在交叉熵损失函数中传递未归一化的预测,并同时计算Softmax及其对数
loss = nn.CrossEntropyLoss(reduction='none')
#使用学习率为0.1的小批量随机梯度下降作为优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
#训练
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/26432.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web前端:作业三

1.回到顶部案例(固定定位) <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style>#container{height: 5000px;border: 1px solid blue;}#back-button{width: 100px;height: 100px;border: 1px solid…

如何申请小程序SSL证书

在互联网时代&#xff0c;数据安全和用户隐私保护变得尤为重要。SSL证书作为网站、应用或小程序与用户之间建立安全连接的关键工具&#xff0c;其重要性不言而喻。SSL证书能够加密数据传输&#xff0c;防止信息被窃取&#xff0c;提升用户信任度&#xff0c;对于小程序开发者来…

Redux 与 MVI:Android 应用的对比

Redux 与 MVI&#xff1a;Android 应用的对比 在为 Android 应用选择合适的状态管理架构时可能会感到困惑。在这个领域中&#xff0c;有两种流行的选择是 Redux 和 MVI&#xff08;Model-View-Intent&#xff09;。两者都有各自的优缺点&#xff0c;因此在深入研究之前了解它们…

WebGIS开发:你还在纠结的10大问题合集!

问题1&#xff1a;GIS开发到底是学Java还是Python&#xff1f; Java是后端语言&#xff0c;Python更重数据分析和算法。 假设通常说的GIS开发是指Webgis&#xff0c;Web就是指网页端&#xff0c;所以我们说的GIS开发大部分情况下是指网页端的地图可视化开发。 GIS开发需要学…

工业烤箱设备厂家:专业制造,助力工业发展

随着现代工业的不断发展&#xff0c;工业烤箱设备在各个领域的应用越来越广泛。作为专业的工业烤箱设备厂家&#xff0c;我们致力于为客户提供高质量、高效率的烤箱设备&#xff0c;助力工业生产的顺利进行。 工业烤箱设备在工业生产中扮演着至关重要的角色。无论是电子、化工、…

Flask快速入门

Flask快速入门&#xff08;路由、CBV、请求和响应、session&#xff09; 目录 Flask快速入门&#xff08;路由、CBV、请求和响应、session&#xff09;安装创建页面Debug模式快速使用Werkzeug介绍watchdog介绍快速体验 路由系统源码分析手动配置路由动态路由-转换器 Flask的CBV…

SpringBoot整合SpringDataRedis

目录 1.导入Maven坐标 2.配置相关的数据源 3.编写配置类 4.通过RedisTemplate对象操作Redis SpringBoot整合Redis有很多种&#xff0c;这里使用的是Spring Data Redis。接下来就springboot整合springDataRedis步骤做一个详细介绍。 1.导入Maven坐标 首先&#xff0c;需要导…

Mysql中使用where 1=1有什么问题吗

昨天偶然看见一篇文章&#xff0c;提到说如果在mysql查询语句中&#xff0c;使用where 11会有性能问题&#xff1f;&#xff1f; 这着实把我吸引了&#xff0c;因为我项目中就有不少同事&#xff0c;包括我自己也有这样写的。为了不给其他人挖坑&#xff0c;赶紧学习一下&…

ABAP调用JavaScript进行幂乘运算

ECC版本没有内置的ipow运算函数&#xff0c;所以需要进行幂乘运算的话&#xff0c;可以采用调用JavaScript的方式来实现&#xff0c;参考代码如下&#xff1a;

集合java

1.集合 ArrayList 集合和数组的优势对比&#xff1a; 长度可变 添加数据的时候不需要考虑索引&#xff0c;默认将数据添加到末尾 package com.itheima;import java.util.ArrayList;/*public boolean add(要添加的元素) | 将指定的元素追加到此集合的末尾 | | p…

Chrome/Edge浏览器视频画中画可拉动进度条插件

目录 前言 一、Separate Window 忽略插件安装&#xff0c;直接使用 注意事项 插件缺点 1 .无置顶功能 2.保留原网页&#xff0c;但会刷新原网页 3.窗口不够美观 二、弹幕画中画播放器 三、失败的尝试 三、Potplayer播放器 总结 前言 平时看一些视频的时候&#xff…

Linux——自动化运维ansibe

一、自动化运维定义 自动化--- 自动化运维&#xff1a; 服务的自动化部署操作系统的日常运维&#xff1a;日志的备份、临时文件清理、服务器日常状态巡检、&#xff08;几乎包括了linux服务管理、linux 系统管理以及在docker 容器课程中涉及的所有内容&#xff09;服务架构的…

maven学习小结

背景 大佬指路我负责实践 目录结构 maven为项目提供一个标准目录结构 环境配置 下载maven包后解压&#xff0c;配置解压目录的bin到path变量&#xff0c;然后终端mvn -v&#xff0c;有回显则表明maven安装成功 pom POM&#xff0c;Project Object Model&#xff0c;项目对…

01_简单信号的连续和离散形式(2)

1. 单位阶跃信号 1.1离散 离散单位阶跃信号&#xff0c;也称为单位阶跃序列&#xff0c;是一个在离散时间信号分析中基础且重要的信号&#xff0c;用于描述在某个时间点后信号值发生突变的情形。它的定义如下&#xff1a; 离散单位阶跃信号具有以下几个重要性质和应用&#x…

Django中使用下拉列表过滤HTML表格数据

在Django中&#xff0c;你可以使用下拉列表&#xff08;即选择框&#xff09;来过滤HTML表格中的数据。这通常涉及两个主要步骤&#xff1a;创建过滤表单和处理过滤逻辑。 创建过滤表单 首先&#xff0c;你需要创建一个表单&#xff0c;用于接收用户选择的过滤条件。这个表单可…

【CT】LeetCode手撕—21. 合并两个有序链表

目录 题目1-思路2- 实现⭐21. 合并两个有序链表——题解思路 3- ACM实现 题目 原题连接&#xff1a;21. 合并两个有序链表 1-思路 双指针&#xff1a;题目提供的 list1 和 list2 就是两个双指针 通过每次移动 list1 和 list2 并判断二者的值&#xff0c;判断完成后将其 插入…

IDEA项目上传Github流程+常见问题解决

一、Github上创建仓库 项目创建好后如图所示 二、IDEA连接Github远程仓库 管理远程 复制远程地址 定义远程 登录Github 点击进入File->Settings->Version Control->Github登录自己的账号并勾上“√” 三、推送项目 点击推送 修改为main 点击确定&#xff0c;打开远程…

编辑并保存hosts文件

1.以管理员权限打开cmd 2.执行命令 notepad C:\Windows\System32\drivers\etc\hosts 回车后会通过记事本打开hosts文件&#xff0c;然后就可以编辑并保存了。

pdf添加书签的软件,分享3个实用的软件!

在数字化阅读日益盛行的今天&#xff0c;PDF文件已成为我们工作、学习和生活中不可或缺的一部分。然而&#xff0c;面对海量的PDF文件&#xff0c;如何高效地进行管理和阅读&#xff0c;成为了许多人关注的焦点。其中&#xff0c;添加书签功能作为提高PDF文件阅读体验的重要工具…

使用adb通过wifi连接手机

1&#xff0c;手机打开开发者模式&#xff0c;打开无线调试 2&#xff0c;命令行使用adb命令配对&#xff1a; adb pair 192.168.0.102:40731 输入验证码&#xff1a;422859 3&#xff0c;连接设备&#xff1a; adb connect 192.168.0.102:36995 4&#xff0c;查看连接状态:…