实现多层感知机

目录

多层感知机:

介绍:

代码实现:

运行结果:

问题答疑:

线性变换与非线性变换

参数含义

为什么清除梯度?

反向传播的作用

为什么更新权重?


多层感知机:

介绍:

缩写:MLP,这是一种人工神经网络,由一个输入层、一个或多个隐藏层以及一个输出层组成,每一层都由多个节点(神经元)构成。在MLP中,节点之间只有前向连接,没有循环连接,这使得它属于前馈神经网络的一种。每个节点都应用一个激活函数,如sigmoid、ReLU等,以引入非线性,从而使网络能够拟合复杂的函数和数据分布。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# Step 1: Define the MLP model
class SimpleMLP(nn.Module):def __init__(self):super(SimpleMLP, self).__init__()self.fc1 = nn.Linear(784, 128)  # Input layer to hidden layerself.fc2 = nn.Linear(128, 64)   # Hidden layer to another hidden layerself.fc3 = nn.Linear(64, 10)    # Hidden layer to output layerself.relu = nn.ReLU()def forward(self, x):x = x.view(-1, 784)             # Flatten the input from 28x28 to 784x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# Step 2: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# Step 3: Define loss function and optimizer
model = SimpleMLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Step 4: Train the model
num_epochs = 5
for epoch in range(num_epochs):for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))# Step 5: Evaluate the model on the test set (optional)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

运行结果:

问题答疑:

线性变换与非线性变换

在神经网络中

线性变换通常指的是权重矩阵和输入数据的矩阵乘法,再加上偏置向量。数学上,对于一个输入向量𝑥x和权重矩阵𝑊W,加上偏置向量𝑏b,线性变换可以表示为: 𝑧=𝑊𝑥+𝑏z=Wx+b

非线性变换是指在神经网络的每一层之后应用的激活函数,如ReLU、sigmoid或tanh等。这些函数引入了非线性,使神经网络能够学习和表达复杂的函数关系。没有非线性变换,无论多少层的神经网络最终都将简化为一个线性模型。

参数含义

在上述模型中,参数如784, 128, 64, 10并不是字节,而是神经网络层的尺寸,具体来说是神经元的数量:

  • 784: 这是输入层的神经元数量,对应于MNIST数据集中每个图片的像素数量。MNIST的图片是28x28像素,因此总共有784个像素点。
  • 128 和 64: 这是两个隐藏层的神经元数量。它们代表了第一层和第二层的宽度,即这一层有多少个神经元。
  • 10: 这是输出层的神经元数量,对应于MNIST数据集中的10个数字类别(0到9)。

为什么清除梯度?

在每一次前向传播和反向传播过程中,梯度会被累积在张量的.grad属性中。如果不手动清零,这些梯度将会被累加,导致不正确的梯度值。因此,在每次迭代开始之前,都需要调用optimizer.zero_grad()来清空梯度。

反向传播的作用

反向传播(Backpropagation)是一种算法,用于计算损失函数相对于神经网络中所有权重的梯度。它的目的是为了让神经网络知道,当损失函数值较高时,哪些权重需要调整,以及调整的方向和幅度。这些梯度随后被用于权重更新,以最小化损失函数。

为什么更新权重?

权重更新是基于梯度下降算法进行的。在反向传播计算出梯度后,权重通过optimizer.step()函数更新,以朝着减小损失函数的方向移动。

这是训练神经网络的核心,即通过不断调整权重和偏置,使模型能够更好地拟合训练数据,从而提高预测准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/44853.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

taocms 3.0.1 本地文件泄露漏洞(CVE-2021-44983)

前言 CVE-2021-44983 是一个影响 taoCMS 3.0.1 的远程代码执行(RCE)漏洞。该漏洞允许攻击者通过上传恶意文件并在服务器上执行任意代码来利用这一安全缺陷。 漏洞描述 taoCMS 是一个内容管理系统(CMS),用于创建和管…

【眼疾病识别】图像识别+深度学习技术+人工智能+卷积神经网络算法+计算机课设+Python+TensorFlow

一、项目介绍 眼疾识别系统,使用Python作为主要编程语言进行开发,基于深度学习等技术使用TensorFlow搭建ResNet50卷积神经网络算法,通过对眼疾图片4种数据集进行训练(‘白内障’, ‘糖尿病性视网膜病变’, ‘青光眼’, ‘正常’&…

jenkins系列-05-jenkins构建golang程序

下载go1.20.2.linux-arm64.tar.gz 并存放到jenkins home目录: 写一个golang demo程序:静态文件服务器:https://gitee.com/jelex/jenkins_golang package mainimport ("encoding/base64""flag""fmt""lo…

window下安装go环境

一、go官网下载安装包 官网地址如下:https://golang.google.cn/dl/ 选择对应系统的安装包,这里是window系统,可以选择zip包,下载完解压就可以使用 二、配置环境变量 这里的截图配置以win11为例 我的文件解压目录是 D:\Software…

力扣32.最长有效括号

力扣32.最长有效括号 class Solution {public:int longestValidParentheses(string s) {int n s.size();int res0;int start -1;vector<int> st;for(int i0;i<n;i){if(s[i] ()st.push_back(i);else{//前面没有( , (开启下一段)下一段的开始更新为当前下标if(st.emp…

机器学习和人工智能在农业的应用——案例分析

作者主页: 知孤云出岫 目录 引言机器学习和人工智能在农业的应用1. 精准农业作物健康监测土壤分析 2. 作物产量预测3. 农业机器人自动化播种和收割智能灌溉 4. 农业市场分析价格预测需求预测 机器学习和人工智能带来的变革1. 提高生产效率2. 降低生产成本3. 提升作物产量和质量…

探索JT808协议在车辆远程视频监控系统中的应用

一、部标JT808协议概述 随着物联网技术的迅猛发展&#xff0c;智能交通系统&#xff08;ITS&#xff09;已成为现代交通领域的重要组成部分。其中&#xff0c;车辆远程监控与管理技术作为ITS的核心技术之一&#xff0c;对于提升交通管理效率、保障道路安全具有重要意义。 JT8…

TensorBoard ,PIL 和 OpenCV 在深度学习中的应用

重要工具介绍 TensorBoard&#xff1a; 是一个TensorFlow提供的强大工具&#xff0c;用于可视化和理解深度学习模型的训练过程和结果。下面我将介绍TensorBoard的相关知识和使用方法。 TensorBoard 简介 TensorBoard是TensorFlow提供的一个可视化工具&#xff0c;用于&#x…

尚品汇-(十七)

目录&#xff1a; &#xff08;1&#xff09;获取价格信息 &#xff08;2&#xff09;获取销售信息 前面的表&#xff1a; &#xff08;1&#xff09;获取价格信息 继续编写接口&#xff1a;ManagerService /*** 获取sku价格* param skuId* return*/ BigDecimal getSkuPrice…

『 Linux 』匿名管道应用 - 简易进程池

文章目录 池化技术进程池框架及基本思路进程的描述组织管道通信建立的潜在问题 任务的描述与组织子进程读取管道信息控制子进程进程退出及资源回收 池化技术 池化技术是一种编程技巧,一般用于优化资源的分配与复用; 当一种资源需要被使用时这意味着这个资源可能会被进行多次使…

mqtt.fx连接阿里云

本文主要是记述一下如何使用mqtt.fx连接在阿里云上创建好的MQTT服务。 1 根据MQTT填写对应端口即可 找到设备信息&#xff0c;里面有MQTT连接参数 2 使用物模型通信Topic&#xff0c;注意这里的post说设备上报&#xff0c;那也就是意味着云端订阅post&#xff1b;set则意味着设…

【轻松拿捏】Java-final关键字(面试)

目录 1. 定义和基本用法 回答要点&#xff1a; 示例回答&#xff1a; 2. final 变量 回答要点&#xff1a; 示例回答&#xff1a; 3. final 方法 回答要点&#xff1a; 示例回答&#xff1a; 4. final 类 回答要点&#xff1a; 示例回答&#xff1a; 5. final 关键…

搭建hadoop+spark完全分布式集群环境

目录 一、集群规划 二、更改主机名 三、建立主机名和ip的映射 四、关闭防火墙(master,slave1,slave2) 五、配置ssh免密码登录 六、安装JDK 七、hadoop之hdfs安装与配置 1)解压Hadoop 2)修改hadoop-env.sh 3)修改 core-site.xml 4)修改hdfs-site.xml 5) 修改s…

【进阶篇-Day9:JAVA中单列集合Collection、List、ArrayList、LinkedList的介绍】

目录 1、集合的介绍1.1 概念1.2 集合的分类 2、单列集合&#xff1a;Collection2.1 Collection的使用2.2 集合的通用遍历方式2.2.1 迭代器遍历&#xff1a;&#xff08;1&#xff09;例子&#xff1a;&#xff08;2&#xff09;迭代器遍历的原理&#xff1a;&#xff08;3&…

排序——交换排序

在上篇文章我们详细介绍了排序的概念与插入排序&#xff0c;大家可以通过下面这个链接去看&#xff1a; 排序的概念及插入排序 这篇文章就介绍一下一种排序方式&#xff1a;交换排序。 一&#xff0c;交换排序 基本思想&#xff1a;两两比较&#xff0c;如果发生逆序则交换…

jenkins系列-09.jpom构建java docker harbor

本地先启动jpom server agent: /Users/jelex/Documents/work/jpom-2.10.40/server-2.10.40-release/bin jelexjelexxudeMacBook-Pro bin % sh Server.sh start/Users/jelex/Documents/work/jpom-2.10.40/agent-2.10.40-release/bin jelexjelexxudeMacBook-Pro bin % ./Agent.…

达梦数据库的系统视图v$sessions

达梦数据库的系统视图v$sessions 达梦数据库&#xff08;DM Database&#xff09;是中国的一款国产数据库管理系统&#xff0c;它提供了类似于Oracle的系统视图来监控和管理数据库。V$SESSIONS 是达梦数据库中的一个系统视图&#xff0c;用于显示当前数据库会话的信息。 以下…

全自主巡航无人机项目思路:STM32/PX4 + ROS + AI 实现从传感融合到智能规划的端到端解决方案

1. 项目概述 本项目旨在设计并实现一款高度自主的自动巡航无人机系统。该系统能够按照预设路径自主飞行&#xff0c;完成各种巡航任务&#xff0c;如电力巡线、森林防火、边境巡逻和灾害监测等。 1.1 系统特点 基于STM32F4和PX4的高性能嵌入式飞控系统多传感器融合技术实现精…

MYSQL--第八次作业

MYSQL–第八次作业 一、备份与恢复 环境搭建&#xff1a; CREATE DATABASE booksDB; use booksDB;CREATE TABLE books ( bk_id INT NOT NULL PRIMARY KEY, bk_title VARCHAR(50) NOT NULL, copyright YEAR NOT NULL );CREATE TABLE authors ( auth_id INT NOT NULL PRI…

geoServer在windows中下载安装部署详细操作教程

这里写目录标题 1.安装环境检查2.下载安装包&#xff08;1&#xff09;进入下载地址&#xff1a;&#xff08;2&#xff09;以下载最新版为例&#xff0c;点击“Stable GeoServer”下载&#xff08;3&#xff09;安装有两种方式&#xff08;4&#xff09;我这里选择下载war包 3…