人工智能算法工程师(中级)课程9-PyTorch神经网络之全连接神经网络实战与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程9-PyTorch神经网络之全连接神经网络实战与代码详解。本文将给大家展示全连接神经网络与代码详解,包括全连接模型的设计、数学原理介绍,并从手写数字识别到猫狗识别实战演练。

文章目录

  • 一、引言
  • 二、全连接模型的设计
    • 1. 神经元模型
    • 2. 网络结构
  • 三、全连接模型的参数计算
    • 1. 前向传播
    • 2. 反向传播
  • 四、全连接模型实现手写数字识别
    • 1. 数据准备
    • 2. 模型构建
    • 3. 代码实现
  • 五、阶段实战:猫狗识别
    • 1. 数据准备
    • 2. 模型构建
    • 3. 代码实现
  • 六、数学原理详解
    • 1. 激活函数
    • 2. 损失函数
    • 3. 优化算法
  • 七、总结

一、引言

全连接神经网络(Fully Connected Neural Network,FCNN)是一种经典的神经网络结构,它在众多领域都有着广泛的应用。本文将详细介绍全连接神经网络的设计、参数计算及其在图像识别任务中的应用。通过本文的学习,读者将掌握全连接神经网络的基本原理,并能够实现手写数字识别和猫狗识别等实战项目。

二、全连接模型的设计

1. 神经元模型

全连接神经网络的基本单元是神经元,其数学表达式为:
f ( x ) = σ ( ∑ i = 1 n w i x i + b ) f(x) = \sigma(\sum_{i=1}^{n}w_ix_i + b) f(x)=σ(i=1nwixi+b)
其中, x x x 为输入向量, w w w 为权重向量, b b b 为偏置, σ \sigma σ 为激活函数。

2. 网络结构

全连接神经网络由输入层、隐藏层和输出层组成。每一层的神经元都与上一层的所有神经元相连,如图1所示。
在这里插入图片描述

三、全连接模型的参数计算

1. 前向传播

假设一个全连接神经网络共有 l l l层,第 k k k层的输入为 X ( k ) X^{(k)} X(k),输出为 Y ( k ) Y^{(k)} Y(k),则有:
Y ( k ) = σ ( W ( k ) X ( k ) + b ( k ) ) Y^{(k)} = \sigma(W^{(k)}X^{(k)} + b^{(k)}) Y(k)=σ(W(k)X(k)+b(k))
其中, W ( k ) W^{(k)} W(k) b ( k ) b^{(k)} b(k) 分别为第 k k k层的权重和偏置。

2. 反向传播

全连接神经网络的参数更新通过反向传播算法实现。对于输出层,损失函数为:
L = 1 2 ( Y t r u e − Y p r e d ) 2 L = \frac{1}{2}(Y_{true} - Y_{pred})^2 L=21(YtrueYpred)2
其中, Y t r u e Y_{true} Ytrue 为真实标签, Y p r e d Y_{pred} Ypred 为预测值。
根据链式法则,输出层的权重梯度为:
∂ L ∂ W ( l ) = ∂ L ∂ Y ( l ) ⋅ ∂ Y ( l ) ∂ Z ( l ) ⋅ ∂ Z ( l ) ∂ W ( l ) \frac{\partial L}{\partial W^{(l)}} = \frac{\partial L}{\partial Y^{(l)}} \cdot \frac{\partial Y^{(l)}}{\partial Z^{(l)}} \cdot \frac{\partial Z^{(l)}}{\partial W^{(l)}} W(l)L=Y(l)LZ(l)Y(l)W(l)Z(l)
其中, Z ( l ) = W ( l ) X ( l ) + b ( l ) Z^{(l)} = W^{(l)}X^{(l)} + b^{(l)} Z(l)=W(l)X(l)+b(l)
同理,可求得输出层的偏置梯度、隐藏层的权重梯度和偏置梯度。

四、全连接模型实现手写数字识别

1. 数据准备

使用MNIST数据集,包含60000个训练样本和10000个测试样本。

2. 模型构建

构建一个简单的全连接神经网络,包含一个输入层(784个神经元)、两个隐藏层(128个神经元)和一个输出层(10个神经元)。
在这里插入图片描述

3. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.model = nn.Sequential(nn.Flatten(),nn.Linear(28*28, 128),nn.ReLU(),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, 10),nn.Softmax(dim=1))def forward(self, x):return self.model(x)# 加载数据
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)# 初始化模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()# 训练模型
for epoch in range(5):for i, (images, labels) in enumerate(dataloader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 评估模型
correct = 0
total = 0
with torch.no_grad():for images, labels in test_dataloader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

五、阶段实战:猫狗识别

1. 数据准备

使用猫狗数据集,包含25000张猫和狗的图片。我们将猫和狗的照片放在目录’data/train’下。

2. 模型构建

构建一个全连接神经网络,包含一个输入层(64643个神经元)、三个隐藏层(256、128、64个神经元)和一个输出层(2个神经元)。

3. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义数据预处理
data_transforms = transforms.Compose([transforms.Resize((64, 64)),transforms.RandomRotation(40),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomAffine(0, translate=(0.2, 0.2), scale=(0.8, 1.2)),transforms.ToTensor(),
])# 加载数据
train_dataset = datasets.ImageFolder('data/train', transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.model = nn.Sequential(nn.Flatten(),nn.Linear(64*64*3, 256),nn.ReLU(),nn.Linear(256, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 1),nn.Sigmoid())def forward(self, x):return self.model(x)# 初始化模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()# 训练模型
for epoch in range(15):for i, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.float().unsqueeze(1).to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 评估模型
# 假设有一个测试数据集的加载器叫做 validation_loader
correct = 0
total = 0
with torch.no_grad():for images, labels in validation_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)predicted = (outputs > 0.5).float()total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

六、数学原理详解

1. 激活函数

激活函数用于引入非线性因素,使得神经网络能够学习和模拟复杂函数。常用的激活函数有:

  • Sigmoid函数: σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1
  • ReLU函数: R e L U ( x ) = max ⁡ ( 0 , x ) ReLU(x) = \max(0, x) ReLU(x)=max(0,x)
  • Softmax函数: s o f t m a x ( x ) i = e x i ∑ j e x j softmax(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax(x)i=jexjexi

2. 损失函数

损失函数用于衡量模型预测值与真实值之间的差异。常用的损失函数有:

  • 均方误差(MSE): M S E ( y , y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE(y, \hat{y}) = \frac{1}{n}\sum_{i=1}^{n}(y_i - \hat{y}_i)^2 MSE(y,y^)=n1i=1n(yiy^i)2
  • 交叉熵损失:对于二分类问题, C E ( y , y ^ ) = − y log ⁡ ( y ^ ) − ( 1 − y ) log ⁡ ( 1 − y ^ ) CE(y, \hat{y}) = -y\log(\hat{y}) - (1-y)\log(1-\hat{y}) CE(y,y^)=ylog(y^)(1y)log(1y^)

3. 优化算法

优化算法用于更新网络的权重和偏置,以最小化损失函数。常用的优化算法有:

  • 梯度下降(Gradient Descent): w : = w − α ∂ L ∂ w w := w - \alpha \frac{\partial L}{\partial w} w:=wαwL
  • Adam优化器:结合了动量(Momentum)和自适应学习率(Adagrad)的优点。

七、总结

本篇文章从全连接神经网络的基本原理出发,介绍了全连接模型的设计、参数计算以及如何实现手写数字识别和猫狗识别。通过配套的完整可运行代码,读者可以更好地理解全连接神经网络的实现过程。在实际应用中,全连接神经网络虽然已被卷积神经网络(CNN)等更先进的网络结构所取代,但其基本原理仍然是深度学习领域的重要基石。希望本文能帮助读者深入掌握全连接神经网络,并为后续学习打下坚实的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/45808.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【第32章】MyBatis-Plus之代码生成器配置

文章目录 前言一、概述1.特点说明2.示例配置3. 数据库配置 (DataSourceConfig) 二、全局配置 (GlobalConfig)1.方法说明2.示例配置 三、包配置 (PackageConfig)1. 方法说明2. 示例配置 四、模板配置 (TemplateConfig)1. 方法说明2. 示例配置 五、注入配置 (InjectionConfig)1. …

使用 exe4j 转换 Java jar 程序为 Windows 平台可执行文件 (.exe)

使用 exe4j 转换 Java jar 程序为 Windows 平台可执行文件 (.exe) 介绍exe4j 特点:转换全过程(软件操作)1、注册2、选择模式3、配置应用4、选择执行的方式(我这里管这个叫呈现方式)5、选择 JAR …

Mybatis 学习之 数字字符串判断“失效”问题

目录 1. 现象2. 原因3. 解决4. 特别注意 1. 现象 <?xml version"1.0" encoding"UTF-8"?> <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper…

Java 中的正则表达式

转义字符由反斜杠\x组成&#xff0c;用于实现特殊功能当想取消这些特殊功能时可能在前面加上反斜杠\ 例如在Java中\也具有特殊意义&#xff0c;前面加一个反斜杠表示取消特殊意义&#xff0c;表示1个普通的反斜杠\&#xff0c;\\\\表示2个普通的反斜杠\\。其实就是要求Java中的…

Python那些优质可视化工具!

作者&#xff1a;Lty美丽人生 https://blog.csdn.net/weixin_44208569 本次分享10个适用于多个学科的Python数据可视化库&#xff0c;其中有名气很大的也有鲜为人知的&#xff01; 1、matplotlib 两个直方图 matplotlib 是Python可视化程序库的泰斗。经过十几年它任然是Pytho…

【前端速通系列|第二篇】Vue3前置知识

文章目录 1.前言2.包管理工具npm2.1下载node.js2.2配置 npm 镜像源2.3 npm 常用命令 3.Vite构建工具4.Vue3组件化5.Vue3运行原理 1.前言 本系列文章旨在帮助大家快速上手前端开发。 2.包管理工具npm npm 是 node.js中进行 包管理 的工具. 类似于Java中的Maven。 2.1下载nod…

Autoware 定位之基于ARTag的landmark定位(六)

Tip: 如果你在进行深度学习、自动驾驶、模型推理、微调或AI绘画出图等任务&#xff0c;并且需要GPU资源&#xff0c;可以考虑使用UCloud云计算旗下的Compshare的GPU算力云平台。他们提供高性价比的4090 GPU&#xff0c;按时收费每卡2.6元&#xff0c;月卡只需要1.7元每小时&…

CSS相对定位和绝对定位的区别

CSS相对定位和绝对定位的区别 区别1&#xff1a;相对的对象不同 相对定位是相对于自己绝对定位是相对于离自己最近的有定位的祖先 区别2:是否会脱离文档流 相对定位不会脱离文档流&#xff0c;不会影响其他元素的位置绝对定位会脱离文档流&#xff0c;会影响其他元素的布局 代…

玩转springboot之SpringBoot打成jar包的结构

SpringBoot打成jar包的结构 springboot通常会打成jar包&#xff0c;然后使用java -jar来进行执行&#xff0c;那么这个jar包里的结构是什么样的呢 其中 BOOT-INF 中包含的classes是我们程序中所有的代码编译后的class文件&#xff0c;lib是程序所引用的外部依赖 META-INF 这个…

AM243-Timer

目录 简介初始化代码测试API补充 简介 定时中断。 初始化 开启定时器&#xff0c;最多支持8个硬件定时器 定时周期1ms 增加一个GPIO输出口PRG0_PRU1_GPO15/M4 &#xff0c;我们会在定时中断中每隔1ms翻转该引脚&#xff0c;理想情况下应该在该引脚上测得2ms周期500Hz的矩形…

手把手教你打数学建模国赛!!!第一天软件准备篇

第一天软件准备 MATLAB MATLAB&#xff08;Matrix Laboratory&#xff09;是一种强大的数值计算和科学编程软件。它提供了丰富的数学函数和工具&#xff0c;用于数据分析、算法开发、信号处理、图像处理、控制系统设计、仿真等应用领域。 MATLAB具有直观的语法&#xff0c;使…

Postman接口模拟请求工具使用技巧

Postman是一款非常强大的接口模拟请求工具&#xff0c;可以帮助开发者快速测试、调试API接口。下面集合实际使用过程中的经验&#xff0c;分享大家一些基础使用技巧&#xff1a; 1. 安装与启动&#xff1a;首先在官网&#xff08;Download Postman | Get Started for Free&…

【Linux信号】阻塞信号、信号在内核中的表示、信号集操作函数、sigprocmask、sigpending

我们先来了解一下关于信号的一些常见概念&#xff1a; 实际执行 信号的处理动作 称为信号递达。 信号从产生到递达的之间的状态称为信号未决。 进程可以选择阻塞(Block)某个信号。 被阻塞的信号产生时是处于未决状态的&#xff0c;知道进程解除对该信号的阻塞&#xff0c;该…

零信任作为解决方案,Hvv还能打进去么?

零信任平台由“中心组件服务”三大部分构成&#xff0c;以平台形式充分融合软件定义边界&#xff08;SDP&#xff09;、身份与访问管理&#xff08;IAM&#xff09;、微隔离 &#xff08;MSG&#xff09;的技术方案优势&#xff0c;通过关键技术的创新&#xff0c;实现最佳可信…

Vue中的Mixins与钩子函数:理解合并与调用

在Vue的开发过程中&#xff0c;mixins是一个非常有用的特性&#xff0c;它允许我们跨多个组件共享可复用的代码。然而&#xff0c;当我们在组件与mixins之间定义同名的钩子函数或方法时&#xff0c;理解它们之间的相互作用和合并机制就显得尤为重要。 在Vue.js中&#xff0c;对…

Reinforement Learning学习记录(五)

前言 最近两周的工作主要是在做方向的探索和相关论文的学习,这次的介绍会分为,项目介绍,论文学习,当前进度,未来计划 项目介绍 最近主要是尝试了两个大类的项目,第一个是视觉追踪,第二个是三维重建 视觉跟踪 视觉追踪的话,参考了这几个开源项目: CoTracker: It i…

手机数据恢复篇:如何从 Android 手机恢复消失的照片

丢失 Android 手机中的照片现在已成为您可能遇到的最糟糕的情况之一。随着手机在相机方面越来越好&#xff0c;即使是那些不热衷于拍照的人也成为了摄影师。 如今&#xff0c;人们可以随时随地拍摄照片&#xff0c;每一张照片都保存着回忆和数据&#xff0c;因此&#xff0c;丢…

变得越来越优秀的方法

反省后看到问题很正常&#xff0c;接纳-行动-改变-能量-帮助-成长变优秀&#xff1b;温和后需要【中庸智慧】灵活处世&#xff0c;不做老好人&#xff0c;须有原则有框架&#xff01; —— 只有深刻地反省&#xff0c;我们才能真正地认识自己&#xff0c;我们反省后会看到自己…

昇思25天学习打卡营第19天|应用实践之基于MobileNetv2的垃圾分类

基本介绍 今天的应用实践是垃圾分类代码开发&#xff0c;整体流程是读取本地图像数据作为输入&#xff0c;对图像中的垃圾物体进行检测&#xff0c;并且将检测结果图片保存到文件中。采用的是MobileNetv2模型&#xff0c;使用官方提供的数据集&#xff0c;数据集分为4大类&…

python如何与前端交互

文章目录 1. 选择一个 Python Web 框架2. 创建 Web 应用程序3. 编写后端逻辑4. 编写前端代码5. 连接前后端6. 部署和测试扩展Jupyter Notebook Python 与前端&#xff08;如 HTML, CSS, JavaScript&#xff09;的关联通常是通过 Web 框架来实现的&#xff0c;这些框架允许 Pyth…