#04 构建您的第一个神经网络:PyTorch入门指南

文章目录

  • 前言
    • 理论基础
      • 神经网络层的组成
      • 前向传播与反向传播
    • 神经网络设计
      • 步骤1:准备数据集
      • 步骤2:构建模型
      • 步骤3:定义损失函数和优化器
      • 步骤4:训练模型
      • 步骤5:评估模型
      • 结论


前言

  在过去的几天里,我们深入了解了深度学习的基础,从安装PyTorch开始,到理解Tensor运算,以及自动微分系统Autograd的工作原理。今天,我们将整合我们所学的知识和技能,并迈出一个重要的步伐:构建和训练我们的第一个神经网络。
在这里插入图片描述

  在本篇博文中,我会指导您如何设计一个简单的全连接神经网络(也称为多层感知机MLP),并使用PyTorch作为您的工具。我们的目标是用这个网络解决一个分类问题,并了解神经网络训练的基本流程。让我们一步一步地进行。

理论基础

  在开始编码之前,让我们先回顾一下全连接神经网络的基本组件和原理。一个典型的神经网络由多个层组成,每个层由多个神经元组成。在全连接网络中,一个层中的每个神经元都与前一层的所有神经元连接。这种密集的连接模式能让网络从输入数据中捕获复杂的模式。

神经网络层的组成

神经网络中的每一层主要包含以下部分:

  1. 输入节点:这些是将数据输入到网络的节点。
  2. 权重:每个连接都有一个权重,它决定了输入信号对神经元的激活程度的影响。
  3. 偏置:每个神经元都有一个偏置,它提供了除了输入以外的另一个调节激活的手段。
  4. 激活函数:它决定了神经元的输出,通常是一个非线性函数,如ReLU或Sigmoid。

前向传播与反向传播

神经网络的训练分为两个阶段:前向传播和反向传播。

  • 前向传播:在这个阶段,输入数据通过网络的每一层,每个神经元根据其权重、偏置和激活函数计算输出。
  • 反向传播:训练的这一部分涉及根据网络输出和实际结果之间的差异来调整权重和偏置。这是通过计算损失函数的梯度并将其反向传播到网络中完成的。

神经网络设计

现在,让我们设计一个简单的神经网络。假设我们要解决的是一个二分类问题,我们的网络将有两个输入节点(对应于两个特征),几个隐藏层,以及一个输出节点。

步骤1:准备数据集

在构建模型之前,我们首先需要准备数据集。通常,我们会对数据进行预处理,如标准化,然后分为训练集和测试集。

import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split# 假设data为特征,labels为标签
data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=0.2)
train_dataset = TensorDataset(torch.tensor(data_train, dtype=torch.float32), torch.tensor(labels_train, dtype=torch.float32))
test_dataset = TensorDataset(torch.tensor(data_test, dtype=torch.float32), torch.tensor(labels_test, dtype=torch.float32))train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

步骤2:构建模型

  接下来,我们将使用PyTorch定义我们的神经网络。我们将使用nn.Module作为基类,并定义我们的层和前向传播方法。

import torch.nn as nn
import torch.nn.functional as Fclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(2, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = F.relu(self.fc1(x))x = torch.sigmoid(self.fc2(x))return xmodel = SimpleNN()

步骤3:定义损失函数和优化器

  损失函数用于评估模型的预测与实际标签之间的差异。优化器用于根据损失函数的梯度更新模型的权重。

import torch.optim as optimcriterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

步骤4:训练模型

训练模型涉及多个迭代(或“epoch”),在每个迭代中,我们将完成一个完整的前向和后向传播过程。

for epoch in range(100):  # 举例迭代100次for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs.squeeze(), labels)loss.backward()optimizer.step()

步骤5:评估模型

最后,我们需要评估我们的模型性能,通常使用测试集来完成这个任务。

correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)predicted = outputs.round().squeeze()  # 将输出四舍五入为0或1total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy: {100 * correct / total}%')

结论

  在以上步骤中,我们从头开始构建了一个简单的神经网络,并在PyTorch中进行了训练和评估。这个过程涵盖了许多重要的概念和技巧,是进一步学习深度学习的坚实基础。

  值得注意的是,真实世界的神经网络要复杂得多,涉及到更复杂的数据预处理、更深的网络结构、正则化技术,以及更精细的训练技巧。但通过这个基本的例子,我们已经开始了深度学习之旅,并为未来的学习打下了坚实的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/10007.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

头歌实践教学平台:CG1-v2.0-直线绘制

第1关&#xff1a;直线光栅化-DDA画线算法 一.任务描述 1.本关任务 (1)根据直线DDA算法补全line函数&#xff0c;其中直线斜率0<k<1&#xff1b; (2)当直线方程恰好经过P(x,y)和T(x,y1)的中点M时&#xff0c;统一选取直线上方的T点为显示的像素点。 2.输入 (1)直线两…

使用com.google.common.collect依赖包中的Lists.transform()方法转换集合对象之后,修改集合中的对象属性,发现不生效

目录 1.1、错误描述 &#xff08;1&#xff09;引入依赖 &#xff08;2&#xff09;模拟代码 &#xff08;3&#xff09;运行结果 1.2、解决方案 1.1、错误描述 最近在开发过程中&#xff0c;使用到了com.google.common.collect依赖包&#xff0c;通过这个依赖包中提供的…

Vue踩坑,less与less-loader安装,版本不一致

无脑通过npm i less -D安装less之后&#xff0c;继续无脑通过npm i less-loader -D安装less-loader出现如下错误&#xff1a; 解决方法&#xff1a; 1) npm uninstall less与 npm uninstall less-loader 2) 直接对其版本&#xff1a; npm i less3.0.4 -D npm i less-loader…

es关闭开启除了系统索引以外的所有索引

1、es 开启 “删除或关闭时索引名称支持通配符” 功能 2、kibanan平台执行 POST *,-.*/_close 关闭索引POST *,-.*/_open 打开索引3、其他命令 DELETE index_* // 按通配符删除以index_开头的索引 DELETE _all // 删除全部索引 DELETE *,-.* 删除全…

鸿蒙OpenHarmony开发板解析:【系统能力配置规则】

如何按需配置部件的系统能力 SysCap&#xff08;SystemCapability&#xff0c;系统能力&#xff09;是部件向开发者提供的接口的集合。 开发前请熟悉鸿蒙开发指导文档&#xff1a;gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 部件配置系统…

Java入门——类和对象(上)

经读者反映与笔者考虑&#xff0c;近期以及往后内容更新将主要以java为主&#xff0c;望读者周知、见谅。 类与对象是什么&#xff1f; C语言是面向过程的&#xff0c;关注的是过程&#xff0c;分析出求解问题的步骤&#xff0c;通过函数调用逐步解决问题。 JAVA是基于面向对…

DDOS攻击实战演示,一次DDOS的成本有多低?

DDoS攻击成本概览 分布式拒绝服务&#xff08;DDoS&#xff09;攻击以其低廉的启动成本和惊人的破坏力著称。攻击者通过黑市轻松获取服务&#xff0c;成本从几十元人民币的小额支出到针对大型目标的数千乃至数万元不等。为了具体理解这一成本结构&#xff0c;我们将通过一个简…

每日两题 / 226. 翻转二叉树 98. 验证二叉搜索树(LeetCode热题100)

226. 翻转二叉树 - 力扣&#xff08;LeetCode&#xff09; 以后续遍历的方式交换当前节点的左右指针 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), ri…

机器学习-12-sklearn案例03-flask对外提供服务

整体思路 训练一个模型&#xff0c;把模型保存 写一个基于flask的web服务&#xff0c;在web运行时加载模型&#xff0c;并在对应的接口调用模型进行预测并返回 使用curl进行测试&#xff0c;测试通过 再创建一个html页面&#xff0c;接受参数输入&#xff0c;并返回。 目录结…

CSS悬浮动画

<button class"btn">悬浮动画</button>.btn {position: absolute;top: 50%;left: 50%;transform: translate(-50%, -50%);padding: 10px 20px;width: 200px;height: 50px;background-color: transparent;border-radius: 5px;border: 2px solid powderblu…

R2S+ZeroTier+Trilium

软路由使用ZeroTier搭建远程笔记 软路由使用ZeroTier搭建远程笔记 环境部署 安装ZeroTier安装trilium 环境 软路由硬件&#xff1a;友善 Nanopo R2S软路由系统&#xff1a;OpenWrt&#xff0c;使用第三方固件nanopi-openwrt。内网穿透&#xff1a;ZeroTier。远程笔记&…

银河麒麟操作系统 v10 离线安装 Docker v20.10.9

1查看系统版本 [rootweb-0001 ~]# cat /etc/os-release NAME"Kylin Linux Advanced Server" VERSION"V10 (Tercel)" ID"kylin" VERSION_ID"V10" PRETTY_NAME"Kylin Linux Advanced Server V10 (Tercel)" ANSI_COLOR"…

pyqt动画效果放大与缩小

pyqt动画效果放大与缩小 QPropertyAnimation介绍放大与缩小效果代码 QPropertyAnimation介绍 QPropertyAnimation 是 PyQt中的一个类&#xff0c;它用于对 Qt 对象的属性进行动画处理。通过使用 QPropertyAnimation&#xff0c;你可以平滑地改变一个对象的属性值&#xff0c;例…

服务器2080ti驱动的卸载与安装

服务器2080ti驱动的卸载与安装 前言1、下载驱动2、驱动卸载与安装2.1 卸载原来驱动2.2 安装新驱动 3、查看安装情况 前言 安装transformers库&#xff0c;运行bert模型时出错&#xff0c;显示torch版本太低&#xff0c;要2.0以上的&#xff0c;所以更新显卡驱动&#xff0c;重…

基于vgg16和efficientnet卷积神经网络的天气识别系统(pytorch框架)全网首发【图像识别-天气分类】

一个能够从给定的环境图像中自动识别并分类天气&#xff08;如晴天、多云、雨天、雪天闪电等&#xff09;的系统。 技术栈&#xff1a; 深度学习框架&#xff1a;PyTorch基础模型&#xff1a;VGG16与EfficientNet任务类型&#xff1a;计算机视觉中的图像分类 模型选择 VGG16 …

1.基于python的单细胞数据预处理-归一化

目录 归一化的引入移位对数皮尔森近似残差两个归一化方法的总结 参考&#xff1a; [1] https://github.com/Starlitnightly/single_cell_tutorial [2] https://github.com/theislab/single-cell-best-practices 归一化的引入 在质量控制中&#xff0c;已经从数据集删除了低质…

【网络安全】一次sql注入问题的处理

目录 问题 10.60.100.194&#xff0c;修改之前 修改方案 问题解决 测试过程 问题思考与总结 问题 一次sql注入问题的筛查报告&#xff0c;主要是sql注入的问题资源-CSDN文库 doc-new\20-设计文档\34-Mesh设备管理\100-网络安全 10.60.100.194&#xff0c;修改之前 修改…

Multitouch for Mac:手势自定义,提升工作效率

Multitouch for Mac作为一款触控板手势增强软件&#xff0c;其核心功能在于手势的自定义和与Mac系统的深度整合。通过Multitouch&#xff0c;用户可以轻松设置各种手势&#xff0c;如三指轻点、四指左右滑动等&#xff0c;来执行常见的任务&#xff0c;如打开应用、切换窗口、滚…

ansible部署lamp架构

搭建参考&#xff1a;ansible批量运维管理-CSDN博客 定义ansible主机清单 [rootansible-server ~]# vim /etc/hosts 192.168.200.129 host01 192.168.200.130 host02 [rootansible-server ~]# vim /etc/ansible/hosts [webserver] host01 host02 在ansible端编写index.html…

DRF渲染之异常处理

异常处理 【1 】引言 Django REST Framework 这个就是我们常常说的DRF APIView的dispatch方法&#xff1a; 当请求到达视图时&#xff0c;DRF 的 APIView 类会调用 dispatch 方法来处理请求。在 dispatch 方法中&#xff0c;有一个关键的步骤是处理异常。如果在视图类的方法…